├── .gitmodules ├── bazel ├── BUILD ├── dependencies.bzl ├── nlohmann_json.BUILD └── rules_def.bzl ├── test ├── __init__.py ├── pjrt │ ├── __init__.py │ ├── args_parse.py │ └── test_dynamic_plugin_tpu.py ├── spmd │ ├── __init__.py │ └── test_xla_sharding_hlo.py ├── utils │ └── __init__.py ├── stablehlo │ └── __init__.py ├── benchmarks │ ├── .gitignore │ ├── a6000.training.latest.empty.test │ ├── a6000.inference.speedup.test │ ├── a6000.training.latest.test │ ├── v100.inference.speedup.test │ ├── v100.inference.latest.tier1.test │ ├── v100.inference.speedup.baseline_latest.test │ ├── v100.inference.histogram.test │ ├── v100.inference.latest_grouped.test │ ├── v100.inference.speedup.lazytensor.test │ ├── v100.inference.latest.openxla_baseline.test │ ├── v100.inference.histogram.lazytensor.test │ ├── v100.inference.latest.test │ ├── test_benchmark_model.py │ ├── Makefile │ ├── v100.inference.speedup.tab.test │ ├── test_benchmark_experiment.py │ ├── v100.inference.histogram.tab.test │ └── v100.inference.speedup.lazytensor_tab.test ├── cpp │ ├── get_coverage.sh │ ├── main.cpp │ ├── test_runtime.cpp │ ├── test_status_dont_show_cpp_stacktraces.cpp │ └── test_status_show_cpp_stacktraces.cpp ├── tpu │ ├── run_expensive_test_2.sh │ ├── run_expensive_test_1.sh │ └── run_pallas_test.sh ├── dynamo │ └── test_dynamo_config.py ├── test_mp_mesh_reduce.py ├── test_torch_distributed_fsdp_frozen_weight.py └── test_mp_collective_permute.py ├── torch_xla ├── py.typed ├── core │ ├── __init__.py │ └── dynamo_bridge.py ├── debug │ ├── __init__.py │ └── graph_saver.py ├── test │ └── __init__.py ├── utils │ ├── __init__.py │ ├── buffer_donor_context.py │ └── dlpack.py ├── _internal │ ├── __init__.py │ ├── utils.py │ ├── decomp_registration.py │ ├── xpu.py │ └── c10d_registration.py ├── distributed │ ├── __init__.py │ └── fsdp │ │ └── __init__.py ├── experimental │ ├── pallas_kernels │ │ └── __init__.py │ ├── dynamo_mark_sharding.py │ ├── __init__.py │ ├── dynamo_set_buffer_donor.py │ ├── distributed_checkpoint │ │ └── __init__.py │ ├── pjrt_backend.py │ └── eager.py ├── csrc │ ├── hash_util.cpp │ ├── runtime │ │ ├── env_vars.cpp │ │ ├── xla_mlir_debuginfo_helper.h │ │ ├── env_hash.h │ │ ├── stablehlo_composite_helper.h │ │ ├── tf_logging.cpp │ │ ├── metrics_reader.h │ │ ├── runtime.h │ │ └── pjrt_registry.h │ ├── version.h │ ├── dl_convertor.h │ ├── aten_fallback.h │ ├── thread_pool.h │ ├── function_call_tracker.h │ ├── shape_helper.cpp │ ├── generated_file_include.h │ ├── xla_backend_impl.h │ ├── ops │ │ ├── multinomial.h │ │ ├── nms.h │ │ ├── normal.h │ │ ├── optimization_barrier.h │ │ ├── bernoulli.h │ │ ├── eigh.h │ │ ├── exponential.h │ │ ├── uniform.h │ │ ├── qr.h │ │ ├── randperm.h │ │ ├── mark_tensor.h │ │ ├── discrete_uniform.h │ │ ├── cast_int4.h │ │ ├── stack.h │ │ ├── cummax.h │ │ ├── gather.h │ │ ├── einsum.h │ │ ├── resize.h │ │ ├── nonzero.h │ │ ├── linspace.h │ │ ├── constant.h │ │ ├── expand.h │ │ ├── squeeze.h │ │ ├── svd.h │ │ ├── put.h │ │ ├── index_select.h │ │ ├── scatter.h │ │ ├── view.h │ │ ├── count_nonzero.h │ │ ├── max_in_dim.h │ │ ├── min_in_dim.h │ │ ├── not_supported.h │ │ ├── scatter_add.h │ │ ├── symeig.h │ │ ├── unsqueeze.h │ │ ├── flip.h │ │ ├── masked_select.h │ │ ├── tpu_custom_call.h │ │ ├── index_get.h │ │ ├── cdist.h │ │ ├── linear_interpolation.h │ │ ├── threshold_backward.h │ │ ├── replication_pad.h │ │ ├── einsum_backward.h │ │ ├── threshold.h │ │ ├── arithmetic_ir_ops.h │ │ ├── mse_loss.h │ │ ├── reflection_pad2d.h │ │ ├── softmax_backward.h │ │ ├── masked_scatter.h │ │ ├── max_unpool_nd.h │ │ ├── amp_foreach_non_finite_check_and_unscale.h │ │ ├── roll.h │ │ ├── upsample_nearest2d.h │ │ ├── cat.h │ │ ├── get_dimensions_size.h │ │ ├── kth_value.h │ │ ├── adaptive_max_pool2d.h │ │ ├── send.h │ │ ├── update_slice.h │ │ ├── dynamic_expand.h │ │ ├── log_softmax_backward.h │ │ ├── dot_general.h │ │ ├── recv.h │ │ ├── expand_symint.h │ │ ├── cumsum.h │ │ ├── scatter_reduce.h │ │ ├── native_dropout.h │ │ ├── cumprod.h │ │ ├── dynamic_view.h │ │ ├── softmax.h │ │ ├── amp_update_scale.h │ │ ├── constant_pad_nd.h │ │ ├── split.h │ │ ├── embedding_bag.h │ │ ├── infer_output_shape.h │ │ ├── logsumexp.h │ │ ├── mse_loss_backward.h │ │ ├── permute.h │ │ ├── replication_pad_backward.h │ │ ├── reflection_pad2d_backward.h │ │ ├── quant_tensor.h │ │ ├── upsample_bilinear2d.h │ │ ├── user_computation.h │ │ ├── unselect.h │ │ ├── hardtanh_backward.h │ │ ├── dequant_tensor.h │ │ ├── topk.h │ │ ├── generic_slice.h │ │ ├── index_put.h │ │ ├── diagonal_view_update.h │ │ ├── collective_permute.h │ │ ├── log_softmax.h │ │ ├── nll_loss.h │ │ ├── not_supported.cpp │ │ ├── diagonal.h │ │ ├── flip.cpp │ │ ├── nll_loss2d.h │ │ ├── std.h │ │ ├── var.h │ │ ├── custom_sharding.h │ │ └── std_mean.h │ ├── token_handler.h │ ├── thread_pool.cpp │ ├── matrix.h │ ├── shape_helper.h │ ├── softmax_builder.h │ ├── unwrap_data.h │ ├── tensor_common.h │ ├── dtype.h │ ├── random.h │ ├── shape_builder.h │ ├── nll_loss.h │ └── xla_op_builder.h ├── _dynamo │ ├── __init__.py │ └── config.py └── amp │ ├── __init__.py │ └── syncfree │ └── __init__.py ├── .bazelversion ├── openxla_patches └── BUILD ├── benchmarks ├── __init__.py ├── requirements.txt ├── check_xla_device.py └── patches │ └── mismatched_batch_size.patch ├── external ├── docs ├── .gitattributes ├── _static │ └── img │ │ ├── image-1.png │ │ ├── image-2.png │ │ ├── image-3.png │ │ ├── image-4.png │ │ ├── image.png │ │ ├── spmd_mode.png │ │ ├── mesh_spmd2.png │ │ ├── spmd_debug_1.png │ │ ├── spmd_debug_2.png │ │ ├── IRgraph_markstep.png │ │ ├── llama2_2b_bsz128.png │ │ ├── IRgraph_no_markstep.png │ │ ├── ci_test_dependency.png │ │ ├── gpt2_v4_8_mfu_batch.png │ │ ├── perf_auto_vs_manual.png │ │ ├── spmd_debug_1_light.png │ │ ├── spmd_debug_2_light.png │ │ ├── ci_test_dependency_gpu.png │ │ ├── dynamic_shape_mlp_perf.png │ │ ├── ddp_md_mnist_with_real_data.png │ │ └── gpt2_2b_step_time_vs_batch.png ├── source │ └── _static │ │ └── img │ │ ├── image.png │ │ ├── image-1.png │ │ ├── image-2.png │ │ ├── image-3.png │ │ ├── image-4.png │ │ ├── mesh_spmd2.png │ │ ├── spmd_mode.png │ │ ├── spmd_debug_1.png │ │ ├── spmd_debug_2.png │ │ ├── debugger0_pack.png │ │ ├── debugger1_file.png │ │ ├── debugger5_break.png │ │ ├── dist_op_stack.png │ │ ├── IRgraph_markstep.png │ │ ├── debugger3_session.png │ │ ├── debugger4_active.png │ │ ├── llama2_2b_bsz128.png │ │ ├── IRgraph_no_markstep.png │ │ ├── ci_test_dependency.png │ │ ├── debugger2_breakpoint.png │ │ ├── gpt2_v4_8_mfu_batch.png │ │ ├── perf_auto_vs_manual.png │ │ ├── spmd_debug_1_light.png │ │ ├── spmd_debug_2_light.png │ │ ├── ci_test_dependency_gpu.png │ │ ├── dynamic_shape_mlp_perf.png │ │ ├── ddp_md_mnist_with_real_data.png │ │ └── gpt2_2b_step_time_vs_batch.png ├── docs_build.sh └── requirements.txt ├── examples ├── scan │ ├── README.md │ └── decoder_with_scan.py ├── fsdp │ └── README.md ├── data_parallel │ ├── README.md │ └── train_resnet_xla_ddp.py └── eager │ ├── train_decoder_only_eager_with_compile.py │ ├── train_decoder_only_eager.py │ └── train_decoder_only_eager_multi_process.py ├── .torch_commit ├── infra ├── ansible │ ├── roles │ │ ├── bazel │ │ │ ├── defaults │ │ │ │ └── main.yaml │ │ │ └── tasks │ │ │ │ ├── tests.yaml │ │ │ │ └── main.yaml │ │ ├── fetch_srcs │ │ │ ├── defaults │ │ │ │ └── main.yaml │ │ │ └── tasks │ │ │ │ ├── tests.yaml │ │ │ │ └── main.yaml │ │ ├── build_srcs │ │ │ └── tasks │ │ │ │ └── tests.yaml │ │ └── configure_env │ │ │ └── tasks │ │ │ └── main.yaml │ ├── .ansible-lint │ ├── config │ │ └── vars.yaml │ ├── development.Dockerfile │ └── ansible.cfg ├── tpu-pytorch │ ├── README.md │ ├── infra_triggers.tf │ ├── iam.auto.tfvars │ ├── provider.tf │ ├── tpu_ci.tf │ └── misc.tf ├── tpu-pytorch-releases │ ├── dev_images.auto.tfvars │ ├── infra_triggers.tf │ ├── iam.auto.tfvars │ └── provider.tf └── terraform_modules │ ├── arc_v4_container_cluster │ ├── README.md │ └── arc-values.yaml │ ├── worker_pool │ └── worker_pool.tf │ └── trigger_schedule_account │ └── service_account.tf ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── questions-help-support.md │ ├── documentation.md │ └── feature-request.md ├── upstream │ └── install_valgrind.sh ├── ISSUE_TEMPLATE.md └── stale.yml ├── experimental └── reference_models │ ├── sdxl_inference │ ├── README.md │ ├── astronaut_rides_horse.png │ └── sdxl_beginning.py │ └── README.md ├── .clangd ├── docker ├── gcb_pool.yaml ├── debug_image_cleanup.sh └── docker-entrypoint.sh ├── requirements.in ├── plugins └── cpu │ ├── pjrt_c_api_cpu_version_script.lds │ ├── test_cpu_plugin.h │ ├── setup.py │ ├── torch_xla_cpu_plugin │ └── __init__.py │ ├── pyproject.toml │ ├── BUILD │ └── test_cpu_plugin.cpp ├── scripts ├── update_nightly_torch_wheels.sh ├── tf_log_filter.py ├── normalize_graph_text.py ├── update_torch_wheels.sh ├── dump_stacks.py ├── metrics_to_tensorboard.py └── run_bazel_coverage.sh ├── codegen ├── BUILD └── fix_includes.sh ├── .gitignore ├── .circleci ├── test_xrt.sh └── test.sh ├── .devcontainer ├── tpu-contributor │ └── devcontainer.json └── tpu-internal │ └── devcontainer.json └── contrib └── vscode └── settings.json /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bazel/BUILD: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_xla/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 7.4.1 2 | -------------------------------------------------------------------------------- /openxla_patches/BUILD: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/pjrt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/spmd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /external: -------------------------------------------------------------------------------- 1 | bazel-xla/external -------------------------------------------------------------------------------- /test/stablehlo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_xla/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_xla/debug/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_xla/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_xla/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_xla/_internal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_xla/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/benchmarks/.gitignore: -------------------------------------------------------------------------------- 1 | *.tmp 2 | -------------------------------------------------------------------------------- /test/pjrt/args_parse.py: -------------------------------------------------------------------------------- 1 | ../args_parse.py -------------------------------------------------------------------------------- /bazel/dependencies.bzl: -------------------------------------------------------------------------------- 1 | PYTORCH_LOCAL_DIR = "../" 2 | -------------------------------------------------------------------------------- /torch_xla/experimental/pallas_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb diff=none merge=binary 2 | -------------------------------------------------------------------------------- /examples/scan/README.md: -------------------------------------------------------------------------------- 1 | ../../docs/source/features/scan.md -------------------------------------------------------------------------------- /benchmarks/requirements.txt: -------------------------------------------------------------------------------- 1 | tabulate 2 | scipy 3 | pandas 4 | -------------------------------------------------------------------------------- /.torch_commit: -------------------------------------------------------------------------------- 1 | # 2025-09-29 2 | 21fec65781bebe867faf209f89bb687ffd236ca4 -------------------------------------------------------------------------------- /infra/ansible/roles/bazel/defaults/main.yaml: -------------------------------------------------------------------------------- 1 | bazelisk_version: 1.15.0 2 | -------------------------------------------------------------------------------- /torch_xla/csrc/hash_util.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/hash_util.h" 2 | -------------------------------------------------------------------------------- /torch_xla/_dynamo/__init__.py: -------------------------------------------------------------------------------- 1 | import torch_xla._dynamo.config as config 2 | -------------------------------------------------------------------------------- /torch_xla/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from .autocast_mode import autocast # noqa: F401 2 | -------------------------------------------------------------------------------- /test/cpp/get_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "$PATH" 3 | exec llvm-cov gcov "$@" 4 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/env_vars.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/runtime/env_vars.h" 2 | -------------------------------------------------------------------------------- /test/benchmarks/a6000.training.latest.empty.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends openxla+lazytensor -- 2 | -------------------------------------------------------------------------------- /docs/_static/img/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/image-1.png -------------------------------------------------------------------------------- /docs/_static/img/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/image-2.png -------------------------------------------------------------------------------- /docs/_static/img/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/image-3.png -------------------------------------------------------------------------------- /docs/_static/img/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/image-4.png -------------------------------------------------------------------------------- /docs/_static/img/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/image.png -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | /infra @bhavya01 @qihqi @zhanyong-wan 2 | /docs @mikegre-google @qihqi @zhanyong-wan 3 | -------------------------------------------------------------------------------- /docs/_static/img/spmd_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/spmd_mode.png -------------------------------------------------------------------------------- /experimental/reference_models/sdxl_inference/README.md: -------------------------------------------------------------------------------- 1 | # How to run: 2 | 3 | ``` 4 | python sdxl.py 5 | ``` -------------------------------------------------------------------------------- /docs/_static/img/mesh_spmd2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/mesh_spmd2.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/spmd_debug_1.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/spmd_debug_2.png -------------------------------------------------------------------------------- /docs/source/_static/img/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/image.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/image-1.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/image-2.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/image-3.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/image-4.png -------------------------------------------------------------------------------- /infra/ansible/.ansible-lint: -------------------------------------------------------------------------------- 1 | --- 2 | # .ansible-lint 3 | 4 | profile: moderate 5 | skip_list: 6 | - schema[tasks] -------------------------------------------------------------------------------- /torch_xla/amp/syncfree/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import Adam 2 | from .adamw import AdamW 3 | from .sgd import SGD 4 | -------------------------------------------------------------------------------- /.clangd: -------------------------------------------------------------------------------- 1 | CompileFlags: 2 | CompilationDatabase: build # Specifies that compile_commands.json is in this directory. 3 | -------------------------------------------------------------------------------- /docs/_static/img/IRgraph_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/IRgraph_markstep.png -------------------------------------------------------------------------------- /docs/_static/img/llama2_2b_bsz128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/llama2_2b_bsz128.png -------------------------------------------------------------------------------- /docs/source/_static/img/mesh_spmd2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/mesh_spmd2.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/spmd_mode.png -------------------------------------------------------------------------------- /docker/gcb_pool.yaml: -------------------------------------------------------------------------------- 1 | privatePoolV1Config: 2 | workerConfig: 3 | diskSizeGb: '500' 4 | machineType: e2-standard-32 5 | -------------------------------------------------------------------------------- /docs/_static/img/IRgraph_no_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/IRgraph_no_markstep.png -------------------------------------------------------------------------------- /docs/_static/img/ci_test_dependency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/ci_test_dependency.png -------------------------------------------------------------------------------- /docs/_static/img/gpt2_v4_8_mfu_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/gpt2_v4_8_mfu_batch.png -------------------------------------------------------------------------------- /docs/_static/img/perf_auto_vs_manual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/perf_auto_vs_manual.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_1_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/spmd_debug_1_light.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_2_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/spmd_debug_2_light.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/spmd_debug_1.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/spmd_debug_2.png -------------------------------------------------------------------------------- /torch_xla/core/dynamo_bridge.py: -------------------------------------------------------------------------------- 1 | # TODO(JackCaoG): remove after updated upstream 2 | from torch_xla._dynamo.dynamo_bridge import * -------------------------------------------------------------------------------- /docs/_static/img/ci_test_dependency_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/ci_test_dependency_gpu.png -------------------------------------------------------------------------------- /docs/_static/img/dynamic_shape_mlp_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/dynamic_shape_mlp_perf.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger0_pack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/debugger0_pack.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger1_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/debugger1_file.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger5_break.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/debugger5_break.png -------------------------------------------------------------------------------- /docs/source/_static/img/dist_op_stack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/dist_op_stack.png -------------------------------------------------------------------------------- /docs/source/_static/img/IRgraph_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/IRgraph_markstep.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger3_session.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/debugger3_session.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger4_active.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/debugger4_active.png -------------------------------------------------------------------------------- /docs/source/_static/img/llama2_2b_bsz128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/llama2_2b_bsz128.png -------------------------------------------------------------------------------- /docs/_static/img/ddp_md_mnist_with_real_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/ddp_md_mnist_with_real_data.png -------------------------------------------------------------------------------- /docs/_static/img/gpt2_2b_step_time_vs_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/_static/img/gpt2_2b_step_time_vs_batch.png -------------------------------------------------------------------------------- /docs/source/_static/img/IRgraph_no_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/IRgraph_no_markstep.png -------------------------------------------------------------------------------- /docs/source/_static/img/ci_test_dependency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/ci_test_dependency.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger2_breakpoint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/debugger2_breakpoint.png -------------------------------------------------------------------------------- /docs/source/_static/img/gpt2_v4_8_mfu_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/gpt2_v4_8_mfu_batch.png -------------------------------------------------------------------------------- /docs/source/_static/img/perf_auto_vs_manual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/perf_auto_vs_manual.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_1_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/spmd_debug_1_light.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_2_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/spmd_debug_2_light.png -------------------------------------------------------------------------------- /docs/source/_static/img/ci_test_dependency_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/ci_test_dependency_gpu.png -------------------------------------------------------------------------------- /docs/source/_static/img/dynamic_shape_mlp_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/dynamic_shape_mlp_perf.png -------------------------------------------------------------------------------- /infra/ansible/roles/bazel/tasks/tests.yaml: -------------------------------------------------------------------------------- 1 | - name: "Bazel --version runs succesfully" 2 | ansible.builtin.command: 3 | cmd: bazel --version 4 | -------------------------------------------------------------------------------- /docs/source/_static/img/ddp_md_mnist_with_real_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/ddp_md_mnist_with_real_data.png -------------------------------------------------------------------------------- /docs/source/_static/img/gpt2_2b_step_time_vs_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/docs/source/_static/img/gpt2_2b_step_time_vs_batch.png -------------------------------------------------------------------------------- /docs/docs_build.sh: -------------------------------------------------------------------------------- 1 | # Installs requirements and builds HTML version of PyTorch/XLA docs. 2 | pip install -r requirements.txt 3 | sphinx-build -b html source build -------------------------------------------------------------------------------- /infra/tpu-pytorch/README.md: -------------------------------------------------------------------------------- 1 | # Terraform setup for pytorch-xla GCP project 2 | 3 | This setup configures: 4 | * Private docker repository (`docker`). 5 | * CI test trigger. -------------------------------------------------------------------------------- /examples/fsdp/README.md: -------------------------------------------------------------------------------- 1 | ## Recommendation 2 | Please consider using `train_decoder_only_fsdp_v2.py` since it uses SPMD internally and are very likely yield better perfomrance. 3 | -------------------------------------------------------------------------------- /experimental/reference_models/sdxl_inference/astronaut_rides_horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/HEAD/experimental/reference_models/sdxl_inference/astronaut_rides_horse.png -------------------------------------------------------------------------------- /examples/data_parallel/README.md: -------------------------------------------------------------------------------- 1 | ## Recommendation 2 | Please consider using `train_resnet_spmd_data_parallel.py` since it uses SPMD internally and are very likely yield better perfomrance. 3 | -------------------------------------------------------------------------------- /test/benchmarks/a6000.inference.speedup.test: -------------------------------------------------------------------------------- 1 | # Datetime(UTC),Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+Dynamo/Oldest Inductor),StdDev 2 | 2023-11-11 04:43:56.070348,1.0,0.0,, 3 | -------------------------------------------------------------------------------- /torch_xla/_internal/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def parse_xla_device(device: str): 5 | m = re.match(r'([A-Z]+):(\d+)$', device) 6 | if m: 7 | return (m.group(1), int(m.group(2))) 8 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | expecttest 2 | filelock 3 | fsspec 4 | jinja2 5 | markupsafe 6 | mpmath 7 | networkx 8 | pyyaml 9 | sympy 10 | typing-extensions 11 | setuptools; python_version >= "3.12" 12 | -------------------------------------------------------------------------------- /torch_xla/experimental/dynamo_mark_sharding.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.warn( 4 | "dynamo_mark_sharding will be auto registered starting 2.5 release, please remove this import" 5 | ) 6 | -------------------------------------------------------------------------------- /torch_xla/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from .eager import eager_mode, is_eager_mode, eager_mode_context 2 | 3 | __all__ = [ 4 | "eager_mode", 5 | "is_eager_mode", 6 | "eager_mode_context", 7 | ] 8 | -------------------------------------------------------------------------------- /torch_xla/experimental/dynamo_set_buffer_donor.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.warn( 4 | "dynamo_set_buffer_donor_ will be auto registered starting 2.5 release, please remove this import" 5 | ) 6 | -------------------------------------------------------------------------------- /infra/ansible/roles/fetch_srcs/defaults/main.yaml: -------------------------------------------------------------------------------- 1 | # See https://docs.ansible.com/ansible/latest/collections/ansible/builtin/git_module.html#parameter-version 2 | pytorch_git_rev: HEAD 3 | xla_git_rev: HEAD 4 | -------------------------------------------------------------------------------- /torch_xla/_internal/decomp_registration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.library.register_kernel("aten::upsample_trilinear3d", "xla", 4 | torch._decomp.decompositions.upsample_trilinear3d) 5 | -------------------------------------------------------------------------------- /test/cpp/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | int main(int argc, char* argv[]) { 7 | ::testing::InitGoogleTest(&argc, argv); 8 | return RUN_ALL_TESTS(); 9 | } 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓Questions/Help/Support" 3 | about: Do you need support? We have resources. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## ❓ Questions and Help 11 | -------------------------------------------------------------------------------- /plugins/cpu/pjrt_c_api_cpu_version_script.lds: -------------------------------------------------------------------------------- 1 | # Only symbols in the global section are available to other frameworks. 2 | VERS_1.0 { 3 | global: 4 | extern "C" { 5 | GetPjrtApi; 6 | }; 7 | 8 | local: 9 | *; 10 | }; 11 | -------------------------------------------------------------------------------- /test/tpu/run_expensive_test_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xue 3 | CDIR="$(cd "$(dirname "$0")" ; pwd -P)" 4 | TEST_CDIR="$(dirname "$CDIR")" 5 | 6 | # This test takes ~1000 seconds to run 7 | python3 "$TEST_CDIR/test_ragged_paged_attention_kernel.py" 8 | -------------------------------------------------------------------------------- /torch_xla/_internal/xpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch_xla.experimental import plugins 4 | 5 | 6 | class XpuPlugin(plugins.DevicePlugin): 7 | 8 | def library_path(self): 9 | return os.environ.get('XPU_LIBRARY_PATH', 'libxpu.so') 10 | -------------------------------------------------------------------------------- /test/tpu/run_expensive_test_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xue 3 | CDIR="$(cd "$(dirname "$0")" ; pwd -P)" 4 | TEST_CDIR="$(dirname "$CDIR")" 5 | 6 | # This test takes ~1350 seconds to run 7 | python3 "$TEST_CDIR/test_multi_queries_paged_attention_kernel.py" 8 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/dev_images.auto.tfvars: -------------------------------------------------------------------------------- 1 | dev_images = [ 2 | { 3 | accelerator = "tpu" 4 | python_version = "3.10" 5 | }, 6 | { 7 | accelerator = "tpu" 8 | extra_tags = ["tpu"] 9 | python_version = "3.12" 10 | } 11 | ] 12 | -------------------------------------------------------------------------------- /torch_xla/_dynamo/config.py: -------------------------------------------------------------------------------- 1 | import torch_xla 2 | 3 | # Whether to skip checking input is a device data or not in the optim_mod. 4 | # Enabling it will reduce the overhead a bit but will throw a runtime error 5 | # if input is a pending IR. 6 | skip_input_data_check = False -------------------------------------------------------------------------------- /test/benchmarks/a6000.training.latest.test: -------------------------------------------------------------------------------- 1 | # Workload,Speedup(Inductor/Oldest Inductor),StdDev,ModelName(Inductor),Speedup(XLA+Dynamo/Oldest Inductor),StdDev,ModelName(XLA+Dynamo) 2 | 0,1.0,0.0,BERT_pytorch,2.84589073,0.0,BERT_pytorch 3 | 1,1.0,0.0,Background_Matting,,, 4 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.test: -------------------------------------------------------------------------------- 1 | # Datetime(UTC),Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+Dynamo/Oldest Inductor),StdDev 2 | 2023-11-11 05:32:18.723407,1.0,0.03108182,0.84533378,0.01857889 3 | 2023-11-12 05:32:18,1.40315838,0.03083885,1.10120312,0.02420242 4 | -------------------------------------------------------------------------------- /bazel/nlohmann_json.BUILD: -------------------------------------------------------------------------------- 1 | cc_library( 2 | name = "json", 3 | hdrs = [ 4 | "single_include/nlohmann/json.hpp", 5 | "single_include/nlohmann/json_fwd.hpp", 6 | ], 7 | includes = ["single_include"], 8 | visibility = ["//visibility:public"], 9 | ) 10 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest.tier1.test: -------------------------------------------------------------------------------- 1 | # ARGS: --filter-by-tier=1 2 | # Workload,Speedup(Inductor/Oldest Inductor),StdDev,ModelName(Inductor),Speedup(XLA+Dynamo/Oldest Inductor),StdDev,ModelName(XLA+Dynamo) 3 | 0,1.51952596,0.06679279,BERT_pytorch,1.56880282,0.06895882,BERT_pytorch 4 | -------------------------------------------------------------------------------- /infra/terraform_modules/arc_v4_container_cluster/README.md: -------------------------------------------------------------------------------- 1 | # Cluster creation for TPU CI for PyTorch/XLA 2 | 3 | This module configures: 4 | * A regional GKE cluster 5 | * A CPU node pool 6 | * An autoscaling v4 TPU node pool 7 | * The installation of Actions Runner Controller (ARC) on the GKE cluster 8 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.baseline_latest.test: -------------------------------------------------------------------------------- 1 | # ARGS: --baseline=latest 2 | # Datetime(UTC),Speedup(Inductor/Latest Inductor),StdDev,Speedup(XLA+Dynamo/Latest Inductor),StdDev 3 | 2023-11-11 05:32:18.723407,0.71267792,0.01566335,0.60245072,0.0 4 | 2023-11-12 05:32:18,1.0,0.0,0.78480315,0.0 5 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.histogram.test: -------------------------------------------------------------------------------- 1 | # Datetime(UTC),Inductor p95,Inductor p50,Inductor p5,XLA+Dynamo p95,XLA+Dynamo p50,XLA+Dynamo p5 2 | 2023-11-11 05:32:18.723407,1.0,1.0,1.0,0.97631327,0.85586259,0.7354119 3 | 2023-11-12 05:32:18,1.50833479,1.40761418,1.30689358,1.52901152,1.17088985,0.81276817 4 | -------------------------------------------------------------------------------- /torch_xla/csrc/version.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_VERSION_H_ 2 | #define XLA_TORCH_XLA_CSRC_VERSION_H_ 3 | 4 | namespace torch_xla { 5 | 6 | extern const char XLA_GITREV[]; 7 | extern const char TORCH_GITREV[]; 8 | 9 | } // namespace torch_xla 10 | 11 | #endif // XLA_TORCH_XLA_CSRC_VERSION_H_ -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest_grouped.test: -------------------------------------------------------------------------------- 1 | # ModelName,Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+Dynamo/Oldest Inductor),StdDev 2 | Background_Matting,1.2957024,0.0,0.77297688,0.0 3 | BERT_pytorch,1.51952596,0.06679279,1.56880282,0.06895882 4 | GEOMEAN,1.40315838,0.03083885,1.10120312,0.02420242 5 | -------------------------------------------------------------------------------- /scripts/update_nightly_torch_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | # Activate torch-xla-nightly conda env if not already in it 6 | if [ "$CONDA_DEFAULT_ENV" != "torch-xla-nightly" ]; then 7 | conda activate torch-xla-nightly 8 | fi 9 | 10 | $(dirname ${BASH_SOURCE[0]})/update_torch_wheels.sh 11 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.lazytensor.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends inductor openxla+lazytensor -- 2 | # Datetime(UTC),Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+LazyTensor/Oldest Inductor),StdDev 3 | 2023-11-11 05:32:18.723407,1.0,0.03108182,, 4 | 2023-11-12 05:32:18,1.40315838,0.03083885,0.41071322,0.0 5 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/infra_triggers.tf: -------------------------------------------------------------------------------- 1 | module "terraform_apply" { 2 | source = "../terraform_modules/apply_terraform_trigger" 3 | 4 | included_files = ["infra/**"] 5 | branch = "master" 6 | config_directory = "infra/tpu-pytorch-releases" 7 | 8 | worker_pool_id = module.worker_pool.id 9 | } 10 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest.openxla_baseline.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends openxla+dynamo inductor --baseline=latest --filter-by-tier=1 2 | # Workload,Speedup(XLA+Dynamo/Latest XLA+Dynamo),StdDev,ModelName(XLA+Dynamo),Speedup(Inductor/Latest XLA+Dynamo),StdDev,ModelName(Inductor) 3 | 0,1.0,0.0,BERT_pytorch,0.96858952,0.0,BERT_pytorch 4 | -------------------------------------------------------------------------------- /torch_xla/experimental/distributed_checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .manager import CheckpointManager 2 | from .planners import SPMDSavePlanner, SPMDLoadPlanner 3 | from .util import prime_optimizer 4 | 5 | __all__ = [ 6 | "CheckpointManager", 7 | "SPMDSavePlanner", 8 | "SPMDLoadPlanner", 9 | "prime_optimizer", 10 | ] 11 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/infra_triggers.tf: -------------------------------------------------------------------------------- 1 | module "terraform_apply" { 2 | source = "../terraform_modules/apply_terraform_trigger" 3 | 4 | included_files = ["infra/**"] 5 | branch = "master" 6 | config_directory = "infra/tpu-pytorch" 7 | 8 | worker_pool_id = module.worker_pool.id 9 | location = "global" 10 | } 11 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/iam.auto.tfvars: -------------------------------------------------------------------------------- 1 | project_remote_build_writers = [ 2 | "group:cloud-tpus-dev-team@twosync.google.com", 3 | "user:mlewko@google.com", 4 | "user:goranpetrovic@google.com", 5 | # tpu-pytorch-releases project: default Service Account for running Cloud Build jobs. 6 | "serviceAccount:1001674285173@cloudbuild.gserviceaccount.com" 7 | ] 8 | -------------------------------------------------------------------------------- /plugins/cpu/test_cpu_plugin.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_PJRT_C_PJRT_C_API_CPU_H_ 2 | #define XLA_PJRT_C_PJRT_C_API_CPU_H_ 3 | 4 | #include "xla/pjrt/c/pjrt_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | const PJRT_Api* GetPjrtApi(); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | 16 | #endif // XLA_PJRT_C_PJRT_C_API_CPU_H_ 17 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.histogram.lazytensor.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends inductor openxla+lazytensor -- 2 | # Datetime(UTC),Inductor p95,Inductor p50,Inductor p5,XLA+LazyTensor p95,XLA+LazyTensor p50,XLA+LazyTensor p5 3 | 2023-11-11 05:32:18.723407,1.0,1.0,1.0,,, 4 | 2023-11-12 05:32:18,1.50833479,1.40761418,1.30689358,0.41071322,0.41071322,0.41071322 5 | -------------------------------------------------------------------------------- /plugins/cpu/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # add `build_util` to import path 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 6 | 7 | import build_util 8 | import setuptools 9 | 10 | build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so', 11 | 'torch_xla_cpu_plugin/lib') 12 | 13 | setuptools.setup() 14 | -------------------------------------------------------------------------------- /docker/debug_image_cleanup.sh: -------------------------------------------------------------------------------- 1 | IMAGE="gcr.io/tpu-pytorch/xla_debug" 2 | DATE=$(date --date='-90 days' +"%Y-%m-%dT%H:%M:%S") 3 | 4 | for digest in $(gcloud container images list-tags ${IMAGE} --limit=999999 --sort-by=TIMESTAMP --filter="timestamp.datetime < '${DATE}'" --format='get(digest)'); do 5 | echo $digest 6 | gcloud container images delete -q --force-delete-tags "${IMAGE}@${digest}" 7 | done 8 | -------------------------------------------------------------------------------- /infra/ansible/roles/bazel/tasks/main.yaml: -------------------------------------------------------------------------------- 1 | - name: "Download bazelisk v{{ bazelisk_version }}" 2 | ansible.builtin.get_url: 3 | url: "https://github.com/bazelbuild/bazelisk/releases/download/v{{ bazelisk_version }}/bazelisk-linux-amd64" 4 | dest: /usr/local/bin/bazel 5 | mode: 'u=rxw,g=rw,o=r' 6 | 7 | - name: "Tests" 8 | include_tasks: tests.yaml 9 | tags: 10 | - tests 11 | -------------------------------------------------------------------------------- /infra/ansible/roles/build_srcs/tasks/tests.yaml: -------------------------------------------------------------------------------- 1 | - name: "Check that various import statements work" 2 | ansible.builtin.command: 3 | cmd: "{{ item }}" 4 | environment: "{{ env_vars | combine({'USE_CUDA': 0}) }}" 5 | loop: 6 | - python -c "import torchgen" 7 | - python -c "import torch" 8 | - python -c "import torch_xla" 9 | - python -c "import torch_xla.core.xla_model" 10 | -------------------------------------------------------------------------------- /plugins/cpu/torch_xla_cpu_plugin/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch_xla.experimental import plugins 3 | from torch_xla._internal import tpu 4 | 5 | 6 | class CpuPlugin(plugins.DevicePlugin): 7 | 8 | def library_path(self) -> str: 9 | return os.path.join( 10 | os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so') 11 | 12 | def physical_chip_count(self) -> int: 13 | return 1 14 | -------------------------------------------------------------------------------- /torch_xla/csrc/dl_convertor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace torch_xla { 8 | 9 | DLManagedTensor* toDLPack(const at::Tensor& src); 10 | at::Tensor fromDLPack(DLManagedTensor* src); 11 | 12 | } // namespace torch_xla 13 | 14 | #endif // XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ 15 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/iam.auto.tfvars: -------------------------------------------------------------------------------- 1 | project_admins = [ 2 | "group:cloud-tpus-grpadm@twosync.google.com", 3 | "group:pytorchxla-dev@google.com", 4 | ] 5 | 6 | project_remote_build_writers = [ 7 | # "group:pytorchxla-announce@google.com", 8 | "serviceAccount:1001674285173@cloudbuild.gserviceaccount.com", 9 | "group:cloud-tpus-dev-team@twosync.google.com", 10 | ] 11 | 12 | cloudbuild_editors = [ 13 | ] 14 | -------------------------------------------------------------------------------- /torch_xla/csrc/aten_fallback.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_ATEN_FALLBACK_H_ 2 | #define XLA_TORCH_XLA_CSRC_ATEN_FALLBACK_H_ 3 | 4 | #include 5 | 6 | namespace torch_xla { 7 | 8 | void xla_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); 9 | 10 | std::vector GetFallbackOperations(); 11 | 12 | } // namespace torch_xla 13 | 14 | #endif // XLA_TORCH_XLA_CSRC_ATEN_FALLBACK_H_ 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Report an issue about missing/wrong documentation 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 📚 Documentation 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /torch_xla/csrc/thread_pool.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_THREAD_POOL_H_ 2 | #define XLA_CLIENT_THREAD_POOL_H_ 3 | 4 | #include 5 | 6 | namespace torch_xla { 7 | namespace thread { 8 | 9 | // Schedules a closure to be run. The closure should not block waiting for other 10 | // events. 11 | void Schedule(std::function fn); 12 | 13 | } // namespace thread 14 | } // namespace torch_xla 15 | 16 | #endif // XLA_CLIENT_THREAD_POOL_H_ 17 | -------------------------------------------------------------------------------- /.github/upstream/install_valgrind.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | mkdir valgrind_build && cd valgrind_build 6 | VALGRIND_VERSION=3.16.1 7 | wget https://ossci-linux.s3.amazonaws.com/valgrind-${VALGRIND_VERSION}.tar.bz2 8 | tar -xjf valgrind-${VALGRIND_VERSION}.tar.bz2 9 | cd valgrind-${VALGRIND_VERSION} 10 | ./configure --prefix=/usr/local 11 | make -j6 12 | make install 13 | cd ../../ 14 | rm -rf valgrind_build 15 | alias valgrind="/usr/local/bin/valgrind" 16 | -------------------------------------------------------------------------------- /torch_xla/distributed/fsdp/__init__.py: -------------------------------------------------------------------------------- 1 | from .xla_fully_sharded_data_parallel import XlaFullyShardedDataParallel 2 | from .state_dict_utils import (consolidate_sharded_state_dicts, 3 | consolidate_sharded_model_checkpoints) 4 | from .utils import checkpoint_module 5 | 6 | __all__ = [ 7 | "XlaFullyShardedDataParallel", 8 | "consolidate_sharded_state_dicts", 9 | "consolidate_sharded_model_checkpoints", 10 | "checkpoint_module", 11 | ] 12 | -------------------------------------------------------------------------------- /torch_xla/csrc/function_call_tracker.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_FUNCTION_CALL_TRACKER_H_ 2 | #define XLA_TORCH_XLA_CSRC_FUNCTION_CALL_TRACKER_H_ 3 | 4 | namespace torch_xla { 5 | namespace fn_tracker { 6 | 7 | #define XLA_FN_TRACK(level) \ 8 | torch_xla::fn_tracker::TrackFunction(__FUNCTION__, level) 9 | 10 | void TrackFunction(const char* tag, int level); 11 | 12 | } // namespace fn_tracker 13 | } // namespace torch_xla 14 | 15 | #endif // XLA_TORCH_XLA_CSRC_FUNCTION_CALL_TRACKER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_MLIR_DEBUGINFO_HELPER_H_ 2 | #define XLA_MLIR_DEBUGINFO_HELPER_H_ 3 | 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/IR/BuiltinOps.h" 6 | #include "mlir/Pass/Pass.h" 7 | 8 | namespace torch_xla { 9 | namespace runtime { 10 | 11 | std::unique_ptr> 12 | CreatePrepareXlaMlirDebuginfoPass(); 13 | 14 | } // namespace runtime 15 | } // namespace torch_xla 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends inductor openxla+dynamo openxla+lazytensor -- 2 | # Workload,Speedup(Inductor/Oldest Inductor),StdDev,ModelName(Inductor),Speedup(XLA+Dynamo/Oldest Inductor),StdDev,ModelName(XLA+Dynamo),Speedup(XLA+LazyTensor/Oldest Inductor),StdDev,ModelName(XLA+LazyTensor) 3 | 0,1.2957024,0.0,Background_Matting,0.77297688,0.0,Background_Matting,0.41071322,0.0,Background_Matting 4 | 1,1.51952596,0.06679279,BERT_pytorch,1.56880282,0.06895882,BERT_pytorch,,, 5 | -------------------------------------------------------------------------------- /torch_xla/utils/buffer_donor_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch_xla 4 | 5 | 6 | @contextmanager 7 | def alias_with_buffer_donor_config(should_alias: bool = True): 8 | saved_config = torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config( 9 | ) 10 | torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(should_alias) 11 | try: 12 | yield saved_config 13 | finally: 14 | torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(saved_config) 15 | -------------------------------------------------------------------------------- /examples/eager/train_decoder_only_eager_with_compile.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_decoder_only_base import TrainDecoderOnlyBase 6 | 7 | import torch_xla 8 | 9 | if __name__ == '__main__': 10 | # The step fn will still be compiled, random input generation happens eagerly. 11 | torch_xla.experimental.eager_mode(True) 12 | trainer = TrainDecoderOnlyBase() 13 | trainer.start_training() 14 | -------------------------------------------------------------------------------- /docker/docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Explicitly source bashrc even when running commands directly. 4 | # Since commands run as a separate subshell, we need to source manually. 5 | # ex. docker run -it gcr.io/tpu-pytorch/xla:nightly bash ... 6 | # The above will not source bashrc without entrypoint. 7 | source ~/.bashrc 8 | 9 | # Activate pytorch conda env at entry by default. 10 | # TODO: This should not be needed as it is already sourced from the .bashrc above. 11 | conda activate pytorch 12 | 13 | exec "$@" 14 | -------------------------------------------------------------------------------- /infra/ansible/roles/fetch_srcs/tasks/tests.yaml: -------------------------------------------------------------------------------- 1 | - name: Retrieve status of setup.py files in XLA and PyTorch repos 2 | ansible.builtin.stat: 3 | path: "{{ item }}" 4 | register: _res 5 | loop: 6 | - "{{ (src_root, 'pytorch/setup.py') | path_join }}" 7 | - "{{ (src_root, 'pytorch/xla/setup.py') | path_join }}" 8 | 9 | - name: Assert that setup.py files exist 10 | ansible.builtin.assert: 11 | that: "{{ item.stat.exists }}" 12 | fail_msg: "{{ item.item }} doesn't exist" 13 | loop: "{{ _res.results }}" 14 | -------------------------------------------------------------------------------- /torch_xla/csrc/shape_helper.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/shape_helper.h" 2 | 3 | #include "xla/hlo/builder/xla_builder.h" 4 | 5 | #include "torch_xla/csrc/status.h" 6 | 7 | namespace torch_xla { 8 | 9 | const xla::Shape& ShapeHelper::ShapeOfXlaOp(xla::XlaOp op) { 10 | XLA_ASSIGN_OR_THROW(const xla::Shape* shape, GetShape(op)); 11 | return *shape; 12 | } 13 | 14 | absl::StatusOr GetShape(xla::XlaOp op) { 15 | return op.builder()->GetShapePtr(op); 16 | } 17 | 18 | } // namespace torch_xla 19 | -------------------------------------------------------------------------------- /experimental/reference_models/sdxl_inference/sdxl_beginning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline 3 | 4 | import torch_xla2 5 | env = torch_xla2.default_env() 6 | 7 | # this is now contains torhc.Tensor 8 | pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") 9 | 10 | with env: 11 | pipe.to('jax') 12 | prompt = "a photograph of an astronaut riding a horse" 13 | image = pipe(prompt, num_inference_steps=10).images[0] 14 | image.save(f"astronaut_rides_horse_orig.png") 15 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/provider.tf: -------------------------------------------------------------------------------- 1 | # Run `gcloud auth application-default login` before running Terraform. 2 | provider "google" { 3 | project = "tpu-pytorch" 4 | region = "us-central1" 5 | } 6 | 7 | terraform { 8 | required_providers { 9 | google = { 10 | source = "hashicorp/google" 11 | version = ">= 4.52.0" 12 | } 13 | } 14 | 15 | backend "gcs" { 16 | # Make sure that bucket name matches the only specified in ./misc.tf. 17 | bucket = "tpu-pytorch-tfstate" 18 | prefix = "terraform/state" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/generated_file_include.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_GENERATED_FILE_INCLUDE_H_ 2 | #define XLA_TORCH_XLA_CSRC_GENERATED_FILE_INCLUDE_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/aten_fallback.h" 7 | #include "torch_xla/csrc/aten_xla_bridge.h" 8 | #include "torch_xla/csrc/ir.h" 9 | #include "torch_xla/csrc/ops/ops_xla_shape_fn.h" 10 | #include "torch_xla/csrc/runtime/debug_macros.h" 11 | #include "torch_xla/csrc/runtime/metrics.h" 12 | 13 | #endif // XLA_TORCH_XLA_CSRC_GENERATED_FILE_INCLUDE_H_ 14 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/env_hash.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace torch_xla { 4 | namespace runtime { 5 | namespace hash { 6 | 7 | // Take a hash of XLA flags which impact the compilation result. 8 | // TODO(jonbolin): We should move away from manually hashing the env vars and 9 | // instead hash the compilation environment directly when the functionality is 10 | // supported in the runtime. 11 | torch::lazy::hash_t HashXlaEnvVars(); 12 | 13 | } // namespace hash 14 | } // namespace runtime 15 | } // namespace torch_xla 16 | -------------------------------------------------------------------------------- /infra/ansible/roles/configure_env/tasks/main.yaml: -------------------------------------------------------------------------------- 1 | - name: Append environment variables required during runtime to ~/.bashrc 2 | ansible.builtin.lineinfile: 3 | path: ~/.bashrc 4 | line: "export {{ item }}={{ env_vars[item] }}" 5 | create: true 6 | loop: "{{ env_vars.keys() | list }}" 7 | 8 | - name: Append environment variables required during runtime to ~/.zshrc 9 | ansible.builtin.lineinfile: 10 | path: ~/.zshrc 11 | line: "export {{ item }}={{ env_vars[item] }}" 12 | create: true 13 | loop: "{{ env_vars.keys() | list }}" 14 | -------------------------------------------------------------------------------- /plugins/cpu/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "torch_xla_cpu_plugin" 7 | version = "0.0.1" 8 | authors = [ 9 | {name = "PyTorch/XLA Dev Team'", email = "pytorch-xla@googlegroups.com"}, 10 | ] 11 | description = "CPU PJRT Plugin for testing only" 12 | requires-python = ">=3.8" 13 | 14 | [tool.setuptools.package-data] 15 | torch_xla_cpu_plugin = ["lib/*.so"] 16 | 17 | [project.entry-points."torch_xla.plugins"] 18 | example = "torch_xla_cpu_plugin:CpuPlugin" 19 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/provider.tf: -------------------------------------------------------------------------------- 1 | # Run `gcloud auth application-default login` before running Terraform. 2 | provider "google" { 3 | project = "tpu-pytorch-releases" 4 | region = "us-central1" 5 | } 6 | 7 | terraform { 8 | required_providers { 9 | google = { 10 | source = "hashicorp/google" 11 | version = ">= 4.52.0" 12 | } 13 | } 14 | 15 | backend "gcs" { 16 | # Make sure that bucket name matches the one specified in ./misc.tf. 17 | bucket = "tpu-pytorch-releases-tfstate" 18 | prefix = "terraform/state" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/tpu_ci.tf: -------------------------------------------------------------------------------- 1 | module "v4_arc_cluster" { 2 | source = "../terraform_modules/arc_v4_container_cluster" 3 | project_id = "tpu-pytorch" 4 | cluster_name = "tpu-ci" 5 | cpu_nodepool_name = "cpu-nodepool" 6 | cpu_node_count = 32 7 | tpu_nodepool_name = "tpu-nodepool" 8 | min_tpu_nodes = 32 9 | max_tpu_nodes = 32 10 | github_repo_url = "https://github.com/pytorch/xla" 11 | # Dockerfile for this image can be found at test/tpu/Dockerfile 12 | runner_image = "gcr.io/tpu-pytorch/tpu-ci-runner:latest" 13 | } 14 | -------------------------------------------------------------------------------- /test/tpu/run_pallas_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xue 3 | 4 | # Absolute path to the directory of this script. 5 | _TPU_DIR="$( 6 | cd "$(dirname "$0")" 7 | pwd -P 8 | )" 9 | 10 | # Absolute path to the test/ directory. 11 | _TEST_DIR="$(dirname "$_TPU_DIR")" 12 | 13 | # This test takes ~370 seconds to run 14 | python3 "$_TEST_DIR/test_pallas.py" -v 15 | 16 | # This test takes ~15 seconds to run 17 | python3 "$_TEST_DIR/test_pallas_spmd.py" 18 | 19 | # This test takes ~15 seconds to run 20 | XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$_TEST_DIR/test_pallas_spmd.py" 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/xla_backend_impl.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_XLA_BACKEND_IMPL_H_ 2 | #define XLA_TORCH_XLA_CSRC_XLA_BACKEND_IMPL_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "torch_xla/csrc/device.h" 10 | #include "torch_xla/csrc/runtime/computation_client.h" 11 | 12 | namespace torch_xla { 13 | 14 | torch::lazy::BackendImplInterface* GetXlaBackendImpl(); 15 | 16 | bool InitXlaBackend(); 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_XLA_BACKEND_IMPL_H_ 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | If you are asking a question, please preface the title with [question]. 2 | If you are submitting a feature request, please preface the title with [feature request]. 3 | If you are submitting a bug report, please fill in the following details. 4 | 5 | ## Issue description 6 | 7 | Provide a short description. 8 | 9 | ## Code example 10 | 11 | Please try to provide a minimal example to repro the bug. 12 | Error messages and stack traces are also helpful. 13 | 14 | ## System Info 15 | 16 | - reproducible on XLA backend [CPU/TPU]: 17 | - torch_xla version: 18 | -------------------------------------------------------------------------------- /test/benchmarks/test_benchmark_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from benchmark_model import BenchmarkModel 4 | 5 | 6 | class BenchmarkModelTest(unittest.TestCase): 7 | 8 | def test_to_dict(self): 9 | bm = BenchmarkModel("torchbench or other", "super_deep_model", 10 | "placeholder") 11 | actual = bm.to_dict() 12 | self.assertEqual(2, len(actual)) 13 | self.assertEqual("torchbench or other", actual["suite_name"]) 14 | self.assertEqual("super_deep_model", actual["model_name"]) 15 | 16 | 17 | if __name__ == '__main__': 18 | unittest.main() 19 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/multinomial.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "torch_xla/csrc/ir.h" 4 | 5 | namespace torch_xla { 6 | 7 | class Multinomial : public XlaNode { 8 | public: 9 | Multinomial(const torch::lazy::Value& input, const torch::lazy::Value& seed, 10 | int64_t num_samples, bool replacement); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | 16 | private: 17 | int64_t num_samples_; 18 | bool replacement_; 19 | }; 20 | 21 | } // namespace torch_xla 22 | -------------------------------------------------------------------------------- /torch_xla/csrc/token_handler.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_TOKEN_HANDLER_H_ 2 | #define XLA_TORCH_XLA_CSRC_TOKEN_HANDLER_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | class TokenHandler { 9 | public: 10 | explicit TokenHandler(xla::XlaOp token) : token_(token) {} 11 | 12 | xla::XlaOp GetInput(xla::XlaOp input, const xla::Shape* input_shape); 13 | 14 | xla::XlaOp GetNewToken(xla::XlaOp result); 15 | 16 | private: 17 | xla::XlaOp token_; 18 | }; 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_TOKEN_HANDLER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/stablehlo_composite_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef STABLEHLO_COMPOSITE_HELPER_H_ 2 | #define STABLEHLO_COMPOSITE_HELPER_H_ 3 | 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/IR/BuiltinOps.h" 6 | #include "mlir/Pass/Pass.h" 7 | 8 | namespace torch_xla { 9 | namespace runtime { 10 | 11 | std::unique_ptr> 12 | CreateBuildStableHLOCompositePass(); 13 | 14 | std::unique_ptr> 15 | CreateRemoveXlaMarkTensorOpsPass(); 16 | 17 | } // namespace runtime 18 | } // namespace torch_xla 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /examples/scan/decoder_with_scan.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import override 2 | from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel 3 | 4 | 5 | class DecoderWithScan(DecoderOnlyModel): 6 | 7 | def __init__(self, config: DecoderOnlyConfig): 8 | super().__init__(config) 9 | self.config = config 10 | 11 | @override 12 | def run_decoder_layers(self, hidden_states): 13 | from torch_xla.experimental.scan_layers import scan_layers 14 | return scan_layers( 15 | self.layers, 16 | hidden_states, 17 | is_layer_pure=self.config.is_decoder_layer_pure) 18 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/nms.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NMS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NMS_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Nms : public XlaNode { 9 | public: 10 | Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, 11 | const torch::lazy::Value& iou_threshold); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_NMS_H_ 21 | -------------------------------------------------------------------------------- /infra/terraform_modules/arc_v4_container_cluster/arc-values.yaml: -------------------------------------------------------------------------------- 1 | githubConfigUrl: ${github_repo_url} 2 | githubConfigSecret: github-pat 3 | minRunners: ${min_tpu_nodes} 4 | maxRunners: ${max_tpu_nodes} 5 | template: 6 | spec: 7 | containers: 8 | - name: runner 9 | image: ${runner_image} 10 | command: ["/home/runner/run.sh"] 11 | resources: 12 | limits: 13 | google.com/tpu: 4 14 | requests: 15 | google.com/tpu: 4 16 | nodeSelector: 17 | cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice 18 | cloud.google.com/gke-tpu-topology: 2x2x1 19 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/normal.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NORMAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NORMAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Normal : public XlaNode { 9 | public: 10 | Normal(const torch::lazy::Value& mean, const torch::lazy::Value& std, 11 | const torch::lazy::Value& seed); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_NORMAL_H_ 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/optimization_barrier.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_OPTIMIZATION_BARRIER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_OPTIMIZATION_BARRIER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class OptimizationBarrier : public XlaNode { 9 | public: 10 | OptimizationBarrier(const torch::lazy::OpList& inputs); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | }; 16 | 17 | } // namespace torch_xla 18 | 19 | #endif // XLA_TORCH_XLA_CSRC_OPS_OPTIMIZATION_BARRIER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/bernoulli.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_BERNOULLI_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_BERNOULLI_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Bernoulli : public XlaNode { 9 | public: 10 | Bernoulli(const torch::lazy::Value& probability, 11 | const torch::lazy::Value& seed, xla::Shape shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_BERNOULLI_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/eigh.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EIGH_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EIGH_H_ 3 | 4 | #include 5 | 6 | #include "xla/types.h" 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class Eigh : public XlaNode { 13 | public: 14 | Eigh(const torch::lazy::Value& input, std::string_view uplo); 15 | 16 | std::string ToString() const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | char uplo_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_EIGH_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/exponential.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EXPONENTIAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EXPONENTIAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Exponential : public XlaNode { 9 | public: 10 | Exponential(const torch::lazy::Value& lambda, const torch::lazy::Value& seed, 11 | xla::Shape shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_EXPONENTIAL_H_ -------------------------------------------------------------------------------- /codegen/BUILD: -------------------------------------------------------------------------------- 1 | exports_files(["xla_native_functions.yaml"]) 2 | 3 | # Requires `torchgen` locally, via pip `torch` package or local builds. 4 | py_binary( 5 | name = "lazy_tensor_generator", 6 | srcs = ["lazy_tensor_generator.py"], 7 | data = [ 8 | "//torch_xla/csrc:aten_xla_type.cpp", 9 | "@torch//:torchgen_deps", 10 | ], 11 | deps = [ 12 | "@pypi_torch//:pkg", 13 | "@pypi_pyyaml//:pkg", 14 | ], 15 | tags = [ 16 | "local", 17 | "no-remote-exec", 18 | ], 19 | ) 20 | 21 | sh_binary( 22 | name = "fix_includes", 23 | srcs = ["fix_includes.sh"], 24 | ) 25 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/uniform.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UNIFORM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UNIFORM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Uniform : public XlaNode { 9 | public: 10 | Uniform(const torch::lazy::Value& from, const torch::lazy::Value& to, 11 | const torch::lazy::Value& seed, const xla::Shape& rng_shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_UNIFORM_H_ -------------------------------------------------------------------------------- /examples/eager/train_decoder_only_eager.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_decoder_only_base import TrainDecoderOnlyBase 6 | 7 | import torch_xla 8 | 9 | 10 | class TrainDecoderOnlyEager(TrainDecoderOnlyBase): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | # We want to run the step fn eagerly. 15 | self.compiled_step_fn = self.step_fn 16 | 17 | 18 | if __name__ == '__main__': 19 | torch_xla.experimental.eager_mode(True) 20 | base = TrainDecoderOnlyEager() 21 | base.start_training() 22 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/qr.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_QR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_QR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class QR : public XlaNode { 9 | public: 10 | QR(const torch::lazy::Value& input, bool some); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | bool some() const { return some_; } 19 | 20 | private: 21 | bool some_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_QR_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/randperm.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class RandPerm : public XlaNode { 11 | public: 12 | RandPerm(int64_t n, const at::ScalarType dtype, const at::Layout layout, 13 | const at::Device device, bool pin_memory); 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | std::string ToString() const override; 17 | 18 | private: 19 | int64_t n_; 20 | }; 21 | 22 | } // namespace torch_xla 23 | 24 | #endif // XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ 25 | -------------------------------------------------------------------------------- /scripts/tf_log_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | 9 | def normalize(args): 10 | fd = sys.stdin 11 | if args.input: 12 | fd = open(args.input) 13 | # 2019-04-06 02:51:26.397580: I torch_xla/csrc/forward.cpp:168] 14 | for line in fd: 15 | line.rstrip('\n') 16 | m = re.match(r'.*:\d+\] (.*)', line) 17 | if m: 18 | print(m.group(1)) 19 | else: 20 | print(line) 21 | 22 | 23 | if __name__ == '__main__': 24 | arg_parser = argparse.ArgumentParser() 25 | arg_parser.add_argument('--input', type=str) 26 | args = arg_parser.parse_args() 27 | normalize(args) 28 | -------------------------------------------------------------------------------- /infra/terraform_modules/worker_pool/worker_pool.tf: -------------------------------------------------------------------------------- 1 | variable "name" { 2 | default = "main" 3 | } 4 | 5 | variable "location" { 6 | default = "us-central1" 7 | } 8 | 9 | variable "machine_type" { 10 | default = "e2-standard-32" 11 | } 12 | 13 | variable "disk_size_gb" { 14 | default = 500 15 | } 16 | 17 | resource "google_cloudbuild_worker_pool" "worker_pool" { 18 | name = var.name 19 | location = var.location 20 | 21 | worker_config { 22 | disk_size_gb = var.disk_size_gb 23 | machine_type = var.machine_type 24 | no_external_ip = false 25 | } 26 | } 27 | 28 | output "id" { 29 | value = google_cloudbuild_worker_pool.worker_pool.id 30 | } 31 | -------------------------------------------------------------------------------- /test/cpp/test_runtime.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "torch_xla/csrc/runtime/runtime.h" 4 | 5 | namespace torch_xla::runtime { 6 | 7 | TEST(RuntimeTest, ComputationClientInitialization) { 8 | ComputationClient* client; 9 | 10 | client = GetComputationClientIfInitialized(); 11 | EXPECT_EQ(client, nullptr); 12 | 13 | // Initialize the ComputationClient. 14 | // Check all the APIs return the same valid ComputationClient. 15 | 16 | auto status = GetComputationClient(); 17 | ASSERT_TRUE(status.ok()); 18 | 19 | client = status.value(); 20 | EXPECT_EQ(GetComputationClientIfInitialized(), client); 21 | } 22 | 23 | } // namespace torch_xla::runtime 24 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/mark_tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MarkTensor : public XlaNode { 9 | public: 10 | MarkTensor(const torch::lazy::Value& input, const std::string& info); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | private: 19 | std::string info_; 20 | }; 21 | 22 | } // namespace torch_xla 23 | 24 | #endif // XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ 25 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/discrete_uniform.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DISCRETE_UNIFORM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DISCRETE_UNIFORM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DiscreteUniform : public XlaNode { 9 | public: 10 | DiscreteUniform(const torch::lazy::Value& from, const torch::lazy::Value& to, 11 | const torch::lazy::Value& seed, const xla::Shape& rng_shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_DISCRETE_UNIFORM_H_ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | *.egg-info/ 4 | torch_xla/lib/ 5 | torch_xla/pb/cpp/* 6 | torch_xla/version.py 7 | torch_xla/csrc/version.cpp 8 | */**/__pycache__ 9 | */**/*MNIST/ 10 | torchax/**/runs/ 11 | *.swp 12 | *.pyc 13 | *.so 14 | 15 | # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.) 16 | # 17 | # Files below are not deleted by "setup.py clean". 18 | 19 | # Visual Studio Code files. 20 | .vs 21 | .vscode/ 22 | 23 | # Files autogenerated by docs/docs_build.sh. 24 | /core 25 | /docs/src/* 26 | 27 | # Local terraform state. 28 | .terraform 29 | 30 | 31 | # Bazel temporary files. 32 | bazel-* 33 | MODULE.bazel 34 | MODULE.bazel.lock 35 | 36 | # Clangd cache directory. 37 | .cache/* 38 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cast_int4.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CAST_INT4 2 | #define XLA_TORCH_XLA_CSRC_OPS_CAST_INT4 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class CastInt4 : public XlaNode { 9 | public: 10 | CastInt4(const torch::lazy::Value& weight, 11 | const std::vector& int4_weight_values); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | private: 20 | std::vector int4_vals_; 21 | }; 22 | 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_OPS_CAST_INT4 26 | -------------------------------------------------------------------------------- /scripts/normalize_graph_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | 9 | def normalize(args): 10 | fd = sys.stdin 11 | if args.input: 12 | fd = open(args.input) 13 | # %397 = f32[128]{0} xla::cross_replica_sum(%396), scale=0.125, groups=() 14 | for line in fd: 15 | line.rstrip('\n') 16 | m = re.match(r'(\s*)%\d+\s*=\s*(.*::[^(]+\()[^)]*(.*)', line) 17 | if m: 18 | line = m.group(1) + m.group(2) + m.group(3) 19 | print(line) 20 | 21 | 22 | if __name__ == '__main__': 23 | arg_parser = argparse.ArgumentParser() 24 | arg_parser.add_argument('--input', type=str) 25 | args = arg_parser.parse_args() 26 | normalize(args) 27 | -------------------------------------------------------------------------------- /torch_xla/experimental/pjrt_backend.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.distributed as dist 4 | from torch_xla.distributed import xla_backend 5 | from torch_xla._internal import rendezvous 6 | from torch_xla._internal import tpu 7 | 8 | if tpu.num_available_chips() > 0 and tpu.version() <= 3: 9 | from torch.testing._internal.distributed import multi_threaded_pg 10 | logging.warning('Patching torch.distributed state to support multithreading.') 11 | logging.warning('torch.distributed support on TPU v2 and v3 is experimental ' 12 | 'and does not support torchrun.') 13 | multi_threaded_pg._install_threaded_pg() 14 | 15 | dist.register_rendezvous_handler('pjrt', rendezvous.pjrt_rendezvous_handler) 16 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/stack.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_STACK_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_STACK_H_ 3 | 4 | #include "absl/types/span.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class Stack : public XlaNode { 11 | public: 12 | Stack(c10::ArrayRef values, int64_t dim); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | int64_t dim() const { return dim_; }; 21 | 22 | private: 23 | int64_t dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_STACK_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/thread_pool.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/thread_pool.h" 2 | 3 | #include 4 | 5 | #include "tsl/platform/env.h" 6 | #include "tsl/platform/threadpool.h" 7 | 8 | #include "torch_xla/csrc/runtime/sys_util.h" 9 | 10 | namespace torch_xla { 11 | namespace thread { 12 | 13 | void Schedule(std::function fn) { 14 | static size_t num_threads = torch_xla::runtime::sys_util::GetEnvInt( 15 | "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); 16 | static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", 17 | num_threads); 18 | pool.Schedule(std::move(fn)); 19 | } 20 | 21 | } // namespace thread 22 | } // namespace torch_xla 23 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cummax.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class CumMax : public XlaNode { 11 | public: 12 | CumMax(const torch::lazy::Value& input, int64_t dim); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | int64_t dim() const { return dim_; } 21 | 22 | private: 23 | int64_t dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ 29 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/gather.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_GATHER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_GATHER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Gather : public XlaNode { 9 | public: 10 | Gather(const torch::lazy::Value& input, int64_t dim, 11 | const torch::lazy::Value& index); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_GATHER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/einsum.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EINSUM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EINSUM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Einsum : public XlaNode { 9 | public: 10 | Einsum(const torch::lazy::OpList& operands, const std::string equation); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | 16 | std::string ToString() const override; 17 | 18 | const std::string& equation() const { return equation_; } 19 | 20 | private: 21 | const std::string equation_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_EINSUM_H_ -------------------------------------------------------------------------------- /infra/ansible/config/vars.yaml: -------------------------------------------------------------------------------- 1 | # Used for fetching clang from the right repo, see apt.yaml. 2 | llvm_debian_repo: bookworm 3 | clang_version: 17 4 | # PyTorch and PyTorch/XLA wheel versions. 5 | package_version: 2.9.0 6 | # If set to true, wheels will be renamed to $WHEEL_NAME-nightly-cp38-cp38-linux_x86_64.whl. 7 | nightly_release: false 8 | # Whether to preinstall libtpu in the PyTorch/XLA wheel. 9 | bundle_libtpu: 1 10 | # Suffix for bazel remote cache key 11 | cache_suffix: "" 12 | # Whether to build C++ tests with `torch_xla` wheel 13 | build_cpp_tests: 0 14 | # Whether to tag wheels with git hash, e.g. X.Y.Z+git123abc 15 | git_versioned_xla_build: false 16 | # Whether to use C++11 ABI when building torch and torch_xla. 17 | cxx11_abi: 1 18 | -------------------------------------------------------------------------------- /torch_xla/csrc/matrix.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_MATRIX_H_ 2 | #define XLA_TORCH_XLA_CSRC_MATRIX_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | xla::XlaOp BuildTriu(xla::XlaOp input, int64_t diagonal); 9 | 10 | xla::XlaOp BuildTril(xla::XlaOp input, int64_t diagonal); 11 | 12 | xla::XlaOp BuildDiagonal(xla::XlaOp input, int64_t offset, int64_t dim1, 13 | int64_t dim2); 14 | 15 | xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input, 16 | int64_t offset, int64_t dim1, int64_t dim2); 17 | 18 | xla::XlaOp BuildInverse(xla::XlaOp input); 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_MATRIX_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/resize.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_RESIZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_RESIZE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Resize : public XlaNode { 9 | public: 10 | Resize(const torch::lazy::Value& input, std::vector size); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | 16 | std::string ToString() const override; 17 | 18 | const std::vector& size() const { return size_; } 19 | 20 | private: 21 | std::vector size_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_RESIZE_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/nonzero.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NONZERO_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NONZERO_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class NonZero : public XlaNode { 12 | public: 13 | NonZero(const torch::lazy::Value& input); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | }; 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_OPS_NONZERO_H_ 23 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/linspace.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LINSPACE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LINSPACE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Linspace : public XlaNode { 9 | public: 10 | Linspace(const torch::lazy::Value& start, const torch::lazy::Value& end, 11 | const int64_t steps); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t steps() const { return steps_; }; 20 | 21 | private: 22 | int64_t steps_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_LINSPACE_H_ -------------------------------------------------------------------------------- /infra/ansible/development.Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for building a development image. 2 | # The built image contains all required pip and apt packages for building and 3 | # running PyTorch and PyTorch/XLA. The image doesn't contain any source code. 4 | ARG python_version=3.8 5 | ARG debian_version=bookworm 6 | 7 | FROM python:${python_version}-${debian_version} 8 | 9 | RUN pip install ansible 10 | 11 | COPY . /ansible 12 | WORKDIR /ansible 13 | 14 | # List Asnible tasks to apply for the dev image. 15 | ENV TAGS="bazel,configure_env,install_deps" 16 | 17 | ARG ansible_vars 18 | RUN ansible-playbook playbook.yaml -e "stage=build" -e "${ansible_vars}" --tags "${TAGS}" 19 | RUN ansible-playbook playbook.yaml -e "stage=release" -e "${ansible_vars}" --tags "${TAGS}" 20 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/constant.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CONSTANT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CONSTANT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Constant : public XlaNode { 9 | public: 10 | Constant(xla::Literal value); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | const xla::Literal& value() const { return value_; } 19 | 20 | private: 21 | xla::Literal value_; 22 | }; 23 | 24 | torch::lazy::hash_t LiteralHash(const xla::Literal& l); 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_CONSTANT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/expand.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EXPAND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EXPAND_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class Expand : public XlaNode { 11 | public: 12 | Expand(const torch::lazy::Value& input, std::vector size); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | const std::vector& size() const { return size_; }; 21 | 22 | private: 23 | std::vector size_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_EXPAND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/squeeze.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SQUEEZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SQUEEZE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Squeeze : public XlaNode { 9 | public: 10 | // Squeeze out the specified dimension index, -1 for all trivial dimensions. 11 | Squeeze(const torch::lazy::Value& input, int dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int dim() const { return dim_; } 20 | 21 | private: 22 | int dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_SQUEEZE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/svd.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SVD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SVD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class SVD : public XlaNode { 9 | public: 10 | SVD(const torch::lazy::Value& input, bool some, bool compute_uv); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | bool some() const { return some_; } 19 | 20 | bool compute_uv() const { return compute_uv_; } 21 | 22 | private: 23 | bool some_; 24 | bool compute_uv_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_SVD_H_ -------------------------------------------------------------------------------- /plugins/cpu/BUILD: -------------------------------------------------------------------------------- 1 | load( 2 | "@xla//xla:xla.bzl", 3 | "xla_cc_binary", 4 | ) 5 | 6 | cc_library( 7 | name = "test_cpu_plugin", 8 | srcs = ["test_cpu_plugin.cpp"], 9 | hdrs = ["test_cpu_plugin.h"], 10 | visibility = ["//visibility:public"], 11 | deps = [ 12 | "@xla//xla/pjrt/c:pjrt_c_api_cpu_internal", 13 | "@xla//xla/pjrt/c:pjrt_c_api_hdrs", 14 | ], 15 | ) 16 | 17 | xla_cc_binary( 18 | name = "pjrt_c_api_cpu_plugin.so", 19 | linkopts = [ 20 | "-Wl,--version-script,$(location :pjrt_c_api_cpu_version_script.lds)", 21 | "-Wl,--no-undefined", 22 | ], 23 | linkshared = True, 24 | deps = [ 25 | ":pjrt_c_api_cpu_version_script.lds", 26 | ":test_cpu_plugin", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /test/cpp/test_status_dont_show_cpp_stacktraces.cpp: -------------------------------------------------------------------------------- 1 | #include "test/cpp/test_status_common.h" 2 | 3 | using torch_xla::StatusTest; 4 | 5 | namespace torch_xla::cpp_test { 6 | namespace { 7 | 8 | // This file instantiates the parameterized tests defined in 9 | // `test_status_common.h`. It specifically configures the test environment by 10 | // explicitly setting the `TORCH_SHOW_CPP_STACKTRACES` environment variable to 11 | // 'false' in the test fixture's `SetUp` method. 12 | // 13 | // Any new `TEST_P` test cases added to `test_status_common.h` will 14 | // automatically be run in this mode (without C++ error context). 15 | // 16 | INSTANTIATE_WITH_CPP_STACKTRACES_MODE(StatusTest, StatusTest, kHide); 17 | 18 | } // namespace 19 | } // namespace torch_xla::cpp_test 20 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/put.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_PUT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_PUT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Put : public XlaNode { 9 | public: 10 | Put(const torch::lazy::Value& input, const torch::lazy::Value& index, 11 | const torch::lazy::Value& source, bool accumulate); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | bool accumulate() const { return accumulate_; } 20 | 21 | private: 22 | bool accumulate_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_PUT_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/index_select.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INDEX_SELECT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INDEX_SELECT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class IndexSelect : public XlaNode { 9 | public: 10 | IndexSelect(const torch::lazy::Value& input, int64_t dim, 11 | const torch::lazy::Value& index); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_INDEX_SELECT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/scatter.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SCATTER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SCATTER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Scatter : public XlaNode { 9 | public: 10 | Scatter(const torch::lazy::Value& input, const torch::lazy::Value& index, 11 | const torch::lazy::Value& src, int64_t dim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/view.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_VIEW_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_VIEW_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ViewOp : public XlaNode { 11 | public: 12 | ViewOp(const torch::lazy::Value& input, std::vector output_size); 13 | ViewOp(const torch::lazy::Value& input, xla::Shape output_shape); 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::vector& output_size() const { return output_size_; } 20 | 21 | private: 22 | std::vector output_size_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_VIEW_H_ 28 | -------------------------------------------------------------------------------- /infra/ansible/ansible.cfg: -------------------------------------------------------------------------------- 1 | # See https://docs.ansible.com/ansible/latest/reference_appendices/config.html 2 | # for various configuration options. 3 | 4 | [defaults] 5 | # Displays tasks execution duration. 6 | callbacks_enabled = profile_tasks 7 | # The playbooks is only run on the implicit localhost. 8 | # Silence warning about empty hosts inventory. 9 | localhost_warning = False 10 | # Make output human-readable. 11 | stdout_callback = yaml 12 | # Ansible 2.19 requires this environment variable being set, so that we can use 13 | # string variables as boolean. 14 | allow_broken_conditionals = true 15 | 16 | [inventory] 17 | # Silence warning about no inventory. 18 | # This option is available since Ansible 2.14 (available only with Python 3.9+). 19 | inventory_unparsed_warning = False 20 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/count_nonzero.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_COUNT_NONZERO_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_COUNT_NONZERO_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class CountNonzero : public XlaNode { 9 | public: 10 | CountNonzero(const torch::lazy::Value& input, std::vector dims); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::optional> dims() const { return dims_; } 19 | 20 | private: 21 | std::vector dims_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_COUNT_NONZERO_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/utils/dlpack.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import enum 3 | from torch.utils.dlpack import DLDeviceType 4 | import torch 5 | import torch_xla 6 | import torch_xla.utils.utils as xu 7 | 8 | 9 | def to_dlpack(xla_tensor: Any): 10 | return torch_xla._XLAC._to_dlpack(xla_tensor) 11 | 12 | 13 | def from_dlpack(ext_tensor: Any): 14 | if hasattr(ext_tensor, '__dlpack_device__') and hasattr( 15 | ext_tensor, '__dlpack__'): 16 | device_type, _ = ext_tensor.__dlpack_device__() 17 | if device_type != DLDeviceType.kDLCPU: 18 | raise ValueError( 19 | "PyTorch/XLA DLPack implementation currently only supports CPU.") 20 | dlpack = ext_tensor.__dlpack__() 21 | else: 22 | dlpack = ext_tensor 23 | 24 | return torch_xla._XLAC._from_dlpack(dlpack) 25 | -------------------------------------------------------------------------------- /codegen/fix_includes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # torchgen insists on including the tensor header as a system header 4 | sed -i 's##"torch_xla/csrc/tensor.h"#' $@ 5 | 6 | # remove the runfiles-prefix used in codegen for pytorch 7 | # `torchgen` generates relative includes path and does not support customizing the root, 8 | # so we have to fix them up. 9 | if [[ $(uname -m) == "x86_64" ]]; then 10 | sed -i 's#bazel-out/k8-[^/]*/bin/codegen/lazy_tensor_generator.runfiles/torch/##' $@ 11 | elif [[ $(uname -m) == "aarch64" ]]; then 12 | sed -i 's#bazel-out/aarch64-[^/]*/bin/codegen/lazy_tensor_generator.runfiles/torch/##' $@ 13 | fi 14 | 15 | # use the generated files that are in the compilation unit 16 | sed -i 's##"bazel-out/\1"#' $@ 17 | -------------------------------------------------------------------------------- /benchmarks/check_xla_device.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | assert len(sys.argv) in (2, 3) 4 | devkind = sys.argv[1] 5 | 6 | 7 | def use_torch_xla2(): 8 | use_xla2 = False 9 | if len(sys.argv) == 3 and sys.argv[2].lower() == 'true': 10 | use_xla2 = True 11 | return use_xla2 12 | 13 | 14 | import os 15 | 16 | os.environ["PJRT_DEVICE"] = devkind 17 | 18 | if not use_torch_xla2(): 19 | import torch_xla.core.xla_model as xm 20 | devlist = xm.get_xla_supported_devices() 21 | else: 22 | # torch_xla2 needs jax to detect the device 23 | os.environ["JAX_PLATFORMS"] = devkind.lower( 24 | ) # JAX_PLATFORMS only accepts lower case 25 | assert devkind.lower() in ('cpu', 'gpu', 'tpu') 26 | import jax 27 | devlist = jax.devices(devkind.lower()) 28 | 29 | if not devlist: 30 | sys.exit(1) 31 | -------------------------------------------------------------------------------- /scripts/update_torch_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | DIST_BUCKET="gs://tpu-pytorch/wheels" 6 | TORCH_WHEEL="torch-nightly-cp36-cp36m-linux_x86_64.whl" 7 | TORCH_XLA_WHEEL="torch_xla-nightly-cp36-cp36m-linux_x86_64.whl" 8 | TORCHVISION_WHEEL="torchvision-nightly-cp36-cp36m-linux_x86_64.whl" 9 | 10 | [[ ! -z "$1" ]] && conda activate $1 11 | 12 | function update_wheels() { 13 | gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" /tmp/ 14 | gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" /tmp/ 15 | gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" /tmp/ 16 | pip install "/tmp/$TORCH_WHEEL" 17 | pip install "/tmp/$TORCH_XLA_WHEEL" 18 | pip install "/tmp/$TORCHVISION_WHEEL" 19 | 20 | rm -f "/tmp/$TORCH_WHEEL" "/tmp/$TORCH_XLA_WHEEL" "/tmp/$TORCHVISION_WHEEL" 21 | } 22 | 23 | update_wheels 24 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/max_in_dim.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MAX_IN_DIM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MAX_IN_DIM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MaxInDim : public XlaNode { 9 | public: 10 | MaxInDim(const torch::lazy::Value& input, int64_t dim, bool keepdim); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | int64_t dim() const { return dim_; }; 19 | 20 | bool keepdim() const { return keepdim_; } 21 | 22 | private: 23 | int64_t dim_; 24 | bool keepdim_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_MAX_IN_DIM_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/min_in_dim.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MIN_IN_DIM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MIN_IN_DIM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MinInDim : public XlaNode { 9 | public: 10 | MinInDim(const torch::lazy::Value& input, int64_t dim, bool keepdim); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | int64_t dim() const { return dim_; }; 19 | 20 | bool keepdim() const { return keepdim_; } 21 | 22 | private: 23 | int64_t dim_; 24 | bool keepdim_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_MIN_IN_DIM_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/not_supported.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NOT_SUPPORTED_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NOT_SUPPORTED_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class NotSupported : public XlaNode { 11 | public: 12 | NotSupported(std::string description, xla::Shape shape); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | const std::string& description() const { return description_; } 21 | 22 | private: 23 | std::string description_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_NOT_SUPPORTED_H_ 29 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/scatter_add.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SCATTER_ADD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SCATTER_ADD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class ScatterAdd : public XlaNode { 9 | public: 10 | ScatterAdd(const torch::lazy::Value& input, const torch::lazy::Value& index, 11 | const torch::lazy::Value& src, int64_t dim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_ADD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/symeig.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SYMEIG_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SYMEIG_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class SymEig : public XlaNode { 9 | public: 10 | SymEig(const torch::lazy::Value& input, bool eigenvectors, bool lower); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | bool eigenvectors() const { return eigenvectors_; } 19 | 20 | bool lower() const { return lower_; } 21 | 22 | private: 23 | bool eigenvectors_; 24 | bool lower_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_SYMEIG_H_ 30 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/unsqueeze.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UNSQUEEZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UNSQUEEZE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Unsqueeze : public XlaNode { 9 | public: 10 | // Insert a dimension of size one at the specified position. 11 | Unsqueeze(const torch::lazy::Value& input, int dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int dim() const { return dim_; } 20 | 21 | private: 22 | // Position to unsqueeze. 23 | int dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_UNSQUEEZE_H_#pragma once -------------------------------------------------------------------------------- /torch_xla/csrc/ops/flip.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_FLIP_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_FLIP_H_ 3 | 4 | #include "absl/types/span.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class Flip : public XlaNode { 11 | public: 12 | Flip(const torch::lazy::Value& input, std::vector dims); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const std::vector& dims() const { return dims_; } 21 | 22 | private: 23 | // The dimensions which are flipped. 24 | std::vector dims_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_FLIP_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/masked_select.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MASKED_SELECT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MASKED_SELECT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class MaskedSelect : public XlaNode { 12 | public: 13 | MaskedSelect(const torch::lazy::Value& input, const torch::lazy::Value& mask); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | }; 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_OPS_MASKED_SELECT_H_ -------------------------------------------------------------------------------- /infra/tpu-pytorch/misc.tf: -------------------------------------------------------------------------------- 1 | # Docker registry for private images. 2 | module "docker_registry" { 3 | source = "../terraform_modules/docker_registry" 4 | name = "docker" 5 | description = join(" ", [ 6 | "Private docker images for PyTorch/XLA.", 7 | "Managed by Terraform setup in docker/experimental/tpu-pytorch/misc.tf.", 8 | ]) 9 | } 10 | 11 | # Storage bucket for Terraform state of this project. 12 | module "tfstate_storage_bucket" { 13 | source = "../terraform_modules/storage_bucket" 14 | name = "tpu-pytorch-tfstate" 15 | } 16 | 17 | # Private worker pool for Cloud Builds. 18 | module "worker_pool" { 19 | source = "../terraform_modules/worker_pool" 20 | # See https://cloud.google.com/compute/docs/machine-resource#machine_type_comparison. 21 | machine_type = "e2-standard-32" 22 | } 23 | -------------------------------------------------------------------------------- /test/cpp/test_status_show_cpp_stacktraces.cpp: -------------------------------------------------------------------------------- 1 | #include "test/cpp/test_status_common.h" 2 | 3 | using torch_xla::StatusTest; 4 | 5 | namespace torch_xla::cpp_test { 6 | namespace { 7 | 8 | // This file instantiates the parameterized tests defined in 9 | // `test_status_common.h`. It specifically configures the test environment by 10 | // explicitly setting the `TORCH_SHOW_CPP_STACKTRACES` environment variable to 11 | // 'true' in the test fixture's `SetUp` method. 12 | // 13 | // Any new `TEST_P` test cases added to `test_status_common.h` will 14 | // automatically be run in this mode (with C++ error context). 15 | INSTANTIATE_WITH_CPP_STACKTRACES_MODE(StatusWithCppErrorContextTest, StatusTest, 16 | kShow); 17 | 18 | } // namespace 19 | } // namespace torch_xla::cpp_test 20 | -------------------------------------------------------------------------------- /torch_xla/csrc/shape_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_SHAPE_HELPER_H_ 2 | #define XLA_TORCH_XLA_SHAPE_HELPER_H_ 3 | 4 | #include "absl/base/attributes.h" 5 | #include "absl/base/nullability.h" 6 | #include "xla/hlo/builder/xla_builder.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ShapeHelper { 11 | public: 12 | // Returns the shape of the given XLA operation. 13 | ABSL_DEPRECATED( 14 | "Use GetShape(op) instead. ShapeOfXlaOp() blindly " 15 | "de-references StatusOr returned by XLA, which is unsafe.") 16 | static const xla::Shape& ShapeOfXlaOp(xla::XlaOp op); 17 | }; 18 | 19 | // Returns the shape of the given XLA operation. 20 | absl::StatusOr GetShape(xla::XlaOp op); 21 | 22 | } // namespace torch_xla 23 | 24 | #endif // XLA_TORCH_XLA_SHAPE_HELPER_H_ 25 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/tpu_custom_call.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class TpuCustomCall : public XlaNode { 9 | public: 10 | // Make a TPU custom call with payload, e.g., Mosaic. 11 | TpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, 12 | const std::string& payload); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | private: 21 | std::string payload_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ 27 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 30 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - high priority 8 | - nostale 9 | # Label to use when marking an issue as stale 10 | staleLabel: stale 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | # Limit to only `issues` or `pulls` 19 | only: issues 20 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/index_get.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INDEX_GET_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INDEX_GET_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class IndexGet : public XlaNode { 9 | public: 10 | IndexGet(const torch::lazy::Value& base, const torch::lazy::Value& indices, 11 | int64_t start_dim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t start_dim() const { return start_dim_; } 20 | 21 | private: 22 | // The dimension number at which indexing starts. 23 | int64_t start_dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_INDEX_GET_H_ -------------------------------------------------------------------------------- /infra/ansible/roles/fetch_srcs/tasks/main.yaml: -------------------------------------------------------------------------------- 1 | - name: "Create source root directory at {{ src_root }}" 2 | ansible.builtin.file: 3 | path: "{{ src_root }}" 4 | state: directory 5 | mode: '0755' 6 | 7 | - name: "Clone git PyTorch and XLA git repos" 8 | ansible.builtin.git: 9 | repo: "{{ item.repo }}" 10 | dest: "{{ item.dest }}" 11 | version: "{{ item.version }}" 12 | depth: 1 13 | force: true 14 | loop: 15 | - repo: https://github.com/pytorch/pytorch 16 | dest: "{{ (src_root, 'pytorch') | path_join }}" 17 | version: "{{ pytorch_git_rev }}" 18 | 19 | - repo: https://github.com/pytorch/xla 20 | dest: "{{ (src_root, 'pytorch/xla') | path_join }}" 21 | version: "{{ xla_git_rev }}" 22 | 23 | - name: "Tests" 24 | include_tasks: tests.yaml 25 | tags: 26 | - tests 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cdist.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CDIST_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CDIST_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class CdistForward : public XlaNode { 9 | public: 10 | CdistForward(const torch::lazy::Value& x1, const torch::lazy::Value& x2, 11 | const torch::lazy::Value& p, bool use_hamming, 12 | bool use_chebyshev); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | bool use_hamming_; // handle p == 0 22 | bool use_chebyshev_; // handle p == +inf 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_CDIST_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/linear_interpolation.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LINEAR_INTERPOLATION_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LINEAR_INTERPOLATION_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class LinearInterpolation : public XlaNode { 9 | public: 10 | LinearInterpolation(const torch::lazy::Value& value, 11 | const torch::lazy::Value& new_value, double alpha); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | double alpha() const { return alpha_; } 20 | 21 | private: 22 | double alpha_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_LINEAR_INTERPOLATION_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/threshold_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class ThresholdBackward : public XlaNode { 9 | public: 10 | ThresholdBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::Value& input, float threshold); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | float threshold() const { return threshold_; } 20 | 21 | private: 22 | float threshold_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/softmax_builder.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_SOFTMAX_BUILDER_H_ 2 | #define XLA_TORCH_XLA_CSRC_SOFTMAX_BUILDER_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | // Computes log(softmax(logits)) along the dimension specified by "dim". 9 | xla::XlaOp BuildLogSoftmax(xla::XlaOp logits, int64_t dim); 10 | 11 | // Computes the gradient of the input of the LogSoftmax function. 12 | xla::XlaOp BuildLogSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, 13 | int64_t dim); 14 | 15 | xla::XlaOp BuildSoftmax(xla::XlaOp logits, int64_t dim); 16 | 17 | xla::XlaOp BuildSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, 18 | int64_t dim); 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_SOFTMAX_BUILDER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/replication_pad.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_H_ 3 | 4 | #include "absl/types/span.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ReplicationPad : public XlaNode { 11 | public: 12 | ReplicationPad(const torch::lazy::Value& input, std::vector padding); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const std::vector& padding() const { return padding_; } 21 | 22 | private: 23 | std::vector padding_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_H_ 29 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # This is a copy of the requirements.txt file from the PyTorch repository v2.7.0, 2 | # with some lines commented out. 3 | sphinx==5.0.0 4 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 5 | # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering 6 | # but it doesn't seem to work and hangs around idly. The initial thought is probably 7 | # something related to Docker setup. We can investigate this later 8 | sphinxcontrib.katex==0.8.6 9 | # matplotlib==3.6.0 10 | # tensorboard==2.10.0 11 | # required to build torch.distributed.elastic.rendezvous.etcd* docs 12 | # python-etcd==0.4.5 13 | sphinx-copybutton==0.5.0 14 | # sphinx-panels==0.4.1 15 | myst-parser==0.18.1 16 | 17 | # This is an additional requirement for the PyTorch XLA documentation. 18 | myst-nb==0.16 -------------------------------------------------------------------------------- /torch_xla/csrc/ops/einsum_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EINSUM_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EINSUM_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class EinsumBackward : public XlaNode { 9 | public: 10 | EinsumBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::OpList& inputs, const std::string equation); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::string& equation() const { return equation_; } 20 | 21 | private: 22 | const std::string equation_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_EINSUM_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/threshold.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // IR node for the threshold operation. 9 | class Threshold : public XlaNode { 10 | public: 11 | Threshold(const torch::lazy::Value& input, float threshold, float value); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | float threshold() const { return threshold_; } 20 | 21 | float value() const { return value_; } 22 | 23 | private: 24 | float threshold_; 25 | float value_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_H_ -------------------------------------------------------------------------------- /test/benchmarks/Makefile: -------------------------------------------------------------------------------- 1 | TEST_ARGS = $(shell echo $@ | perl -pe 's/([^.]*)\.([^.]*)\.([^.]*).*\.test/--accelerator=$$1 --test=$$2 --report=$$3/') 2 | EMBEDDED_TEST_ARGS = $(shell cat $@ | grep '^# ARGS: ' | perl -pe 's/^# ARGS: (.*)/$$1/') 3 | 4 | TESTS := $(wildcard *.test) 5 | all: $(TESTS) 6 | .PHONY: $(TESTS) all 7 | 8 | ifndef V 9 | QUIET_AGGREGATE = @echo ' ' AGGREGATE $(TEST_ARGS) $(EMBEDDED_TEST_ARGS); 10 | QUIET_DIFF = @echo ' ' DIFF $@; 11 | QUIET_RM = @echo ' ' RM $@.tmp; 12 | endif 13 | 14 | $(TESTS): 15 | $(QUIET_AGGREGATE)python3 ../../benchmarks/aggregate.py \ 16 | --format=csv \ 17 | $(TEST_ARGS) $(EMBEDDED_TEST_ARGS) \ 18 | $(wildcard *.jsonl) > $@.tmp 19 | $(QUIET_DIFF)git diff -I'^# ARGS: ' --no-index $@ $@.tmp 20 | $(QUIET_RM)$(RM) $@.tmp 21 | 22 | clean: 23 | $(RM) *.tmp 24 | .PHONY: clean 25 | -------------------------------------------------------------------------------- /torch_xla/csrc/unwrap_data.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_UNWRAP_DATA_H_ 2 | #define XLA_TORCH_XLA_CSRC_UNWRAP_DATA_H_ 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "absl/types/span.h" 10 | 11 | #include "torch_xla/csrc/runtime/computation_client.h" 12 | 13 | namespace torch_xla { 14 | 15 | runtime::ComputationClient::DataPtr UnwrapXlaData( 16 | const torch::lazy::BackendDataPtr& data); 17 | 18 | std::vector UnwrapXlaData( 19 | absl::Span datas); 20 | 21 | std::vector WrapXlaData( 22 | absl::Span xla_datas); 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_UNWRAP_DATA_H 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/arithmetic_ir_ops.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_ARITHMETIC_IR_OPS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_ARITHMETIC_IR_OPS_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | torch::lazy::NodePtr operator+(const torch::lazy::Value& node1, 9 | const torch::lazy::Value& node2); 10 | torch::lazy::NodePtr operator-(const torch::lazy::Value& node1, 11 | const torch::lazy::Value& node2); 12 | torch::lazy::NodePtr operator*(const torch::lazy::Value& node1, 13 | const torch::lazy::Value& node2); 14 | torch::lazy::NodePtr operator/(const torch::lazy::Value& node1, 15 | const torch::lazy::Value& node2); 16 | 17 | } // namespace torch_xla 18 | 19 | #endif // XLA_TORCH_XLA_CSRC_OPS_ARITHMETIC_IR_OPS_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/mse_loss.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_H_ 3 | 4 | #include "xla/types.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | #include "torch_xla/csrc/reduction.h" 8 | 9 | namespace torch_xla { 10 | 11 | class MseLoss : public XlaNode { 12 | public: 13 | MseLoss(const torch::lazy::Value& input, const torch::lazy::Value& target, 14 | ReductionMode reduction); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | ReductionMode reduction() const { return reduction_; } 23 | 24 | private: 25 | ReductionMode reduction_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/reflection_pad2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ReflectionPad2d : public XlaNode { 11 | public: 12 | ReflectionPad2d(const torch::lazy::Value& input, 13 | std::vector padding); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& padding() const { return padding_; } 22 | 23 | private: 24 | std::vector padding_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_H_ 30 | -------------------------------------------------------------------------------- /test/dynamo/test_dynamo_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_xla 3 | import unittest 4 | from torch_xla._dynamo import config 5 | 6 | 7 | class DynamoconfigTest(unittest.TestCase): 8 | 9 | def dummy_test(self, a): 10 | return a.cos().sin() 11 | 12 | def test_config_skip_input_data_check(self): 13 | device = torch_xla.device() 14 | print(config.skip_input_data_check) 15 | config.skip_input_data_check = True 16 | compiled_dummy = torch.compile(self.dummy_test, backend="openxla") 17 | t1 = torch.randn(3, 4, device=device) 18 | compiled_dummy(t1) 19 | t2 = torch.randn(3, 4, device=device) 20 | t2 += 5 21 | with self.assertRaisesRegex( 22 | RuntimeError, r'input data to dynamo graph can not be a pending ir'): 23 | compiled_dummy(t2) 24 | 25 | 26 | if __name__ == '__main__': 27 | test = unittest.main() 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/softmax_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class SoftmaxBackward : public XlaNode { 9 | public: 10 | SoftmaxBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::Value& output, int64_t dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t dim() const { return dim_; } 20 | 21 | private: 22 | // The dimension along which the result is computed. 23 | int64_t dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/masked_scatter.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MASKED_SCATTER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MASKED_SCATTER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class MaskedScatter : public XlaNode { 12 | public: 13 | MaskedScatter(const torch::lazy::Value& input, const torch::lazy::Value& mask, 14 | const torch::lazy::Value& source); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | }; 20 | 21 | } // namespace torch_xla 22 | 23 | #endif // XLA_TORCH_XLA_CSRC_OPS_MASKED_SCATTER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/max_unpool_nd.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MAX_UNPOOL_ND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MAX_UNPOOL_ND_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MaxUnpoolNd : public XlaNode { 9 | public: 10 | MaxUnpoolNd(const torch::lazy::Value& input, 11 | const torch::lazy::Value& indices, 12 | std::vector output_size); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const std::vector& output_size() const { return output_size_; } 21 | 22 | private: 23 | std::vector output_size_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_MAX_UNPOOL_ND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_AMP_FOREACH_NON_FINITE_CHECK_AND_UNSCALE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_AMP_FOREACH_NON_FINITE_CHECK_AND_UNSCALE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class AmpForachNonFiniteCheckAndUnscale : public XlaNode { 9 | public: 10 | AmpForachNonFiniteCheckAndUnscale(const torch::lazy::OpList& inputs, 11 | const torch::lazy::Value& found_inf, 12 | const torch::lazy::Value& inv_scale); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | }; 18 | 19 | } // namespace torch_xla 20 | 21 | #endif // XLA_TORCH_XLA_CSRC_OPS_AMP_FOREACH_NON_FINITE_CHECK_AND_UNSCALE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/roll.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_ROLL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_ROLL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Roll : public XlaNode { 9 | public: 10 | Roll(const torch::lazy::Value& input, std::vector shifts, 11 | std::vector dims); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::vector& shifts() const { return shifts_; } 20 | 21 | const std::vector& dims() const { return dims_; } 22 | 23 | private: 24 | std::vector shifts_; 25 | std::vector dims_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_ROLL_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/upsample_nearest2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_NEAREST2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_NEAREST2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class UpsampleNearest : public XlaNode { 11 | public: 12 | UpsampleNearest(const torch::lazy::Value& input, 13 | std::vector output_size); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& output_size() const { return output_size_; } 22 | 23 | private: 24 | std::vector output_size_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_NEAREST2D_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cat.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CAT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CAT_H_ 3 | 4 | #include 5 | 6 | #include "absl/types/span.h" 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class Cat : public XlaNode { 13 | public: 14 | Cat(c10::ArrayRef values, int64_t dim, 15 | at::ScalarType dtype); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | int64_t dim() const { return dim_; }; 24 | 25 | at::ScalarType dtype() const { return dtype_; }; 26 | 27 | private: 28 | int64_t dim_; 29 | at::ScalarType dtype_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_CAT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/get_dimensions_size.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_GET_DIMENSIONS_SIZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_GET_DIMENSIONS_SIZE_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class GetDimensionsSize : public XlaNode { 11 | public: 12 | GetDimensionsSize(const torch::lazy::Value& input, 13 | std::vector dimensions); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& dimensions() const { return dimensions_; } 22 | 23 | private: 24 | std::vector dimensions_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_GET_DIMENSIONS_SIZE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/kth_value.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_KTH_VALUE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_KTH_VALUE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class KthValue : public XlaNode { 9 | public: 10 | KthValue(const torch::lazy::Value& input, int64_t k, int64_t dim, 11 | bool keepdim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t k() const { return k_; }; 20 | 21 | int64_t dim() const { return dim_; }; 22 | 23 | bool keepdim() const { return keepdim_; } 24 | 25 | private: 26 | int64_t k_; 27 | int64_t dim_; 28 | bool keepdim_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_KTH_VALUE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/tensor_common.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_TENSOR_COMMON_H_ 2 | #define XLA_TORCH_XLA_CSRC_TENSOR_COMMON_H_ 3 | 4 | #include 5 | 6 | #include "xla/hlo/builder/xla_builder.h" 7 | 8 | namespace torch_xla { 9 | 10 | // XLA SPMD sharding spec annoation. The XLA tensor uses this to create 11 | // HloSharding for replication, manual and tile shardings. 12 | struct ShardingSpec { 13 | ShardingSpec(const xla::OpSharding& sharding) : sharding(sharding) {} 14 | ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape) 15 | : sharding(sharding), shape(shape) {} 16 | 17 | xla::OpSharding sharding; 18 | // Optional source tensor shape unpartitioned. 19 | std::optional shape; 20 | }; 21 | 22 | using ShardingSpecPtr = std::shared_ptr; 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_TENSOR_COMMON_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/adaptive_max_pool2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_ADAPTIVE_MAX_POOL2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_ADAPTIVE_MAX_POOL2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class AdaptiveMaxPool2d : public XlaNode { 11 | public: 12 | AdaptiveMaxPool2d(const torch::lazy::Value& input, 13 | std::vector output_size); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& output_size() const { return output_size_; } 22 | 23 | private: 24 | std::vector output_size_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_ADAPTIVE_MAX_POOL2D_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/send.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SEND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SEND_H_ 3 | 4 | #include "torch_xla/csrc/cross_replica_reduces.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | namespace ir { 9 | namespace ops { 10 | 11 | class Send : public XlaNode { 12 | public: 13 | Send(const torch::lazy::Value& input, const torch::lazy::Value& token, 14 | int64_t channel_id); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | int64_t channel_id() const { return channel_id_; } 23 | 24 | private: 25 | int64_t channel_id_; 26 | }; 27 | 28 | } // namespace ops 29 | } // namespace ir 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_SEND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/tf_logging.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/runtime/tf_logging.h" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include "tsl/platform/stacktrace.h" 9 | 10 | #include "torch_xla/csrc/status.h" 11 | 12 | namespace torch_xla { 13 | namespace runtime { 14 | namespace internal { 15 | 16 | void ErrorGenerator::operator&(const std::basic_ostream& oss) const { 17 | const ErrorSink& sink = dynamic_cast(oss); 18 | 19 | std::stringstream ess; 20 | ess << sink.str(); 21 | 22 | if (torch::get_cpp_stacktraces_enabled()) { 23 | ess << " (at " << file_ << ":" << line_ << ")\n"; 24 | } 25 | 26 | TF_VLOG(1) << ess.str(); 27 | TORCH_CHECK(false, ess.str()); 28 | } 29 | 30 | } // namespace internal 31 | } // namespace runtime 32 | } // namespace torch_xla 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/update_slice.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UPDATE_SLICE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UPDATE_SLICE_H_ 3 | 4 | #include "absl/types/span.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class UpdateSlice : public XlaNode { 11 | public: 12 | UpdateSlice(const torch::lazy::Value& input, const torch::lazy::Value& source, 13 | absl::Span base_indices); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& base_indices() const { return base_indices_; } 22 | 23 | private: 24 | std::vector base_indices_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_UPDATE_SLICE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dynamic_expand.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_EXPAND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_EXPAND_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DynamicExpand : public XlaNode { 9 | public: 10 | DynamicExpand(const torch::lazy::Value& input, 11 | const std::vector& size, 12 | const torch::lazy::Value& src_tensor, int src_index, 13 | int target_index); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | private: 22 | std::vector size_; 23 | int src_index_; 24 | int target_index_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_EXPAND_H_ 30 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/log_softmax_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class LogSoftmaxBackward : public XlaNode { 9 | public: 10 | LogSoftmaxBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::Value& output, int64_t dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t dim() const { return dim_; } 20 | 21 | private: 22 | // The dimension along which the result is computed. 23 | int64_t dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_BACKWARD_H_#pragma once -------------------------------------------------------------------------------- /torch_xla/csrc/dtype.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_DTYPE_H_ 2 | #define XLA_TORCH_XLA_CSRC_DTYPE_H_ 3 | 4 | #include "xla/shape.h" 5 | 6 | #include "torch_xla/csrc/device.h" 7 | 8 | namespace torch_xla { 9 | 10 | at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type); 11 | 12 | xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type); 13 | 14 | // Downcast type to be compatible with device if necessary. 15 | xla::PrimitiveType MaybeDowncastToXlaDeviceType( 16 | xla::PrimitiveType type, const torch::lazy::BackendDevice& device); 17 | 18 | xla::PrimitiveType MaybeDowncastToXlaDeviceType( 19 | at::ScalarType scalar_type, const torch::lazy::BackendDevice& device); 20 | 21 | // Upcast type to original PyTorch type. 22 | at::ScalarType MaybeUpcastToHostTorchType(xla::PrimitiveType xla_type); 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_DTYPE_H_ 27 | -------------------------------------------------------------------------------- /.circleci/test_xrt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | source ./xla_env 6 | source .circleci/common.sh 7 | 8 | PYTORCH_DIR=/tmp/pytorch 9 | XLA_DIR=$PYTORCH_DIR/xla 10 | USE_COVERAGE="${USE_COVERAGE:-0}" 11 | 12 | # Needs to be kept in sync with .jenkins/pytorch/common_utils.sh in pytorch/pytorch. 13 | TORCHVISION_COMMIT="$(cat $PYTORCH_DIR/.github/ci_commit_pins/vision.txt)" 14 | 15 | function pip_install() { 16 | # retry 3 times 17 | # old versions of pip don't have the "--progress-bar" flag 18 | pip install --progress-bar off "$@" || pip install --progress-bar off "$@" || pip install --progress-bar off "$@" ||\ 19 | pip install "$@" || pip install "$@" || pip install "$@" 20 | } 21 | 22 | function install_torchvision() { 23 | pip_install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$TORCHVISION_COMMIT" 24 | } 25 | 26 | install_torchvision 27 | 28 | ./test/run_tests.sh 29 | -------------------------------------------------------------------------------- /.devcontainer/tpu-contributor/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tpu-contributor", 3 | "image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 4 | "runArgs": [ 5 | "--privileged", 6 | "--net=host", 7 | "--shm-size=16G" 8 | ], 9 | "initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 10 | "customizations": { 11 | "vscode": { 12 | "extensions": [ 13 | "llvm-vs-code-extensions.vscode-clangd", 14 | "ms-vscode.cpptools-themes", 15 | "BazelBuild.vscode-bazel", 16 | "DevonDCarew.bazel-code", 17 | "StackBuild.bazel-stack-vscode", 18 | "StackBuild.bazel-stack-vscode-cc", 19 | "xaver.clang-format", 20 | "ryanluker.vscode-coverage-gutters", 21 | "ms-azuretools.vscode-docker", 22 | "ms-python.python" 23 | ] 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dot_general.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DOT_GENERAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DOT_GENERAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DotGeneral : public XlaNode { 9 | public: 10 | DotGeneral(const torch::lazy::Value& lhs, const torch::lazy::Value& rhs, 11 | const std::vector>& dim_vectors, 12 | std::optional preferred_element_type); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | std::vector> dim_vectors_; 22 | std::optional preferred_element_type_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_DOT_GENERAL_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/recv.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_RECV_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_RECV_H_ 3 | 4 | #include "torch_xla/csrc/cross_replica_reduces.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | namespace ir { 9 | namespace ops { 10 | 11 | class Recv : public XlaNode { 12 | public: 13 | Recv(const torch::lazy::Value& token, const xla::Shape& recv_shape, 14 | int64_t channel_id); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | int64_t channel_id() const { return channel_id_; } 23 | 24 | private: 25 | xla::Shape recv_shape_; 26 | int64_t channel_id_; 27 | }; 28 | 29 | } // namespace ops 30 | } // namespace ir 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_RECV_H_ -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680Feature Request" 3 | about: Submit a proposal/request for a new feature for PyTorch/XLA integration 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🚀 Feature 11 | 12 | 13 | ## Motivation 14 | 15 | 16 | 17 | ## Pitch 18 | 19 | 20 | 21 | ## Alternatives 22 | 23 | 24 | 25 | ## Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /torch_xla/experimental/eager.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from contextlib import contextmanager 3 | 4 | import torch_xla 5 | import logging 6 | 7 | 8 | def eager_mode(enable: bool): 9 | """Configure torch_xla's default execution mode. 10 | 11 | Under eager mode only functions that was `torch_xla.compile`d will be 12 | traced and compiled. Other torch ops will be executed eagerly. 13 | """ 14 | torch_xla._XLAC._set_use_eager_mode(enable) 15 | 16 | 17 | def is_eager_mode() -> bool: 18 | """Return True if torch_xla is currently under eager mode 19 | """ 20 | return torch_xla._XLAC._get_use_eager_mode() 21 | 22 | 23 | @contextmanager 24 | def eager_mode_context(enable: bool): 25 | """Context manager to enable/disable the eager mode. 26 | """ 27 | saved_eager_mode = is_eager_mode() 28 | eager_mode(enable) 29 | try: 30 | yield saved_eager_mode 31 | finally: 32 | eager_mode(saved_eager_mode) 33 | -------------------------------------------------------------------------------- /infra/terraform_modules/trigger_schedule_account/service_account.tf: -------------------------------------------------------------------------------- 1 | resource "google_service_account" "build_runner" { 2 | account_id = "build-triggers-scheduler" 3 | description = "Service account for Scheduled Jobs. Has permissions to trigger Cloud Builds." 4 | } 5 | 6 | resource "google_project_iam_custom_role" "build_runner" { 7 | role_id = "build_runner" 8 | title = "Build Runner" 9 | description = "Grants permissions to trigger Cloud Builds." 10 | permissions = ["cloudbuild.builds.create"] 11 | } 12 | 13 | data "google_project" "project" {} 14 | 15 | resource "google_project_iam_member" "build_runner" { 16 | role = google_project_iam_custom_role.build_runner.name 17 | project = data.google_project.project.project_id 18 | member = "serviceAccount:${google_service_account.build_runner.email}" 19 | } 20 | 21 | output "email" { 22 | value = google_service_account.build_runner.email 23 | } 24 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/expand_symint.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EXPAND_SYMINT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EXPAND_SYMINT_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | #include "torch_xla/csrc/torch_util.h" 8 | 9 | namespace torch_xla { 10 | 11 | class ExpandSymInt : public XlaNode { 12 | public: 13 | ExpandSymInt(const torch::lazy::Value& input, 14 | const SymIntElements& size_elements); 15 | 16 | std::string ToString() const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | const std::vector& size() const { return upper_bounds_; }; 21 | 22 | const bool IsDynamic(int index) const { return dynamic_dims_[index]; }; 23 | 24 | private: 25 | std::vector upper_bounds_; 26 | std::vector dynamic_dims_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_EXPAND_SYMINT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/random.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_RANDOM_H_ 2 | #define XLA_TORCH_XLA_CSRC_RANDOM_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | // Set downcast to true if the caller knows the |maxval - minval| is appropriate 9 | // for f16 dtype. We avoid computing the range on-the-fly since it incurs an XLA 10 | // computation. 11 | xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape, 12 | xla::XlaOp minval, xla::XlaOp maxval, 13 | bool downcast = false); 14 | 15 | xla::XlaOp RngDiscreteUniform(xla::XlaOp seed, const xla::Shape& shape, 16 | xla::XlaOp minval, xla::XlaOp maxval); 17 | 18 | xla::XlaOp RngNormal(xla::XlaOp seed, const xla::Shape& shape, xla::XlaOp mean, 19 | xla::XlaOp std); 20 | 21 | } // namespace torch_xla 22 | 23 | #endif // XLA_TORCH_XLA_CSRC_RANDOM_H_ 24 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/metrics_reader.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_METRICS_READER_H_ 2 | #define XLA_CLIENT_METRICS_READER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "torch_xla/csrc/runtime/metrics.h" 9 | #include "torch_xla/csrc/runtime/types.h" 10 | 11 | namespace torch_xla { 12 | namespace runtime { 13 | namespace metrics_reader { 14 | 15 | // Creates a report with the current metrics statistics. 16 | std::string CreateMetricReport( 17 | const std::map& xrt_metrics); 18 | 19 | // Creates a report with the selected metrics statistics. 20 | std::string CreateMetricReport(const std::vector& counter_names, 21 | const std::vector& metric_names); 22 | 23 | } // namespace metrics_reader 24 | } // namespace runtime 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_CLIENT_METRICS_READER_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cumsum.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CUMSUM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CUMSUM_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class CumSum : public XlaNode { 13 | public: 14 | CumSum(const torch::lazy::Value& input, int64_t dim, 15 | std::optional dtype); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | int64_t dim_; 29 | std::optional dtype_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_CUMSUM_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/scatter_reduce.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class ScatterReduce : public XlaNode { 9 | public: 10 | ScatterReduce(const torch::lazy::Value& input, 11 | const torch::lazy::Value& index, const torch::lazy::Value& src, 12 | std::string_view reduce, bool include_self, int64_t dim); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | int64_t dim() const { return dim_; }; 21 | 22 | private: 23 | std::string reduce_; 24 | bool include_self_; 25 | int64_t dim_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/native_dropout.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class NativeDropout : public XlaNode { 12 | public: 13 | NativeDropout(const torch::lazy::Value& input, const torch::lazy::Value& seed, 14 | float p, std::optional train); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | float p_; 22 | std::optional train_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cumprod.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CUMPROD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CUMPROD_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class CumProd : public XlaNode { 13 | public: 14 | CumProd(const torch::lazy::Value& input, int64_t dim, 15 | std::optional dtype); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | int64_t dim_; 29 | std::optional dtype_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_CUMPROD_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dynamic_view.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_VIEW_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_VIEW_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DynamicView : public XlaNode { 9 | public: 10 | DynamicView(const torch::lazy::Value& input, const std::vector& size, 11 | const torch::lazy::Value& src_tensor, int src_index, 12 | int target_index, float mul_scaler); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | std::vector size_; 22 | int src_index_; 23 | int target_index_; 24 | float mul_scaler_; 25 | xla::Shape complete_output_shape_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_VIEW_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/softmax.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class Softmax : public XlaNode { 13 | public: 14 | Softmax(const torch::lazy::Value& input, int64_t dim, 15 | std::optional dtype); 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | std::string ToString() const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | int64_t dim_; 29 | std::optional dtype_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/amp_update_scale.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_AMP_UPDATE_SCALE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_AMP_UPDATE_SCALE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class AmpUpdateScale : public XlaNode { 9 | public: 10 | AmpUpdateScale(const torch::lazy::Value& current_scale, 11 | const torch::lazy::Value& growth_tracker, 12 | const torch::lazy::Value& found_inf, 13 | double scale_growth_factor, double scale_backoff_factor, 14 | int growth_interval); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | double scale_growth_factor_; 22 | double scale_backoff_factor_; 23 | int growth_interval_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_AMP_UPDATE_SCALE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/constant_pad_nd.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CONSTANT_PAD_ND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CONSTANT_PAD_ND_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ConstantPadNd : public XlaNode { 11 | public: 12 | ConstantPadNd(const torch::lazy::Value& input, std::vector pad, 13 | const at::Scalar& value); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | const at::Scalar& value() const { return value_; } 22 | 23 | const std::vector& pad() const { return pad_; } 24 | 25 | private: 26 | std::vector pad_; 27 | at::Scalar value_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_CONSTANT_PAD_ND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/split.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SPLIT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SPLIT_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | // Split the tensor into chunks along a given dimension. 11 | class Split : public XlaNode { 12 | public: 13 | Split(const torch::lazy::Value& input, std::vector split_sizes, 14 | int64_t dim); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | std::string ToString() const override; 21 | 22 | const std::vector& split_sizes() const { return split_sizes_; } 23 | 24 | int64_t dim() const { return dim_; } 25 | 26 | private: 27 | std::vector split_sizes_; 28 | int64_t dim_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_SPLIT_H_#pragma once -------------------------------------------------------------------------------- /test/test_mp_mesh_reduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_xla 3 | from torch_xla import runtime as xr 4 | import torch_xla.core.xla_model as xm 5 | 6 | 7 | def _test_scalar(): 8 | 9 | def reduce_add(vlist): 10 | return sum(vlist) 11 | 12 | svalue = 1.25 13 | rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_scalar', svalue, 14 | reduce_add) 15 | assert rvalue == svalue * xr.world_size() 16 | 17 | 18 | def _test_tensor(): 19 | 20 | def reduce_add(vlist): 21 | return torch.stack(vlist).sum(dim=0) 22 | 23 | tvalue = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32) 24 | rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_tensor', tvalue, 25 | reduce_add) 26 | assert rvalue.allclose(tvalue * xr.world_size()) 27 | 28 | 29 | def _mp_fn(index): 30 | _test_scalar() 31 | _test_tensor() 32 | 33 | 34 | if __name__ == '__main__': 35 | torch_xla.launch(_mp_fn, args=()) 36 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/embedding_bag.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class EmbeddingBag : public XlaNode { 11 | public: 12 | EmbeddingBag(const torch::lazy::Value& weight, 13 | const torch::lazy::Value& indices, 14 | const torch::lazy::Value& offsets, int64_t mode, 15 | const torch::lazy::Value& per_sample_weights, 16 | bool include_last_offset); 17 | 18 | std::string ToString() const override; 19 | 20 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 21 | 22 | XlaOpVector Lower(LoweringContext* loctx) const override; 23 | 24 | private: 25 | int64_t mode_; 26 | bool include_last_offset_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/shape_builder.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_SHAPE_BUILDER_H_ 2 | #define XLA_TORCH_XLA_CSRC_SHAPE_BUILDER_H_ 3 | 4 | #include 5 | 6 | #include "absl/types/span.h" 7 | #include "xla/shape.h" 8 | #include "xla/types.h" 9 | 10 | namespace torch_xla { 11 | 12 | class ShapeBuilder { 13 | public: 14 | explicit ShapeBuilder(xla::PrimitiveType type) : type_(type) {} 15 | 16 | ShapeBuilder& Add(const xla::Shape& shape, int64_t dim); 17 | 18 | ShapeBuilder& Add(const xla::Shape& shape, 19 | absl::Span dimensions); 20 | 21 | ShapeBuilder& Add(int64_t size); 22 | 23 | xla::Shape Build() const; 24 | 25 | private: 26 | struct ShapeDim { 27 | const xla::Shape* shape = nullptr; 28 | int64_t dim_or_size = -1; 29 | }; 30 | 31 | xla::PrimitiveType type_; 32 | std::vector dims_; 33 | }; 34 | 35 | } // namespace torch_xla 36 | 37 | #endif // XLA_TORCH_XLA_CSRC_SHAPE_BUILDER_H -------------------------------------------------------------------------------- /examples/eager/train_decoder_only_eager_multi_process.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_decoder_only_base import TrainDecoderOnlyBase 6 | 7 | import torch_xla 8 | import torch_xla.core.xla_model as xm 9 | 10 | 11 | class TrainDecoderXLADDP(TrainDecoderOnlyBase): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | # We want to run the step fn eagerly. 16 | self.compiled_step_fn = self.step_fn 17 | 18 | def run_optimizer(self): 19 | # optimizer_step will call `optimizer.step()` and all_reduce the gradident 20 | xm.optimizer_step(self.optimizer) 21 | 22 | 23 | def _mp_fn(index): 24 | import torch_xla 25 | torch_xla.experimental.eager_mode(True) 26 | xla_ddp = TrainDecoderXLADDP() 27 | xla_ddp.start_training() 28 | 29 | 30 | if __name__ == '__main__': 31 | torch_xla.launch(_mp_fn, args=()) 32 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/infer_output_shape.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INFER_OUTPUT_SHAPE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INFER_OUTPUT_SHAPE_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "xla/hlo/builder/xla_builder.h" 6 | 7 | namespace torch_xla { 8 | 9 | using LowerForShapeFn = 10 | std::function operands)>; 11 | using LowerForShapesFn = std::function( 12 | absl::Span operands)>; 13 | 14 | // Compute the output shape for the given input shapes and lowering. 15 | xla::Shape InferOutputShape(absl::Span input_shapes, 16 | const LowerForShapeFn& core_lowering_fn); 17 | 18 | xla::Shape InferOutputShapes(absl::Span input_shapes, 19 | const LowerForShapesFn& core_lowering_fn); 20 | 21 | } // namespace torch_xla 22 | 23 | #endif // XLA_TORCH_XLA_CSRC_OPS_INFER_OUTPUT_SHAPE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/runtime.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_RUNTIME_H_ 2 | #define XLA_CLIENT_RUNTIME_H_ 3 | 4 | #include "absl/base/attributes.h" 5 | #include "absl/status/statusor.h" 6 | 7 | #include "torch_xla/csrc/runtime/computation_client.h" 8 | 9 | namespace torch_xla::runtime { 10 | 11 | // Returns the ComputationClient singleton. 12 | const absl::StatusOr& GetComputationClient(); 13 | 14 | // Returns the ComputationClient singleton if it was successfully initialized. 15 | // Returns a nullptr if the ComputationClient wasn't initialized yet. 16 | // Throws an exception if the ComputationClient was initialized but the 17 | // initialization failed. 18 | ComputationClient* GetComputationClientIfInitialized(); 19 | 20 | // Runs the XRT local service, this will block the caller unitl the server 21 | // being stopped. 22 | void RunLocalService(uint64_t service_port); 23 | 24 | } // namespace torch_xla::runtime 25 | 26 | #endif 27 | -------------------------------------------------------------------------------- /examples/data_parallel/train_resnet_xla_ddp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_resnet_base import TrainResNetBase 6 | 7 | import torch_xla 8 | import torch_xla.core.xla_model as xm 9 | import torch_xla.runtime as xr 10 | 11 | 12 | class TrainResNetXLADDP(TrainResNetBase): 13 | 14 | def run_optimizer(self): 15 | # optimizer_step will call `optimizer.step()` and all_reduce the gradident 16 | xm.optimizer_step(self.optimizer) 17 | 18 | 19 | def _mp_fn(index): 20 | # cache init needs to happens inside the mp_fn. 21 | xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False) 22 | xla_ddp = TrainResNetXLADDP() 23 | xla_ddp.start_training() 24 | 25 | 26 | if __name__ == '__main__': 27 | print('consider using train_resnet_spmd_data_parallel.py instead to get better performance') 28 | torch_xla.launch(_mp_fn, args=()) 29 | -------------------------------------------------------------------------------- /scripts/dump_stacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # The following command is needed (as root) in order to enable GDB to attach 4 | # existing user processes: 5 | # 6 | # echo 0 > /proc/sys/kernel/yama/ptrace_scope 7 | # 8 | 9 | import argparse 10 | import stack_trace_parse as stp 11 | import subprocess 12 | 13 | 14 | def get_stacks(pid): 15 | return subprocess.check_output([ 16 | 'gdb', '-p', 17 | str(pid), '-batch', '-ex', 'thread apply all bt', '-ex', 'quit' 18 | ]).decode('utf-8') 19 | 20 | 21 | def dump_stacks(args): 22 | stacks = get_stacks(args.pid) 23 | stp.process_stack_lines(stacks.splitlines(True), args) 24 | 25 | 26 | if __name__ == '__main__': 27 | arg_parser = argparse.ArgumentParser() 28 | arg_parser.add_argument( 29 | 'pid', 30 | type=int, 31 | metavar='PID', 32 | help='The process ID whose stacks need to be dumped') 33 | args, files = arg_parser.parse_known_args() 34 | dump_stacks(args) 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/logsumexp.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LOGSUMEXP_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LOGSUMEXP_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class Logsumexp : public XlaNode { 11 | public: 12 | Logsumexp(const torch::lazy::Value& input, std::vector dimensions, 13 | bool keep_reduced_dimensions); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | const std::vector& dimensions() const { return dimensions_; } 22 | 23 | bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; } 24 | 25 | private: 26 | std::vector dimensions_; 27 | bool keep_reduced_dimensions_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_LOGSUMEXP_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/mse_loss_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_BACKWARD_H_ 3 | 4 | #include "xla/types.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | #include "torch_xla/csrc/reduction.h" 8 | 9 | namespace torch_xla { 10 | 11 | class MseLossBackward : public XlaNode { 12 | public: 13 | MseLossBackward(const torch::lazy::Value& grad_output, 14 | const torch::lazy::Value& input, 15 | const torch::lazy::Value& target, ReductionMode reduction); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | ReductionMode reduction() const { return reduction_; } 24 | 25 | private: 26 | ReductionMode reduction_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/permute.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_PERMUTE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_PERMUTE_H_ 3 | 4 | #include "absl/types/span.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class Permute : public XlaNode { 11 | public: 12 | Permute(const torch::lazy::Value& input, std::vector dims); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const std::vector& dims() const { return dims_; } 21 | 22 | static xla::Shape MakePermuteShape(const xla::Shape& source_shape, 23 | absl::Span permutation); 24 | 25 | private: 26 | // The permutation of dimensions. 27 | std::vector dims_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_PERMUTE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/replication_pad_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_BACKWARD_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ReplicationPadBackward : public XlaNode { 11 | public: 12 | ReplicationPadBackward(const torch::lazy::Value& gard_output, 13 | const torch::lazy::Value& input, 14 | std::vector padding); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | std::string ToString() const override; 21 | 22 | const std::vector& padding() const { return padding_; } 23 | 24 | private: 25 | std::vector padding_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_BACKWARD_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/debug/graph_saver.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import threading 4 | import torch_xla 5 | 6 | _SAVE_GRAPH_LOCK = threading.Lock() 7 | _SAVE_GRAPH_IDS = collections.defaultdict(dict) 8 | 9 | 10 | def save_tensors_graph(save_dir, name, tensors): 11 | fmt = os.environ.get('SAVE_GRAPH_FMT', 'text') 12 | if fmt == 'text': 13 | graph = torch_xla._XLAC._get_xla_tensors_text(tensors) 14 | elif fmt == 'dot': 15 | graph = torch_xla._XLAC._get_xla_tensors_dot(tensors) 16 | elif fmt == 'hlo': 17 | graph = torch_xla._XLAC._get_xla_tensors_hlo(tensors) 18 | else: 19 | raise RuntimeError('Invalid save graph format: {}'.format(fmt)) 20 | tid = threading.current_thread().ident 21 | with _SAVE_GRAPH_LOCK: 22 | tdict = _SAVE_GRAPH_IDS[tid] 23 | id = tdict.get(name, 0) 24 | tdict[name] = id + 1 25 | fname = '{}-{}-{}.{}'.format(name, tid, id, fmt) 26 | with open(os.path.join(save_dir, fname), 'w') as fd: 27 | fd.write(graph) 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/reflection_pad2d_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_BACKWARD_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ReflectionPad2dBackward : public XlaNode { 11 | public: 12 | ReflectionPad2dBackward(const torch::lazy::Value& gard_output, 13 | const torch::lazy::Value& input, 14 | std::vector padding); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | std::string ToString() const override; 21 | 22 | const std::vector& padding() const { return padding_; } 23 | 24 | private: 25 | std::vector padding_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_BACKWARD_H_ 31 | -------------------------------------------------------------------------------- /benchmarks/patches/mismatched_batch_size.patch: -------------------------------------------------------------------------------- 1 | diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py 2 | index 8593ba4c..57fef507 100644 3 | --- a/torchbenchmark/util/model.py 4 | +++ b/torchbenchmark/util/model.py 5 | @@ -182,6 +182,8 @@ class BenchmarkModel(metaclass=PostInitProcessor): 6 | 7 | # use the device suggestion on CUDA inference tests, key should be either eval_batch_size or train_batch_size 8 | device_batch_size_key = f"{self.test}_batch_size" 9 | + # A patch to making sure batch sizes are comparable. It's needed because xla device string is unrecognized. 10 | + current_device_name = 'NVIDIA A100-SXM4-40GB' 11 | if self.metadata and "devices" in self.metadata and current_device_name in self.metadata["devices"] \ 12 | and device_batch_size_key in self.metadata["devices"][current_device_name]: 13 | batch_size = self.metadata["devices"][current_device_name][device_batch_size_key] 14 | -------------------------------------------------------------------------------- /plugins/cpu/test_cpu_plugin.cpp: -------------------------------------------------------------------------------- 1 | #include "plugins/cpu/test_cpu_plugin.h" 2 | 3 | #include 4 | 5 | #include "xla/pjrt/c/pjrt_c_api.h" 6 | #include "xla/pjrt/c/pjrt_c_api_cpu_internal.h" 7 | 8 | // Use `test` as the platform name instead of `cpu` so torch_xla treats this 9 | // as an unknown device. 10 | PJRT_Error* test_platform_name(PJRT_Client_PlatformName_Args* args) { 11 | static const std::string platform_name = "test"; 12 | args->platform_name = platform_name.c_str(); 13 | args->platform_name_size = platform_name.size(); 14 | return nullptr; 15 | } 16 | 17 | const PJRT_Api* GetPjrtApi() { 18 | // HACK: The CPU client is created as a constexpr, so const-casting is 19 | // undefined behavior. Make a non-const copy of the struct so we can override 20 | // methods. Don't do this for a real plugin. 21 | static PJRT_Api pjrt_api = *pjrt::cpu_plugin::GetCpuPjrtApi(); 22 | pjrt_api.PJRT_Client_PlatformName = test_platform_name; 23 | 24 | return &pjrt_api; 25 | } 26 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.tab.test: -------------------------------------------------------------------------------- 1 | # ARGS: --format=tab 2 | ╒════════════════════════════╤════════════╤══════════╤══════════════╤══════════╕ 3 | │ Datetime(UTC) │ Speedup │ StdDev │ Speedup │ StdDev │ 4 | │ │ Inductor │ │ XLA+Dynamo │ │ 5 | │ │ over │ │ over │ │ 6 | │ │ Oldest │ │ Oldest │ │ 7 | │ │ Inductor │ │ Inductor │ │ 8 | ╞════════════════════════════╪════════════╪══════════╪══════════════╪══════════╡ 9 | │ 2023-11-11 05:32:18.723407 │ 1.00 │ 0.03 │ 0.85 │ 0.02 │ 10 | ├────────────────────────────┼────────────┼──────────┼──────────────┼──────────┤ 11 | │ 2023-11-12 05:32:18 │ 1.40 │ 0.03 │ 1.10 │ 0.02 │ 12 | ╘════════════════════════════╧════════════╧══════════╧══════════════╧══════════╛ 13 | -------------------------------------------------------------------------------- /scripts/metrics_to_tensorboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Add metrics images to tensorboard summary for easy viewing/comparisons 3 | 4 | import argparse 5 | import os 6 | 7 | from PIL import Image 8 | from pathlib import Path 9 | 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torchvision.transforms import ToTensor 12 | 13 | 14 | def generate_tensorboard_img_summary(logdir, imgdir): 15 | writer = SummaryWriter(logdir) 16 | all_metrics_graphs = list(Path(imgdir).rglob("*.png")) 17 | 18 | for img in all_metrics_graphs: 19 | tag = os.path.basename(img) 20 | img_tensor = ToTensor()(Image.open(img)) 21 | writer.add_image(tag, img_tensor, 0) 22 | writer.close() 23 | 24 | 25 | if __name__ == '__main__': 26 | arg_parser = argparse.ArgumentParser() 27 | arg_parser.add_argument('--logdir', type=str) 28 | arg_parser.add_argument('--imgdir', type=str) 29 | args = arg_parser.parse_args() 30 | generate_tensorboard_img_summary(args.logdir, args.imgdir) 31 | -------------------------------------------------------------------------------- /test/benchmarks/test_benchmark_experiment.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from benchmark_experiment import BenchmarkExperiment 4 | 5 | 6 | class BenchmarkExperimentTest(unittest.TestCase): 7 | 8 | def test_to_dict(self): 9 | be = BenchmarkExperiment("cpu", "PJRT", "some xla_flags", "openxla", None, 10 | "train", "123", False) 11 | actual = be.to_dict() 12 | self.assertEqual(9, len(actual)) 13 | self.assertEqual("cpu", actual["accelerator"]) 14 | self.assertTrue("accelerator_model" in actual) 15 | self.assertEqual("PJRT", actual["xla"]) 16 | self.assertEqual("some xla_flags", actual["xla_flags"]) 17 | self.assertEqual("openxla", actual["dynamo"]) 18 | self.assertEqual(None, actual["torch_xla2"]) 19 | self.assertEqual("train", actual["test"]) 20 | self.assertEqual("123", actual["batch_size"]) 21 | self.assertEqual(False, actual["enable_functionalization"]) 22 | 23 | 24 | if __name__ == '__main__': 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/quant_tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class QuantizeTensor : public XlaNode { 9 | public: 10 | QuantizeTensor(const torch::lazy::Value& input, 11 | const std::vector& scale, 12 | const std::vector& zero_point, int quant_min, 13 | int quant_max, const std::string& dtype, int axis); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | private: 22 | int quant_min_; 23 | int quant_max_; 24 | int axis_; 25 | std::string dtype_; 26 | std::vector scale_; 27 | std::vector zero_point_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/upsample_bilinear2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_BILINEAR2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_BILINEAR2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class UpsampleBilinear : public XlaNode { 11 | public: 12 | UpsampleBilinear(const torch::lazy::Value& input, 13 | std::vector output_size, bool align_corners); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& output_size() const { return output_size_; } 22 | 23 | bool align_corners() const { return align_corners_; } 24 | 25 | private: 26 | std::vector output_size_; 27 | bool align_corners_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_BILINEAR2D_H_ 33 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.histogram.tab.test: -------------------------------------------------------------------------------- 1 | # ARGS: --format=tab 2 | ╒════════════════════════════╤════════════╤════════════╤════════════╤══════════════╤══════════════╤══════════════╕ 3 | │ Datetime(UTC) │ Inductor │ Inductor │ Inductor │ XLA+Dynamo │ XLA+Dynamo │ XLA+Dynamo │ 4 | │ │ p95 │ p50 │ p5 │ p95 │ p50 │ p5 │ 5 | ╞════════════════════════════╪════════════╪════════════╪════════════╪══════════════╪══════════════╪══════════════╡ 6 | │ 2023-11-11 05:32:18.723407 │ 1.00 │ 1.00 │ 1.00 │ 0.98 │ 0.86 │ 0.74 │ 7 | ├────────────────────────────┼────────────┼────────────┼────────────┼──────────────┼──────────────┼──────────────┤ 8 | │ 2023-11-12 05:32:18 │ 1.51 │ 1.41 │ 1.31 │ 1.53 │ 1.17 │ 0.81 │ 9 | ╘════════════════════════════╧════════════╧════════════╧════════════╧══════════════╧══════════════╧══════════════╛ 10 | -------------------------------------------------------------------------------- /test/pjrt/test_dynamic_plugin_tpu.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | 3 | from absl.testing import absltest 4 | import torch_xla 5 | import torch_xla.core.xla_model as xm 6 | from torch_xla.experimental import plugins 7 | import torch_xla.runtime as xr 8 | from torch_xla._internal import tpu 9 | 10 | plugins.register_plugin('TPU', tpu.TpuPlugin()) 11 | plugins.use_dynamic_plugins() 12 | 13 | 14 | class TestDynamicTpuPlugin(absltest.TestCase): 15 | 16 | @classmethod 17 | def setUpClass(cls): 18 | xr.set_device_type('TPU') 19 | 20 | @staticmethod 21 | def _assert_tpus_exist(index=0): 22 | del index 23 | assert xm.xla_device_hw(torch_xla.device()) == 'TPU' 24 | 25 | def test_single_process(self): 26 | with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: 27 | executor.submit(self._assert_tpus_exist).result() 28 | 29 | def test_spawn(self): 30 | torch_xla.launch(self._assert_tpus_exist) 31 | 32 | 33 | if __name__ == '__main__': 34 | absltest.main() 35 | -------------------------------------------------------------------------------- /test/test_torch_distributed_fsdp_frozen_weight.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch_xla 5 | import torch_xla.core.xla_model as xm 6 | from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP 7 | 8 | 9 | def _mp_fn(index): 10 | dev = torch_xla.device() 11 | if xm.xla_device_hw(dev) not in ('TPU',): 12 | print('Default device {} is not a TPU device'.format(dev), file=sys.stderr) 13 | return 14 | 15 | model = nn.Linear(1024, 1024) 16 | model.weight.requires_grad = False # the weight param is frozen 17 | 18 | model = FSDP(model) # wrapping the linear module with FSDP 19 | 20 | input = torch.rand((2, 1024), device='xla') 21 | 22 | output = model(input) 23 | loss = torch.sum(output) 24 | loss.backward() 25 | assert not any(p._has_full_param for p in model.full_params), \ 26 | 'Expecting all the full params to be freed at this moment.' 27 | 28 | 29 | if __name__ == "__main__": 30 | torch_xla.launch(_mp_fn, args=()) 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/nll_loss.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_NLL_LOSS_H_ 2 | #define XLA_TORCH_XLA_CSRC_NLL_LOSS_H_ 3 | 4 | #include "absl/types/optional.h" 5 | #include "xla/hlo/builder/xla_builder.h" 6 | 7 | #include "torch_xla/csrc/reduction.h" 8 | 9 | namespace torch_xla { 10 | 11 | // Builds the NLLLoss for log-probabilities "logits" and class indices "labels". 12 | xla::XlaOp BuildNllLoss(xla::XlaOp logits, xla::XlaOp labels, xla::XlaOp weight, 13 | int ignore_index, ReductionMode reduction_mode); 14 | 15 | // Builds the NLLLoss gradient for log-probabilities "logits" and class indices 16 | // "labels". 17 | xla::XlaOp BuildNllLossBackward(xla::XlaOp grad_output, xla::XlaOp logits, 18 | xla::XlaOp labels, xla::XlaOp weight, 19 | xla::XlaOp total_weight, int ignore_index, 20 | ReductionMode reduction_mode); 21 | 22 | } // namespace torch_xla 23 | 24 | #endif // XLA_TORCH_XLA_CSRC_NLL_LOSS_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/user_computation.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_USER_COMPUTATION_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_USER_COMPUTATION_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | #include "torch_xla/csrc/runtime/computation_client.h" 6 | 7 | namespace torch_xla { 8 | 9 | class UserComputation : public XlaNode { 10 | public: 11 | UserComputation(torch::lazy::OpKind op, torch::lazy::OpList operands, 12 | runtime::ComputationClient::ComputationPtr computation); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const runtime::ComputationClient::ComputationPtr& computation() const { 21 | return computation_; 22 | } 23 | 24 | private: 25 | runtime::ComputationClient::ComputationPtr computation_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_USER_COMPUTATION_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/unselect.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UNSELECT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UNSELECT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Unselect : public XlaNode { 9 | public: 10 | Unselect(const torch::lazy::Value& target, const torch::lazy::Value& source, 11 | int64_t dim, int64_t start, int64_t end, int64_t stride); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t dim() const { return dim_; } 20 | 21 | int64_t start() const { return start_; } 22 | 23 | int64_t end() const { return end_; } 24 | 25 | int64_t stride() const { return stride_; } 26 | 27 | private: 28 | int64_t dim_; 29 | int64_t start_; 30 | int64_t end_; 31 | int64_t stride_; 32 | }; 33 | 34 | } // namespace torch_xla 35 | 36 | #endif // XLA_TORCH_XLA_CSRC_OPS_UNSELECT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/hardtanh_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_HARDTANH_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_HARDTANH_BACKWARD_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class HardtanhBackward : public XlaNode { 11 | public: 12 | HardtanhBackward(const torch::lazy::Value& grad_output, 13 | const torch::lazy::Value& input, const at::Scalar& min_val, 14 | const at::Scalar& max_val); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | at::Scalar min_val() const { return min_val_; } 23 | 24 | at::Scalar max_val() const { return max_val_; } 25 | 26 | private: 27 | at::Scalar min_val_; 28 | at::Scalar max_val_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_HARDTANH_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dequant_tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DEQUANT_TENSOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DEQUANT_TENSOR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DequantizeTensor : public XlaNode { 9 | public: 10 | DequantizeTensor(const torch::lazy::Value& input, 11 | const std::vector& scale, 12 | const std::vector& zero_point, int quant_min, 13 | int quant_max, const std::string& dtype, int axis); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | private: 22 | int quant_min_; 23 | int quant_max_; 24 | int axis_; 25 | std::string dtype_; 26 | std::vector scale_; 27 | std::vector zero_point_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/topk.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_TOPK_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_TOPK_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class TopK : public XlaNode { 9 | public: 10 | TopK(const torch::lazy::Value& input, int64_t k, int64_t dim, bool largest, 11 | bool sorted, bool stable); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t k() const { return k_; }; 20 | 21 | int64_t dim() const { return dim_; }; 22 | 23 | bool largest() const { return largest_; } 24 | 25 | bool sorted() const { return sorted_; } 26 | 27 | bool stable() const { return stable_; } 28 | 29 | private: 30 | int64_t k_; 31 | int64_t dim_; 32 | bool largest_; 33 | bool sorted_; 34 | bool stable_; 35 | }; 36 | 37 | } // namespace torch_xla 38 | 39 | #endif // XLA_TORCH_XLA_CSRC_OPS_TOPK_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/generic_slice.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_GENERIC_SLICE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_GENERIC_SLICE_H_ 3 | 4 | #include "absl/types/span.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class GenericSlice : public XlaNode { 11 | public: 12 | GenericSlice(const torch::lazy::Value& input, 13 | absl::Span base_indices, 14 | absl::Span sizes); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | std::string ToString() const override; 21 | 22 | const std::vector& base_indices() const { return base_indices_; } 23 | 24 | const std::vector& sizes() const { return sizes_; } 25 | 26 | private: 27 | std::vector base_indices_; 28 | std::vector sizes_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_GENERIC_SLICE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/index_put.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INDEX_PUT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INDEX_PUT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class IndexPut : public XlaNode { 9 | public: 10 | IndexPut(const torch::lazy::Value& base, const torch::lazy::Value& indices, 11 | int64_t start_dim, const torch::lazy::Value& values, 12 | bool accumulate); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | int64_t start_dim() const { return start_dim_; } 21 | 22 | bool accumulate() const { return accumulate_; } 23 | 24 | private: 25 | // The dimension number at which indexing starts. 26 | int64_t start_dim_; 27 | // Whether to accumulate instead of set. 28 | bool accumulate_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_INDEX_PUT_H_ -------------------------------------------------------------------------------- /.devcontainer/tpu-internal/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tpu-internal", 3 | "image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 4 | "runArgs": [ 5 | "--privileged", 6 | "--net=host", 7 | "--shm-size=16G" 8 | ], 9 | "containerEnv": { 10 | "BAZEL_REMOTE_CACHE": "1", 11 | "SILO_NAME": "cache-silo-${localEnv:USER}-tpuvm-312" 12 | }, 13 | "initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 14 | "customizations": { 15 | "vscode": { 16 | "extensions": [ 17 | "llvm-vs-code-extensions.vscode-clangd", 18 | "ms-vscode.cpptools-themes", 19 | "BazelBuild.vscode-bazel", 20 | "StackBuild.bazel-stack-vscode", 21 | "StackBuild.bazel-stack-vscode-cc", 22 | "xaver.clang-format", 23 | "ryanluker.vscode-coverage-gutters", 24 | "ms-azuretools.vscode-docker", 25 | "ms-python.python", 26 | "eeyore.yapf" 27 | ] 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/diagonal_view_update.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_VIEW_UPDATE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_VIEW_UPDATE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DiagonalViewUpdate : public XlaNode { 9 | public: 10 | DiagonalViewUpdate(const torch::lazy::Value& target, 11 | const torch::lazy::Value& input, int64_t offset, 12 | int64_t dim1, int64_t dim2); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | int64_t offset() const { return offset_; } 21 | 22 | int64_t dim1() const { return dim1_; } 23 | 24 | int64_t dim2() const { return dim2_; } 25 | 26 | private: 27 | int64_t offset_; 28 | int64_t dim1_; 29 | int64_t dim2_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_VIEW_UPDATE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/xla_op_builder.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_XLA_OP_BUILDER_H_ 2 | #define XLA_TORCH_XLA_CSRC_XLA_OP_BUILDER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include "xla/hlo/builder/xla_builder.h" 11 | 12 | namespace torch_xla { 13 | namespace op_builder { 14 | 15 | using BuilderPtr = std::shared_ptr; 16 | 17 | struct Op { 18 | Op(BuilderPtr builder, xla::XlaOp op) 19 | : builder(std::move(builder)), op(std::move(op)) {} 20 | 21 | BuilderPtr builder; 22 | xla::XlaOp op; 23 | }; 24 | 25 | using OpPtr = std::shared_ptr; 26 | 27 | py::object ShapeToPyShape(const xla::Shape& shape); 28 | 29 | xla::Shape PyShapeToShape(py::object shape); 30 | 31 | OpPtr CreateOp(BuilderPtr builder, const std::string& opname, 32 | const std::vector& operands, py::dict args); 33 | 34 | } // namespace op_builder 35 | } // namespace torch_xla 36 | 37 | #endif // XLA_TORCH_XLA_CSRC_XLA_OP_BUILDER_H_ -------------------------------------------------------------------------------- /scripts/run_bazel_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" 3 | export XRT_WORKERS="localservice:0;grpc://localhost:40934" 4 | export XLA_EXPERIMENTAL="nonzero:masked_select" 5 | 6 | BAZEL_REMOTE_CACHE_CONFIG="--config=remote_cache --remote_default_exec_properties=cache-silo-key=cache-silo-coverage" 7 | if [ ! -z "$GCLOUD_SERVICE_KEY_FILE" ]; then 8 | file_size=$(stat -c%s "$GCLOUD_SERVICE_KEY_FILE") 9 | if [ "$file_size" -le 1 ]; then 10 | BAZEL_REMOTE_CACHE_CONFIG="" 11 | fi 12 | fi 13 | 14 | bazel coverage $BAZEL_REMOTE_CACHE_CONFIG //... 15 | cp "$(bazel info output_path)/_coverage/_coverage_report.dat" /tmp/cov_xrt.dat 16 | 17 | export PJRT_DEVICE="CPU" 18 | bazel coverage $BAZEL_REMOTE_CACHE_CONFIG //test/... 19 | cp "$(bazel info output_path)/_coverage/_coverage_report.dat" /tmp/cov_pjrt.dat 20 | 21 | # requires `apt-get install lcov` 22 | lcov --add-tracefile /tmp/cov_xrt.dat -a /tmp/cov_pjrt.dat -o /tmp/merged.dat 23 | genhtml /tmp/merged.dat -o CodeCoveragn m -------------------------------------------------------------------------------- /torch_xla/csrc/ops/collective_permute.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ 3 | 4 | #include "torch_xla/csrc/cross_replica_reduces.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class CollectivePermute : public XlaNode { 10 | public: 11 | CollectivePermute( 12 | const torch::lazy::Value& input, const torch::lazy::Value& token, 13 | std::vector> source_target_pairs); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | const std::vector>& source_target_pairs() const { 22 | return source_target_pairs_; 23 | } 24 | 25 | private: 26 | std::vector> source_target_pairs_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/log_softmax.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | // IR node for log(softmax) operation. 13 | class LogSoftmax : public XlaNode { 14 | public: 15 | LogSoftmax(const torch::lazy::Value& input, int64_t dim, 16 | std::optional dtype, 17 | std::vector&& shapes); 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | std::string ToString() const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | // The dimension along which the result is computed. 29 | int64_t dim_; 30 | std::optional dtype_; 31 | }; 32 | 33 | } // namespace torch_xla 34 | 35 | #endif // XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_H_ 36 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/nll_loss.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS_H_ 3 | 4 | #include "absl/types/optional.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | #include "torch_xla/csrc/reduction.h" 8 | 9 | namespace torch_xla { 10 | 11 | class NllLoss : public XlaNode { 12 | public: 13 | NllLoss(const torch::lazy::Value& logits, const torch::lazy::Value& labels, 14 | const absl::optional& weight, 15 | ReductionMode reduction, int ignore_index); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | ReductionMode reduction() const { return reduction_; } 24 | 25 | int ignore_index() const { return ignore_index_; } 26 | 27 | private: 28 | ReductionMode reduction_; 29 | int ignore_index_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/not_supported.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/ops/not_supported.h" 2 | 3 | #include "torch_xla/csrc/lowering_context.h" 4 | #include "torch_xla/csrc/ops/xla_ops.h" 5 | #include "torch_xla/csrc/runtime/debug_macros.h" 6 | 7 | namespace torch_xla { 8 | 9 | NotSupported::NotSupported(std::string description, xla::Shape shape) 10 | : XlaNode(xla_not_supported, std::move(shape), /*num_outputs=*/1, 11 | torch::lazy::MHash(description)), 12 | description_(std::move(description)) {} 13 | 14 | torch::lazy::NodePtr NotSupported::Clone(torch::lazy::OpList operands) const { 15 | return torch_xla::MakeNode(description_, xla_shape()); 16 | } 17 | 18 | XlaOpVector NotSupported::Lower(LoweringContext* /* loctx */) const { 19 | XLA_ERROR() << "Node not supported: " << ToString(); 20 | } 21 | 22 | std::string NotSupported::ToString() const { 23 | std::stringstream ss; 24 | ss << XlaNode::ToString() << ", description=" << description_; 25 | return ss.str(); 26 | } 27 | 28 | } // namespace torch_xla 29 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/pjrt_registry.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ 2 | #define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ 3 | 4 | #include "xla/pjrt/pjrt_client.h" 5 | #include "xla/pjrt/pjrt_common.h" 6 | 7 | #include "torch_xla/csrc/runtime/xla_coordinator.h" 8 | 9 | namespace torch_xla { 10 | namespace runtime { 11 | 12 | class PjRtPlugin { 13 | public: 14 | virtual std::string library_path() const = 0; 15 | 16 | virtual const std::unordered_map 17 | client_create_options() const = 0; 18 | 19 | virtual bool requires_xla_coordinator() const = 0; 20 | }; 21 | 22 | void RegisterPjRtPlugin(std::string name, 23 | std::shared_ptr plugin); 24 | 25 | absl::StatusOr, 26 | std::unique_ptr>> 27 | InitializePjRt(const std::string& device_type); 28 | 29 | } // namespace runtime 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_CLIENT_INITIALIZE_PJRT_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/diagonal.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Diagonal : public XlaNode { 9 | public: 10 | Diagonal(const torch::lazy::Value& input, int64_t offset, int64_t dim1, 11 | int64_t dim2); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t offset() const { return offset_; } 20 | 21 | int64_t dim1() const { return dim1_; } 22 | 23 | int64_t dim2() const { return dim2_; } 24 | 25 | static xla::Shape MakeDiagonalShape(const xla::Shape& shape, int64_t offset, 26 | int64_t dim1, int64_t dim2); 27 | 28 | private: 29 | int64_t offset_; 30 | int64_t dim1_; 31 | int64_t dim2_; 32 | }; 33 | 34 | } // namespace torch_xla 35 | 36 | #endif // XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/flip.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/ops/flip.h" 2 | 3 | #include "xla/hlo/builder/xla_builder.h" 4 | 5 | #include "torch_xla/csrc/lowering_context.h" 6 | 7 | namespace torch_xla { 8 | 9 | Flip::Flip(const torch::lazy::Value& input, std::vector dims) 10 | : XlaNode(torch::lazy::OpKind(at::aten::flip), {input}, GetXlaShape(input), 11 | /*num_outputs=*/1, torch::lazy::MHash(dims)), 12 | dims_(std::move(dims)) {} 13 | 14 | torch::lazy::NodePtr Flip::Clone(torch::lazy::OpList operands) const { 15 | return torch_xla::MakeNode(operands.at(0), dims_); 16 | } 17 | 18 | XlaOpVector Flip::Lower(LoweringContext* loctx) const { 19 | xla::XlaOp input = loctx->GetOutputOp(operand(0)); 20 | xla::XlaOp output = xla::Rev(input, dims_); 21 | return ReturnOp(output, loctx); 22 | } 23 | 24 | std::string Flip::ToString() const { 25 | std::stringstream ss; 26 | ss << XlaNode::ToString() << ", dims=(" << absl::StrJoin(dims_, ", ") << ")"; 27 | return ss.str(); 28 | } 29 | 30 | } // namespace torch_xla 31 | -------------------------------------------------------------------------------- /bazel/rules_def.bzl: -------------------------------------------------------------------------------- 1 | """Rules that simplify deps and compiler configuration for PyTorch/XLA.""" 2 | def ptxla_cc_library( 3 | deps = [], 4 | copts = [], 5 | **kwargs): 6 | native.cc_library( 7 | copts = copts + ["-isystemexternal/torch"], # Required for system includes. 8 | deps = deps + [ 9 | "@torch//:headers", 10 | "@torch//:runtime_headers", 11 | ], 12 | **kwargs 13 | ) 14 | 15 | def ptxla_cc_test( 16 | deps, 17 | copts = [], 18 | **kwargs): 19 | native.cc_test( 20 | linkstatic = True, 21 | copts = copts + [ 22 | "-isystemexternal/torch", # Required for system includes. 23 | ], 24 | deps = deps + [ 25 | "@pybind11//:pybind11_embed", # libpython 26 | "@torch//:headers", 27 | "@torch//:libc10", 28 | "@torch//:libtorch", 29 | "@torch//:libtorch_cpu", 30 | "@torch//:libtorch_python", 31 | ], 32 | **kwargs 33 | ) 34 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/nll_loss2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS2D_H_ 3 | 4 | #include "absl/types/optional.h" 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | #include "torch_xla/csrc/reduction.h" 8 | 9 | namespace torch_xla { 10 | 11 | class NllLoss2d : public XlaNode { 12 | public: 13 | NllLoss2d(const torch::lazy::Value& logits, const torch::lazy::Value& labels, 14 | const absl::optional& weight, 15 | ReductionMode reduction, int ignore_index); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | ReductionMode reduction() const { return reduction_; } 24 | 25 | int ignore_index() const { return ignore_index_; } 26 | 27 | private: 28 | ReductionMode reduction_; 29 | int ignore_index_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS2D_H_ 35 | -------------------------------------------------------------------------------- /contrib/vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "bsv.bazel.buildFlags": [ 3 | "--config=compdb", 4 | ], 5 | "bsv.cc.compdb.targets": [ 6 | "//torch_xla/csrc/runtime:all", 7 | "//torch_xla/csrc:all", 8 | "//test/cpp:all", 9 | ], 10 | "coverage-gutters.coverageBaseDir": ".", 11 | "coverage-gutters.showLineCoverage": false, 12 | "coverage-gutters.showGutterCoverage": true, 13 | "coverage-gutters.coverageReportFileName": "./genhtml/index.html", 14 | "coverage-gutters.coverageFileNames": [ 15 | "./bazel-out/_coverage/_coverage_report.dat" 16 | ], 17 | "git.detectSubmodules": false, 18 | "[python]": { 19 | "editor.defaultFormatter": "eeyore.yapf", 20 | "editor.formatOnSave": true, 21 | }, 22 | "python.analysis.exclude": [ 23 | "**/third_party", 24 | "**/build", 25 | "**/__pycache__", 26 | "**/.git", 27 | ], 28 | "[cpp]": { 29 | "editor.defaultFormatter": "xaver.clang-format", 30 | "editor.formatOnSave": true, 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/std.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_STD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_STD_H_ 3 | 4 | #include 5 | 6 | #include "xla/types.h" 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class Std : public XlaNode { 13 | public: 14 | Std(const torch::lazy::Value& input, std::vector dimensions, 15 | bool keep_reduced_dimensions, double correction); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | const std::vector& dimensions() const { return dimensions_; } 24 | 25 | bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; } 26 | 27 | double correction() const { return correction_; } 28 | 29 | private: 30 | std::vector dimensions_; 31 | bool keep_reduced_dimensions_; 32 | double correction_; 33 | }; 34 | 35 | } // namespace torch_xla 36 | 37 | #endif // XLA_TORCH_XLA_CSRC_OPS_STD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/var.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_VAR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_VAR_H_ 3 | 4 | #include 5 | 6 | #include "xla/types.h" 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class Var : public XlaNode { 13 | public: 14 | Var(const torch::lazy::Value& input, std::vector dimensions, 15 | double correction, bool keep_reduced_dimensions); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | const std::vector& dimensions() const { return dimensions_; } 24 | 25 | bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; } 26 | 27 | double correction() const { return correction_; } 28 | 29 | private: 30 | std::vector dimensions_; 31 | double correction_; 32 | bool keep_reduced_dimensions_; 33 | }; 34 | 35 | } // namespace torch_xla 36 | 37 | #endif // XLA_TORCH_XLA_CSRC_OPS_VAR_H_ -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.lazytensor_tab.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends inductor openxla+lazytensor --format=tab 2 | ╒════════════════════════════╤════════════╤══════════╤══════════════════╤══════════╕ 3 | │ Datetime(UTC) │ Speedup │ StdDev │ Speedup │ StdDev │ 4 | │ │ Inductor │ │ XLA+LazyTensor │ │ 5 | │ │ over │ │ over │ │ 6 | │ │ Oldest │ │ Oldest │ │ 7 | │ │ Inductor │ │ Inductor │ │ 8 | ╞════════════════════════════╪════════════╪══════════╪══════════════════╪══════════╡ 9 | │ 2023-11-11 05:32:18.723407 │ 1.00 │ 0.03 │ │ │ 10 | ├────────────────────────────┼────────────┼──────────┼──────────────────┼──────────┤ 11 | │ 2023-11-12 05:32:18 │ 1.40 │ 0.03 │ 0.41 │ 0.00 │ 12 | ╘════════════════════════════╧════════════╧══════════╧══════════════════╧══════════╛ 13 | -------------------------------------------------------------------------------- /.circleci/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | source ./xla_env 6 | source .circleci/common.sh 7 | 8 | PYTORCH_DIR=/tmp/pytorch 9 | XLA_DIR=$PYTORCH_DIR/xla 10 | USE_COVERAGE="${USE_COVERAGE:-0}" 11 | 12 | # Needs to be kept in sync with .jenkins/pytorch/common_utils.sh in pytorch/pytorch. 13 | TORCHVISION_COMMIT="$(cat $PYTORCH_DIR/.github/ci_commit_pins/vision.txt)" 14 | 15 | function pip_install() { 16 | # retry 3 times 17 | # old versions of pip don't have the "--progress-bar" flag 18 | pip install --progress-bar off "$@" || pip install --progress-bar off "$@" || pip install --progress-bar off "$@" ||\ 19 | pip install "$@" || pip install "$@" || pip install "$@" 20 | } 21 | 22 | function install_torchvision() { 23 | pip_install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$TORCHVISION_COMMIT" 24 | } 25 | 26 | install_torchvision 27 | 28 | export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json" 29 | export SILO_NAME='cache-silo-ci-dev-3.8_cuda_12.1' # cache bucket for CI 30 | run_torch_xla_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE 31 | -------------------------------------------------------------------------------- /test/spmd/test_xla_sharding_hlo.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import unittest 4 | from unittest.mock import patch 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch_xla 10 | import torch_xla.runtime as xr 11 | import torch_xla.core.xla_model as xm 12 | import torch_xla.distributed.spmd as xs 13 | 14 | import test_xla_sharding_base 15 | 16 | 17 | class XlaShardingHloTest(test_xla_sharding_base.XlaShardingTest): 18 | 19 | @classmethod 20 | def setUpClass(cls): 21 | super().setUpClass() 22 | 23 | @patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) 24 | def test_xla_sharded_hlo_dump_post_optimizations(self): 25 | t1 = torch.randn(1, 128).to('xla') 26 | t2 = torch.randn(128, 1).to('xla') 27 | xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1)) 28 | 29 | t3 = t1 @ t2 30 | hlo = torch_xla._XLAC._get_xla_tensors_hlo([t3]) 31 | if self.n_devices > 1: 32 | self.assertIn('all-reduce', hlo) 33 | 34 | 35 | if __name__ == '__main__': 36 | test = unittest.main() 37 | sys.exit(0 if test.result.wasSuccessful() else 1) 38 | -------------------------------------------------------------------------------- /test/test_mp_collective_permute.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch_xla 4 | from torch_xla import runtime as xr 5 | import torch_xla.core.xla_model as xm 6 | 7 | 8 | def _mp_fn(index): 9 | device = torch_xla.device() 10 | if xm.xla_device_hw(device) in ['TPU', 'NEURON']: 11 | world_size = xr.world_size() 12 | ordinal = xr.global_ordinal() 13 | value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device) 14 | pairs = [] 15 | for i in range(1, world_size): 16 | pairs.append([i - 1, i]) 17 | pairs.append([world_size - 1, 0]) 18 | result_tensor = xm.collective_permute(value, pairs) 19 | 20 | result = result_tensor.cpu().tolist() 21 | expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100 22 | 23 | if result != expected: 24 | print(f"Wrong result from core {ordinal}: {result}", file=sys.stderr) 25 | sys.exit(1) 26 | else: 27 | print(f"Default device {device} is not a supported device", file=sys.stderr) 28 | 29 | 30 | if __name__ == '__main__': 31 | torch_xla.launch(_mp_fn, args=()) 32 | -------------------------------------------------------------------------------- /torch_xla/_internal/c10d_registration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_xla.runtime as xr 3 | import torch_xla.core.xla_model as xm 4 | from typing import List, Optional, TypedDict 5 | 6 | 7 | # "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", 8 | @torch.library.impl("_c10d_functional::broadcast", "XLA") 9 | def broadcast_xla(self: torch.Tensor, 10 | src: int, 11 | tag: str, 12 | ranks: Optional[List] = None, 13 | group_size: Optional[int] = None) -> torch.Tensor: 14 | assert group_size == None, "currently does not support group_size" 15 | # xm.collective_broadcast perform an inplace update, but 16 | # we want an functional implementation here. 17 | with torch.no_grad(): 18 | scale = torch.tensor( 19 | 1 if xr.global_ordinal() == src else 0, dtype=self.dtype) 20 | # Transfer scale tensor as device data instead of constant 1 or 0. 21 | xscale = xm.send_cpu_data_to_device(scale, self.device)[0] 22 | return xm.all_reduce(xm.REDUCE_SUM, xscale * self, groups=ranks) 23 | -------------------------------------------------------------------------------- /experimental/reference_models/README.md: -------------------------------------------------------------------------------- 1 | This directory will contain a list of reference models that 2 | we have optimized and runs well on TPU. 3 | 4 | Contents of this directory is organized in the following way: 5 | 6 | * Every subdirectory is a self-contained model, as a seperate pip package. 7 | 8 | * Each subdirectory must has a README indicating: 9 | ** is this training or inference 10 | ** on what devices it has been tested / developed 11 | ** instructions on running. 12 | 13 | * Every subdirectory contains it's own set of shell scripts do with all the flags 14 | set for the best performance that we turned, be it training or inference. 15 | 16 | * Each subdirectory can specify their own dependencies, and can depend on models / layers 17 | defined in well-known OSS libraries, such as HuggingFace transformers. But should ideally not depend on each other. 18 | 19 | * (Optional) Each model can also have a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU. 20 | 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/custom_sharding.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CUSTOM_SHARDING_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CUSTOM_SHARDING_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class CustomSharding : public XlaNode { 9 | public: 10 | // The following enum represents the custom_call_target being 11 | // passed to xla builder. The actual sharding will still be 12 | // attached to the XLATensor. 13 | enum class Type { 14 | kSharding, 15 | kSPMDFullToShardShape, 16 | kSPMDShardToFullShape, 17 | }; 18 | 19 | // Make a custom call to Sharding. 20 | CustomSharding(const torch::lazy::Value& input, 21 | const xla::Shape& output_shape, const Type& type); 22 | 23 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 24 | 25 | XlaOpVector Lower(LoweringContext* loctx) const override; 26 | 27 | std::string ToString() const override; 28 | 29 | Type type; 30 | xla::Shape output_shape; 31 | }; 32 | 33 | } // namespace torch_xla 34 | 35 | #endif // XLA_TORCH_XLA_CSRC_OPS_CUSTOM_SHARDING_H_ 36 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/std_mean.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_STD_MEAN_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_STD_MEAN_H_ 3 | 4 | #include 5 | 6 | #include "xla/types.h" 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class StdMean : public XlaNode { 13 | public: 14 | StdMean(const torch::lazy::Value& input, std::vector dimensions, 15 | double correction, bool keep_reduced_dimensions); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | const std::vector& dimensions() const { return dimensions_; } 24 | 25 | double correction() const { return correction_; } 26 | 27 | bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; } 28 | 29 | private: 30 | std::vector dimensions_; 31 | double correction_; 32 | bool keep_reduced_dimensions_; 33 | }; 34 | 35 | } // namespace torch_xla 36 | 37 | #endif // XLA_TORCH_XLA_CSRC_OPS_STD_MEAN_H_ --------------------------------------------------------------------------------