├── .bazelrc ├── .bazelversion ├── .circleci ├── build.sh ├── common.sh ├── download_llvm_raw.sh ├── setup_ci_environment.sh ├── test.sh └── test_xrt.sh ├── .clang-format ├── .clangd ├── .devcontainer ├── gpu-internal │ └── devcontainer.json ├── tpu-contributor │ └── devcontainer.json └── tpu-internal │ └── devcontainer.json ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE.md ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── documentation.md │ ├── feature-request.md │ └── questions-help-support.md ├── ci.md ├── scripts │ └── run_tests.sh ├── stale.yml ├── upstream │ ├── Dockerfile │ ├── install_conda.sh │ └── install_valgrind.sh └── workflows │ ├── _build_torch_xla.yml │ ├── _check_code_changes.yml │ ├── _docs.yml │ ├── _test.yml │ ├── _tpu_ci.yml │ ├── build_and_test.yml │ ├── build_upstream_image.yml │ ├── lintercheck.yml │ ├── openxla_pin_update_weekly.yml │ ├── setup │ └── action.yml │ └── torchax.yml ├── .gitignore ├── .gitmodules ├── .style.yapf ├── API_GUIDE.md ├── BUILD ├── CODEGEN_MIGRATION_GUIDE.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── OP_LOWERING_GUIDE.md ├── README.md ├── WORKSPACE ├── bazel ├── BUILD ├── dependencies.bzl ├── nlohmann_json.BUILD ├── rules_def.bzl └── torch.BUILD ├── benchmarks ├── README.md ├── __init__.py ├── aggregate.py ├── bench.py ├── benchmark_experiment.py ├── benchmark_model.py ├── check_xla_device.py ├── experiment_runner.py ├── llama.py ├── matmul_bench.py ├── microbench.py ├── nightly.sh ├── patches │ └── mismatched_batch_size.patch ├── requirements.txt ├── result_analyzer.py ├── run_benchmark.sh ├── run_single_graph_bm.sh ├── run_top_tier_bm.sh ├── tiers.py ├── torchbench_model.py ├── util.py └── verifier.py ├── build_util.py ├── codegen ├── BUILD ├── fix_includes.sh ├── lazy_tensor_generator.py └── xla_native_functions.yaml ├── configuration.yaml ├── contrib ├── kaggle │ ├── distributed-pytorch-xla-basics-with-pjrt.ipynb │ └── pytorch-xla-2-0-on-kaggle.ipynb └── vscode │ └── settings.json ├── docker ├── Dockerfile ├── cloudbuild.yaml ├── common.sh ├── debug_cloudbuild.yaml ├── debug_image_cleanup.sh ├── docker-entrypoint.sh └── gcb_pool.yaml ├── docs ├── .gitattributes ├── README.md ├── _static │ └── img │ │ ├── IRgraph_markstep.png │ │ ├── IRgraph_no_markstep.png │ │ ├── bit_layout.svg │ │ ├── ci_test_dependency.png │ │ ├── ci_test_dependency_gpu.png │ │ ├── ddp_md_mnist_with_real_data.png │ │ ├── dynamic_shape_mlp_perf.png │ │ ├── gpt2_2b_step_time_vs_batch.png │ │ ├── gpt2_v4_8_mfu_batch.png │ │ ├── image-1.png │ │ ├── image-2.png │ │ ├── image-3.png │ │ ├── image-4.png │ │ ├── image.png │ │ ├── llama2_2b_bsz128.png │ │ ├── mesh_spmd2.png │ │ ├── perf_auto_vs_manual.png │ │ ├── pytorchXLA_flow.svg │ │ ├── spmd_debug_1.png │ │ ├── spmd_debug_1_light.png │ │ ├── spmd_debug_2.png │ │ ├── spmd_debug_2_light.png │ │ ├── spmd_mode.png │ │ ├── torchbench_pjrt_vs_xrt.svg │ │ └── torchbench_tfrt_vs_se.svg ├── docs_build.sh ├── jupyter_execute │ └── tutorials │ │ └── precision_tutorial.ipynb ├── requirements.txt └── source │ ├── _static │ ├── css │ │ └── pytorch_theme.css │ └── img │ │ ├── IRgraph_markstep.png │ │ ├── IRgraph_no_markstep.png │ │ ├── bit_layout.svg │ │ ├── ci_test_dependency.png │ │ ├── ci_test_dependency_gpu.png │ │ ├── ddp_md_mnist_with_real_data.png │ │ ├── debugger0_pack.png │ │ ├── debugger1_file.png │ │ ├── debugger2_breakpoint.png │ │ ├── debugger3_session.png │ │ ├── debugger4_active.png │ │ ├── debugger5_break.png │ │ ├── dist_op_stack.png │ │ ├── dynamic_shape_mlp_perf.png │ │ ├── gpt2_2b_step_time_vs_batch.png │ │ ├── gpt2_v4_8_mfu_batch.png │ │ ├── image-1.png │ │ ├── image-2.png │ │ ├── image-3.png │ │ ├── image-4.png │ │ ├── image.png │ │ ├── llama2_2b_bsz128.png │ │ ├── mesh_spmd2.png │ │ ├── perf_auto_vs_manual.png │ │ ├── pytorch-logo-dark.svg │ │ ├── pytorchXLA_flow.svg │ │ ├── spmd_debug_1.png │ │ ├── spmd_debug_1_light.png │ │ ├── spmd_debug_2.png │ │ ├── spmd_debug_2_light.png │ │ ├── spmd_mode.png │ │ ├── torchbench_pjrt_vs_xrt.svg │ │ └── torchbench_tfrt_vs_se.svg │ ├── accelerators │ ├── gpu.md │ └── tpu.md │ ├── conf.py │ ├── contribute │ ├── bazel.md │ ├── codegen_migration.md │ ├── configure-environment.md │ ├── cpp_debugger.md │ ├── op_lowering.md │ └── plugins.md │ ├── features │ ├── pallas.md │ ├── scan.md │ ├── stablehlo.md │ ├── torch_distributed.md │ └── triton.md │ ├── index.rst │ ├── learn │ ├── _pjrt.md │ ├── api-guide.rst │ ├── dynamic_shape.md │ ├── eager.md │ ├── pytorch-on-xla-devices.md │ ├── trace-vs-execution-time.md │ ├── troubleshoot.md │ └── xla-overview.md │ ├── notes │ └── source_of_recompilation.md │ ├── perf │ ├── amp.md │ ├── assume_pure.md │ ├── ddp.md │ ├── dynamo.md │ ├── fori_loop.md │ ├── fsdp_collectives.md │ ├── fsdp_spmd.md │ ├── quantized_ops.md │ ├── recompilation.md │ ├── spmd_advanced.md │ ├── spmd_basic.md │ ├── spmd_distributed_checkpoint.md │ └── spmd_gpu.md │ └── tutorials │ ├── precision_tutorial.ipynb │ └── precision_tutorial.py ├── examples ├── README.md ├── data_parallel │ ├── README.md │ ├── train_resnet_ddp.py │ ├── train_resnet_spmd_data_parallel.py │ └── train_resnet_xla_ddp.py ├── debug │ ├── train_resnet_benchmark.py │ └── train_resnet_profile.py ├── decoder_only_model.py ├── eager │ ├── train_decoder_only_eager.py │ ├── train_decoder_only_eager_multi_process.py │ ├── train_decoder_only_eager_spmd_data_parallel.py │ └── train_decoder_only_eager_with_compile.py ├── flash_attention │ ├── train_decoder_only_flash_attention.py │ └── train_decoder_only_flash_attention_fsdp_v2.py ├── fsdp │ ├── README.md │ ├── train_decoder_only_fsdp_v2.py │ └── train_resnet_fsdp_auto_wrap.py ├── host_offloading │ └── README.md ├── scan │ ├── README.md │ ├── decoder_with_scan.py │ └── scan_examples.py ├── train_decoder_only_base.py ├── train_resnet_amp.py └── train_resnet_base.py ├── experimental └── reference_models │ ├── README.md │ └── sdxl_inference │ ├── README.md │ ├── astronaut_rides_horse.png │ ├── sdxl.py │ └── sdxl_beginning.py ├── external ├── infra ├── Terraform.md ├── ansible │ ├── .ansible-lint │ ├── Dockerfile │ ├── README.md │ ├── ansible.cfg │ ├── config │ │ ├── apt.yaml │ │ ├── cuda_deps.yaml │ │ ├── env.yaml │ │ ├── pip.yaml │ │ └── vars.yaml │ ├── development.Dockerfile │ ├── e2e_tests.Dockerfile │ ├── playbook.yaml │ └── roles │ │ ├── bazel │ │ ├── defaults │ │ │ └── main.yaml │ │ └── tasks │ │ │ ├── main.yaml │ │ │ └── tests.yaml │ │ ├── build_plugin │ │ └── tasks │ │ │ └── main.yaml │ │ ├── build_srcs │ │ └── tasks │ │ │ ├── main.yaml │ │ │ └── tests.yaml │ │ ├── configure_env │ │ └── tasks │ │ │ └── main.yaml │ │ ├── fetch_srcs │ │ ├── defaults │ │ │ └── main.yaml │ │ └── tasks │ │ │ ├── main.yaml │ │ │ └── tests.yaml │ │ └── install_deps │ │ └── tasks │ │ └── main.yaml ├── terraform_modules │ ├── apply_terraform_trigger │ │ └── apply_terraform_trigger.tf │ ├── arc_v4_container_cluster │ │ ├── README.md │ │ ├── arc-values.yaml │ │ ├── main.tf │ │ └── variables.tf │ ├── build_trigger │ │ └── build_trigger.tf │ ├── docker_registry │ │ └── docker_registry.tf │ ├── storage_bucket │ │ └── storage_bucket.tf │ ├── trigger_schedule_account │ │ └── service_account.tf │ ├── trigger_schedule_job │ │ └── job.tf │ ├── worker_pool │ │ └── worker_pool.tf │ └── xla_docker_build │ │ ├── variables.tf │ │ └── xla_docker_build.tf ├── tpu-pytorch-releases │ ├── .terraform.lock.hcl │ ├── README.md │ ├── artifacts.auto.tfvars │ ├── artifacts_builds.tf │ ├── dev_images.auto.tfvars │ ├── dev_images.tf │ ├── iam.auto.tfvars │ ├── iam.tf │ ├── infra_triggers.tf │ ├── misc.tf │ └── provider.tf └── tpu-pytorch │ ├── .terraform.lock.hcl │ ├── README.md │ ├── iam.auto.tfvars │ ├── iam.tf │ ├── infra_triggers.tf │ ├── misc.tf │ ├── provider.tf │ ├── test_triggers.tf │ └── tpu_ci.tf ├── openxla_patches ├── BUILD ├── count_down.diff ├── gpu_nvml.diff └── gpu_race_condition.diff ├── plugins ├── cpu │ ├── BUILD │ ├── README.md │ ├── pjrt_c_api_cpu_version_script.lds │ ├── pyproject.toml │ ├── setup.py │ ├── test_cpu_plugin.cpp │ ├── test_cpu_plugin.h │ └── torch_xla_cpu_plugin │ │ └── __init__.py └── cuda │ ├── README.md │ ├── pyproject.toml │ ├── setup.py │ └── torch_xla_cuda_plugin │ └── __init__.py ├── requirements.in ├── requirements_lock_3_10.txt ├── requirements_lock_3_11.txt ├── requirements_lock_3_8.txt ├── requirements_lock_3_9.txt ├── scripts ├── apply_patches.sh ├── bench_tensor_io.py ├── build_developer.sh ├── build_torch_wheels.sh ├── capture_profile.py ├── cond_patch.py ├── debug_run.py ├── dump_stacks.py ├── grab_graphs.py ├── grab_metrics.py ├── metrics_compare.py ├── metrics_to_tensorboard.py ├── normalize_graph_text.py ├── run_bazel_coverage.sh ├── stack_trace_parse.py ├── tf_log_filter.py ├── update_compile_commands.py ├── update_core_aten_opset_issue.py ├── update_deps.py ├── update_nightly_torch_wheels.sh └── update_torch_wheels.sh ├── setup.py ├── test ├── __init__.py ├── args_parse.py ├── bench.py ├── benchmarks │ ├── .gitignore │ ├── Makefile │ ├── a6000.inference.speedup.test │ ├── a6000.jsonl │ ├── a6000.training.latest.empty.test │ ├── a6000.training.latest.test │ ├── output-example.json │ ├── run_tests.sh │ ├── run_torchbench_tests.sh │ ├── test_benchmark_experiment.py │ ├── test_benchmark_model.py │ ├── test_experiment_runner.py │ ├── test_ragged_paged_attention_benchmark.py │ ├── test_result_analyzer.py │ ├── v100.inference.histogram.lazytensor.test │ ├── v100.inference.histogram.lazytensor_tab.test │ ├── v100.inference.histogram.tab.test │ ├── v100.inference.histogram.test │ ├── v100.inference.latest.openxla_baseline.test │ ├── v100.inference.latest.tab.test │ ├── v100.inference.latest.test │ ├── v100.inference.latest.tier1.test │ ├── v100.inference.latest_grouped.test │ ├── v100.inference.speedup.baseline_latest.test │ ├── v100.inference.speedup.lazytensor.test │ ├── v100.inference.speedup.lazytensor_tab.test │ ├── v100.inference.speedup.tab.test │ ├── v100.inference.speedup.test │ └── v100.jsonl ├── cpp │ ├── BUILD │ ├── cpp_test_util.cpp │ ├── cpp_test_util.h │ ├── get_coverage.sh │ ├── main.cpp │ ├── metrics_snapshot.cpp │ ├── metrics_snapshot.h │ ├── run_tests.sh │ ├── test_aten_xla_tensor_1.cpp │ ├── test_aten_xla_tensor_2.cpp │ ├── test_aten_xla_tensor_3.cpp │ ├── test_aten_xla_tensor_4.cpp │ ├── test_aten_xla_tensor_5.cpp │ ├── test_aten_xla_tensor_6.cpp │ ├── test_ir.cpp │ ├── test_lazy.cpp │ ├── test_replication.cpp │ ├── test_symint.cpp │ ├── test_tensor.cpp │ ├── test_xla_backend_intf.cpp │ ├── test_xla_sharding.cpp │ ├── torch_xla_test.cpp │ └── torch_xla_test.h ├── custom_debug_lowering.py ├── debug_tool │ ├── extract_debug_helper.py │ ├── test_mp_pt_xla_debug.py │ └── test_pt_xla_debug.py ├── distributed_util.py ├── ds │ ├── test_dynamic_shape_models.py │ └── test_dynamic_shapes.py ├── dynamo │ ├── test_bridge.py │ ├── test_dynamo.py │ ├── test_dynamo_aliasing.py │ ├── test_dynamo_config.py │ ├── test_dynamo_dynamic_shape.py │ ├── test_dynamo_graph_dump.py │ ├── test_dynamo_integrations_util.py │ ├── test_graph_input_matcher.py │ ├── test_none_remover.py │ ├── test_num_output.py │ └── test_traceable_collectives.py ├── eager │ ├── test_eager.py │ ├── test_eager_all_reduce_in_place.py │ ├── test_eager_spmd.py │ ├── test_eager_with_torch_compile.py │ └── test_eager_with_xla_compile.py ├── metrics_compare_utils_test.py ├── neuron │ ├── run_tests.sh │ ├── test_neuron_data_types.py │ └── test_neuron_utils.py ├── pjrt │ ├── __init__.py │ ├── args_parse.py │ ├── test_collective_ops_tpu.py │ ├── test_ddp.py │ ├── test_dtypes.py │ ├── test_dynamic_plugin_tpu.py │ ├── test_internal_tpu.py │ ├── test_mesh_service.py │ ├── test_metrics.py │ ├── test_profiler.py │ ├── test_runtime.py │ ├── test_runtime_multi_cpu.py │ ├── test_runtime_multi_gpu.py │ ├── test_runtime_single_proc_gpu.py │ ├── test_runtime_tpu.py │ ├── test_torchrun.py │ └── test_train_hf_transformer.py ├── pytorch_test_base.py ├── quantized_ops │ ├── test_dot_general.py │ └── test_quantized_matmul.py ├── run_tests.sh ├── scan │ ├── test_scan.py │ ├── test_scan_debug.py │ ├── test_scan_layers.py │ ├── test_scan_pallas.py │ └── test_scan_spmd.py ├── schedulers.py ├── spmd │ ├── __init__.py │ ├── args_parse.py │ ├── test_dtensor_integration.py │ ├── test_dtensor_integration2.py │ ├── test_dynamo_spmd.py │ ├── test_fsdp_v2.py │ ├── test_mp_input_sharding.py │ ├── test_sharding_strategies.py │ ├── test_spmd_debugging.py │ ├── test_spmd_graph_dump.py │ ├── test_spmd_lowering_context.py │ ├── test_spmd_parameter_wrapping.py │ ├── test_train_spmd_imagenet.py │ ├── test_train_spmd_linear_model.py │ ├── test_xla_auto_sharding.py │ ├── test_xla_distributed_checkpoint.py │ ├── test_xla_sharding.py │ ├── test_xla_sharding_base.py │ ├── test_xla_sharding_hlo.py │ ├── test_xla_spmd_python_api_interaction.py │ └── test_xla_virtual_device.py ├── stablehlo │ ├── __init__.py │ ├── export_tinyroberta_unbounded_dynamism.py │ ├── export_vit_unbounded_dynamism.py │ ├── export_wav2vec2.py │ ├── llama_model.py │ ├── llama_model2.py │ ├── test_composite.py │ ├── test_export_fx_passes.py │ ├── test_export_llama.py │ ├── test_exports.py │ ├── test_implicit_broadcasting.py │ ├── test_mlir_debuginfo.py │ ├── test_pt2e_qdq.py │ ├── test_saved_model.py │ ├── test_stablehlo_compile.py │ ├── test_stablehlo_custom_call.py │ ├── test_stablehlo_inference.py │ ├── test_stablehlo_save_load.py │ ├── test_unbounded_dynamism.py │ └── test_xla_export_interpreter.py ├── test_as_stride_use_slice.py ├── test_assume_pure.py ├── test_assume_pure_spmd.py ├── test_assume_pure_torch.py ├── test_async_closures.py ├── test_autocast.py ├── test_autocast_xla.py ├── test_callback.py ├── test_compilation_cache_utils.py ├── test_core_aten_ops.py ├── test_data_type.py ├── test_deprecation.py ├── test_devices.py ├── test_dynamic_shapes_detector.py ├── test_env_var_mapper.py ├── test_fp8.py ├── test_fsdp_auto_wrap.py ├── test_gmm.py ├── test_gpu_device_detection.py ├── test_grad_checkpoint.py ├── test_gradient_accumulation.py ├── test_gru.py ├── test_hlo_metadata.py ├── test_inplace_update.py ├── test_input_output_aliases.py ├── test_jax_interop.py ├── test_manual_xla_registration.py ├── test_mat_mul_precision.py ├── test_mat_mul_precision_get_and_set.py ├── test_metrics.py ├── test_mp_all_gather.py ├── test_mp_all_to_all.py ├── test_mp_collective_matmul.py ├── test_mp_collective_permute.py ├── test_mp_distributed_mm.py ├── test_mp_early_exit.py ├── test_mp_mesh_reduce.py ├── test_mp_reduce_scatter.py ├── test_mp_rendezvous.py ├── test_mp_replication.py ├── test_mp_save.py ├── test_mp_sync_batch_norm.py ├── test_multi_queries_paged_attention_kernel.py ├── test_operations.py ├── test_operations_hlo.py ├── test_ops.py ├── test_pallas.py ├── test_pallas_spmd.py ├── test_persistent_cache.py ├── test_placeholder.py ├── test_profile_mp_mnist.py ├── test_profiler.py ├── test_profiler_session.py ├── test_python_ops.py ├── test_ragged_paged_attention_kernel.py ├── test_splash_attention.py ├── test_syncfree_optimizers.py ├── test_torch_distributed_fsdp_frozen_weight.py ├── test_torch_distributed_xla_backend.py ├── test_train_mp_imagenet.py ├── test_train_mp_imagenet_amp.py ├── test_train_mp_imagenet_fsdp.py ├── test_train_mp_mnist.py ├── test_train_mp_mnist_amp.py ├── test_train_mp_mnist_fsdp_with_ckpt.py ├── test_train_mp_mnist_zero1.py ├── test_triton.py ├── test_user_computation_debug_cache.py ├── test_utils.py ├── test_while_loop.py ├── test_xla_graph_execution.py ├── test_zero1.py ├── torch_distributed │ ├── test_ddp.py │ ├── test_torch_distributed_all_gather_xla_backend.py │ ├── test_torch_distributed_all_reduce_xla_backend.py │ ├── test_torch_distributed_bucketed_all_reduce_xla_backend.py │ ├── test_torch_distributed_fsdp_meta.py │ ├── test_torch_distributed_multi_all_reduce_xla_backend.py │ └── test_torch_distributed_reduce_scatter_xla_backend.py ├── tpu │ ├── Dockerfile │ ├── run_expensive_test_1.sh │ ├── run_expensive_test_2.sh │ ├── run_pallas_test.sh │ ├── run_tests.sh │ ├── tpu_info │ │ └── test_cli.py │ └── xla_test_job.yaml └── utils │ ├── __init__.py │ ├── run_tests_utils.sh │ ├── train_mp_result_analyzer.py │ ├── train_spmd_linear_model.py │ └── train_spmd_linear_model_grad_acc.py ├── torch_xla ├── __init__.py ├── _dynamo │ ├── __init__.py │ ├── config.py │ ├── dynamo_backend2.py │ └── dynamo_bridge.py ├── _internal │ ├── __init__.py │ ├── c10d_registration.py │ ├── custom_kernel.py │ ├── decomp_registration.py │ ├── gpu.py │ ├── jax_workarounds.py │ ├── neuron.py │ ├── neuron_utils.py │ ├── pjrt.py │ ├── rendezvous.py │ ├── tpu.py │ ├── utils.py │ └── xpu.py ├── _patched_functions.py ├── amp │ ├── __init__.py │ ├── autocast_mode.py │ ├── grad_scaler.py │ └── syncfree │ │ ├── __init__.py │ │ ├── _functional.py │ │ ├── adam.py │ │ ├── adamw.py │ │ └── sgd.py ├── backends │ └── __init__.py ├── core │ ├── __init__.py │ ├── dynamo_bridge.py │ ├── functions.py │ ├── xla_builder.py │ ├── xla_env_vars.py │ ├── xla_model.py │ └── xla_op_registry.py ├── csrc │ ├── BUILD │ ├── aten_autograd_ops.cpp │ ├── aten_autograd_ops.h │ ├── aten_cuda_functions.cpp │ ├── aten_cuda_functions.h │ ├── aten_fallback.cpp │ ├── aten_fallback.h │ ├── aten_xla_bridge.cpp │ ├── aten_xla_bridge.h │ ├── aten_xla_type.cpp │ ├── autocast_mode.cpp │ ├── batch_norm.cpp │ ├── batch_norm.h │ ├── convert_ops.cpp │ ├── convert_ops.h │ ├── convolution.cpp │ ├── convolution.h │ ├── convolution_helper.cpp │ ├── convolution_helper.h │ ├── cross_replica_reduces.cpp │ ├── cross_replica_reduces.h │ ├── data_ops.cpp │ ├── data_ops.h │ ├── debug_util.cpp │ ├── debug_util.h │ ├── device.cpp │ ├── device.h │ ├── dl_convertor.cpp │ ├── dl_convertor.h │ ├── dtype.cpp │ ├── dtype.h │ ├── dynamic_shape_detector.cpp │ ├── dynamic_shape_detector.h │ ├── elementwise.cpp │ ├── elementwise.h │ ├── function_call_tracker.cpp │ ├── function_call_tracker.h │ ├── generated_file_include.h │ ├── helpers.cpp │ ├── helpers.h │ ├── init_python_bindings.cpp │ ├── ir.cpp │ ├── ir.h │ ├── ir_builder.h │ ├── ir_dump_util.cpp │ ├── ir_dump_util.h │ ├── layout_manager.cpp │ ├── layout_manager.h │ ├── lowering_context.cpp │ ├── lowering_context.h │ ├── matrix.cpp │ ├── matrix.h │ ├── nll_loss.cpp │ ├── nll_loss.h │ ├── ops │ │ ├── adam_optimizer_step.cpp │ │ ├── adam_optimizer_step.h │ │ ├── adaptive_max_pool2d.cpp │ │ ├── adaptive_max_pool2d.h │ │ ├── all_gather.cpp │ │ ├── all_gather.h │ │ ├── all_reduce.cpp │ │ ├── all_reduce.h │ │ ├── all_to_all.cpp │ │ ├── all_to_all.h │ │ ├── amp_foreach_non_finite_check_and_unscale.cpp │ │ ├── amp_foreach_non_finite_check_and_unscale.h │ │ ├── amp_update_scale.cpp │ │ ├── amp_update_scale.h │ │ ├── arithmetic_ir_ops.cpp │ │ ├── arithmetic_ir_ops.h │ │ ├── as_strided.cpp │ │ ├── as_strided.h │ │ ├── as_strided_view_update.cpp │ │ ├── as_strided_view_update.h │ │ ├── avg_pool_nd.cpp │ │ ├── avg_pool_nd.h │ │ ├── avg_pool_nd_backward.cpp │ │ ├── avg_pool_nd_backward.h │ │ ├── bernoulli.cpp │ │ ├── bernoulli.h │ │ ├── cast.cpp │ │ ├── cast.h │ │ ├── cast_int4.cpp │ │ ├── cast_int4.h │ │ ├── cat.cpp │ │ ├── cat.h │ │ ├── cdist.cpp │ │ ├── cdist.h │ │ ├── collective_permute.cpp │ │ ├── collective_permute.h │ │ ├── constant.cpp │ │ ├── constant.h │ │ ├── constant_pad_nd.cpp │ │ ├── constant_pad_nd.h │ │ ├── convolution_backward_overrideable.cpp │ │ ├── convolution_backward_overrideable.h │ │ ├── convolution_overrideable.cpp │ │ ├── convolution_overrideable.h │ │ ├── count_nonzero.cpp │ │ ├── count_nonzero.h │ │ ├── cummax.cpp │ │ ├── cummax.h │ │ ├── cumprod.cpp │ │ ├── cumprod.h │ │ ├── cumsum.cpp │ │ ├── cumsum.h │ │ ├── custom_call.cpp │ │ ├── custom_call.h │ │ ├── custom_sharding.cpp │ │ ├── custom_sharding.h │ │ ├── dequant_tensor.cpp │ │ ├── dequant_tensor.h │ │ ├── device_data.cpp │ │ ├── device_data.h │ │ ├── diagonal.cpp │ │ ├── diagonal.h │ │ ├── diagonal_view_update.cpp │ │ ├── diagonal_view_update.h │ │ ├── discrete_uniform.cpp │ │ ├── discrete_uniform.h │ │ ├── dot_general.cpp │ │ ├── dot_general.h │ │ ├── dynamic_expand.cpp │ │ ├── dynamic_expand.h │ │ ├── dynamic_ir.cpp │ │ ├── dynamic_ir.h │ │ ├── dynamic_view.cpp │ │ ├── dynamic_view.h │ │ ├── eigh.cpp │ │ ├── eigh.h │ │ ├── einsum.cpp │ │ ├── einsum.h │ │ ├── einsum_backward.cpp │ │ ├── einsum_backward.h │ │ ├── einsum_utilities.h │ │ ├── embedding_bag.cpp │ │ ├── embedding_bag.h │ │ ├── expand.cpp │ │ ├── expand.h │ │ ├── expand_symint.cpp │ │ ├── expand_symint.h │ │ ├── exponential.cpp │ │ ├── exponential.h │ │ ├── flip.cpp │ │ ├── flip.h │ │ ├── gather.cpp │ │ ├── gather.h │ │ ├── generic.cpp │ │ ├── generic.h │ │ ├── generic_slice.cpp │ │ ├── generic_slice.h │ │ ├── get_dimensions_size.cpp │ │ ├── get_dimensions_size.h │ │ ├── gpu_custom_call.cpp │ │ ├── gpu_custom_call.h │ │ ├── hardtanh_backward.cpp │ │ ├── hardtanh_backward.h │ │ ├── index_get.cpp │ │ ├── index_get.h │ │ ├── index_ops.cpp │ │ ├── index_ops.h │ │ ├── index_put.cpp │ │ ├── index_put.h │ │ ├── index_select.cpp │ │ ├── index_select.h │ │ ├── infer_output_shape.cpp │ │ ├── infer_output_shape.h │ │ ├── kth_value.cpp │ │ ├── kth_value.h │ │ ├── linear_interpolation.cpp │ │ ├── linear_interpolation.h │ │ ├── linspace.cpp │ │ ├── linspace.h │ │ ├── log_softmax.cpp │ │ ├── log_softmax.h │ │ ├── log_softmax_backward.cpp │ │ ├── log_softmax_backward.h │ │ ├── logsumexp.cpp │ │ ├── logsumexp.h │ │ ├── mark_tensor.cpp │ │ ├── mark_tensor.h │ │ ├── masked_scatter.cpp │ │ ├── masked_scatter.h │ │ ├── masked_select.cpp │ │ ├── masked_select.h │ │ ├── max_in_dim.cpp │ │ ├── max_in_dim.h │ │ ├── max_pool_nd.cpp │ │ ├── max_pool_nd.h │ │ ├── max_pool_nd_backward.cpp │ │ ├── max_pool_nd_backward.h │ │ ├── max_unpool_nd.cpp │ │ ├── max_unpool_nd.h │ │ ├── mean.cpp │ │ ├── mean.h │ │ ├── min_in_dim.cpp │ │ ├── min_in_dim.h │ │ ├── mse_loss.cpp │ │ ├── mse_loss.h │ │ ├── mse_loss_backward.cpp │ │ ├── mse_loss_backward.h │ │ ├── multinomial.cpp │ │ ├── multinomial.h │ │ ├── native_batch_norm_backward.cpp │ │ ├── native_batch_norm_backward.h │ │ ├── native_batch_norm_forward.cpp │ │ ├── native_batch_norm_forward.h │ │ ├── native_dropout.cpp │ │ ├── native_dropout.h │ │ ├── nll_loss.cpp │ │ ├── nll_loss.h │ │ ├── nll_loss2d.cpp │ │ ├── nll_loss2d.h │ │ ├── nll_loss2d_backward.cpp │ │ ├── nll_loss2d_backward.h │ │ ├── nll_loss_backward.cpp │ │ ├── nll_loss_backward.h │ │ ├── nms.cpp │ │ ├── nms.h │ │ ├── nonzero.cpp │ │ ├── nonzero.h │ │ ├── normal.cpp │ │ ├── normal.h │ │ ├── not_supported.cpp │ │ ├── not_supported.h │ │ ├── ops.cpp │ │ ├── ops.h │ │ ├── ops_lower_fn.cpp │ │ ├── ops_xla_shape_fn.cpp │ │ ├── ops_xla_shape_fn.h │ │ ├── optimization_barrier.cpp │ │ ├── optimization_barrier.h │ │ ├── permute.cpp │ │ ├── permute.h │ │ ├── prod.cpp │ │ ├── prod.h │ │ ├── put.cpp │ │ ├── put.h │ │ ├── qr.cpp │ │ ├── qr.h │ │ ├── quant_tensor.cpp │ │ ├── quant_tensor.h │ │ ├── randperm.cpp │ │ ├── randperm.h │ │ ├── recv.cpp │ │ ├── recv.h │ │ ├── reduce_scatter.cpp │ │ ├── reduce_scatter.h │ │ ├── reflection_pad2d.cpp │ │ ├── reflection_pad2d.h │ │ ├── reflection_pad2d_backward.cpp │ │ ├── reflection_pad2d_backward.h │ │ ├── replication_pad.cpp │ │ ├── replication_pad.h │ │ ├── replication_pad_backward.cpp │ │ ├── replication_pad_backward.h │ │ ├── resize.cpp │ │ ├── resize.h │ │ ├── roll.cpp │ │ ├── roll.h │ │ ├── rrelu_with_noise.cpp │ │ ├── rrelu_with_noise.h │ │ ├── rrelu_with_noise_backward.cpp │ │ ├── rrelu_with_noise_backward.h │ │ ├── scalar.cpp │ │ ├── scalar.h │ │ ├── scatter.cpp │ │ ├── scatter.h │ │ ├── scatter_add.cpp │ │ ├── scatter_add.h │ │ ├── scatter_reduce.cpp │ │ ├── scatter_reduce.h │ │ ├── select.cpp │ │ ├── select.h │ │ ├── send.cpp │ │ ├── send.h │ │ ├── sgd_optimizer_step.cpp │ │ ├── sgd_optimizer_step.h │ │ ├── softmax.cpp │ │ ├── softmax.h │ │ ├── softmax_backward.cpp │ │ ├── softmax_backward.h │ │ ├── split.cpp │ │ ├── split.h │ │ ├── squeeze.cpp │ │ ├── squeeze.h │ │ ├── stack.cpp │ │ ├── stack.h │ │ ├── std.cpp │ │ ├── std.h │ │ ├── std_mean.cpp │ │ ├── std_mean.h │ │ ├── sum.cpp │ │ ├── sum.h │ │ ├── svd.cpp │ │ ├── svd.h │ │ ├── symeig.cpp │ │ ├── symeig.h │ │ ├── threshold.cpp │ │ ├── threshold.h │ │ ├── threshold_backward.cpp │ │ ├── threshold_backward.h │ │ ├── topk.cpp │ │ ├── topk.h │ │ ├── tpu_custom_call.cpp │ │ ├── tpu_custom_call.h │ │ ├── triangular_solve.cpp │ │ ├── triangular_solve.h │ │ ├── uniform.cpp │ │ ├── uniform.h │ │ ├── unselect.cpp │ │ ├── unselect.h │ │ ├── unsqueeze.cpp │ │ ├── unsqueeze.h │ │ ├── update_slice.cpp │ │ ├── update_slice.h │ │ ├── upsample_bilinear2d.cpp │ │ ├── upsample_bilinear2d.h │ │ ├── upsample_bilinear2d_backward.cpp │ │ ├── upsample_bilinear2d_backward.h │ │ ├── upsample_nearest2d.cpp │ │ ├── upsample_nearest2d.h │ │ ├── upsample_nearest2d_backward.cpp │ │ ├── upsample_nearest2d_backward.h │ │ ├── user_computation.cpp │ │ ├── user_computation.h │ │ ├── var.cpp │ │ ├── var.h │ │ ├── var_mean.cpp │ │ ├── var_mean.h │ │ ├── view.cpp │ │ ├── view.h │ │ ├── xla_ops.cpp │ │ └── xla_ops.h │ ├── pooling.cpp │ ├── pooling.h │ ├── quant_util.cpp │ ├── quant_util.h │ ├── random.cpp │ ├── random.h │ ├── reduction.cpp │ ├── reduction.h │ ├── resize_ops.cpp │ ├── resize_ops.h │ ├── runtime │ │ ├── BUILD │ │ ├── cache.h │ │ ├── cache_test.cpp │ │ ├── computation_client.cpp │ │ ├── computation_client.h │ │ ├── debug_macros.h │ │ ├── env_hash.cpp │ │ ├── env_hash.h │ │ ├── env_hash_test.cpp │ │ ├── env_vars.cpp │ │ ├── env_vars.h │ │ ├── ifrt_computation_client.cpp │ │ ├── ifrt_computation_client.h │ │ ├── ifrt_computation_client_test.cpp │ │ ├── metrics.cpp │ │ ├── metrics.h │ │ ├── metrics_analysis.cpp │ │ ├── metrics_analysis.h │ │ ├── metrics_reader.cpp │ │ ├── metrics_reader.h │ │ ├── operation_manager.cpp │ │ ├── operation_manager.h │ │ ├── pjrt_computation_client.cpp │ │ ├── pjrt_computation_client.h │ │ ├── pjrt_computation_client_test.cpp │ │ ├── pjrt_registry.cpp │ │ ├── pjrt_registry.h │ │ ├── profiler.cpp │ │ ├── profiler.h │ │ ├── runtime.cpp │ │ ├── runtime.h │ │ ├── stablehlo_composite_helper.cpp │ │ ├── stablehlo_composite_helper.h │ │ ├── stablehlo_helper.cpp │ │ ├── stablehlo_helper.h │ │ ├── sys_util.cpp │ │ ├── sys_util.h │ │ ├── sys_util_test.cpp │ │ ├── tensor_source.h │ │ ├── tf_logging.cpp │ │ ├── tf_logging.h │ │ ├── types.h │ │ ├── util.h │ │ ├── util_test.cpp │ │ ├── xla_coordinator.cpp │ │ ├── xla_coordinator.h │ │ ├── xla_mlir_debuginfo_helper.cpp │ │ ├── xla_mlir_debuginfo_helper.h │ │ ├── xla_util.cpp │ │ ├── xla_util.h │ │ └── xla_util_test.cpp │ ├── shape_builder.cpp │ ├── shape_builder.h │ ├── shape_helper.cpp │ ├── shape_helper.h │ ├── softmax_builder.cpp │ ├── softmax_builder.h │ ├── stack_frame_index_builder.cpp │ ├── stack_frame_index_builder.h │ ├── tensor.cpp │ ├── tensor.h │ ├── tensor_common.h │ ├── tensor_impl.cpp │ ├── tensor_impl.h │ ├── tensor_methods.cpp │ ├── tensor_methods.h │ ├── tensor_ops.cpp │ ├── tensor_ops.h │ ├── tensor_util.cpp │ ├── tensor_util.h │ ├── thread_pool.cpp │ ├── thread_pool.h │ ├── token_handler.cpp │ ├── token_handler.h │ ├── torch_util.cpp │ ├── torch_util.h │ ├── unwrap_data.cpp │ ├── unwrap_data.h │ ├── version.h │ ├── view.cpp │ ├── view.h │ ├── xla_backend_impl.cpp │ ├── xla_backend_impl.h │ ├── xla_graph_executor.cpp │ ├── xla_graph_executor.h │ ├── xla_lower_util.cpp │ ├── xla_lower_util.h │ ├── xla_manual_registration.cpp │ ├── xla_op_builder.cpp │ ├── xla_op_builder.h │ ├── xla_sharding_util.cpp │ └── xla_sharding_util.h ├── debug │ ├── __init__.py │ ├── frame_parser_util.py │ ├── graph_saver.py │ ├── metrics.py │ ├── metrics_compare_utils.py │ ├── metrics_saver.py │ ├── model_comparator.py │ └── profiler.py ├── distributed │ ├── __init__.py │ ├── data_parallel.py │ ├── fsdp │ │ ├── __init__.py │ │ ├── _init_utils.py │ │ ├── consolidate_sharded_ckpts.py │ │ ├── state_dict_utils.py │ │ ├── utils.py │ │ ├── wrap.py │ │ ├── xla_flatten_params_wrapper.py │ │ └── xla_fully_sharded_data_parallel.py │ ├── parallel_loader.py │ ├── spmd │ │ ├── __init__.py │ │ ├── api.py │ │ ├── debugging.py │ │ ├── xla_sharded_tensor.py │ │ └── xla_sharding.py │ ├── xla_backend.py │ ├── xla_multiprocessing.py │ └── zero_redundancy_optimizer.py ├── experimental │ ├── __init__.py │ ├── assume_pure.py │ ├── callback.py │ ├── custom_kernel.py │ ├── deprecation.py │ ├── distributed_checkpoint │ │ ├── __init__.py │ │ ├── _helpers.py │ │ ├── manager.py │ │ ├── planners.py │ │ └── util.py │ ├── dynamo_mark_sharding.py │ ├── dynamo_set_buffer_donor.py │ ├── eager.py │ ├── fori_loop.py │ ├── gradient_accumulation.py │ ├── gru.py │ ├── mark_pattern_utils.py │ ├── pallas_kernels │ │ ├── __init__.py │ │ ├── multi_queries_paged_attention_kernel.py │ │ ├── ragged_paged_attention_kernel.py │ │ └── ragged_paged_attention_v2.py │ ├── pjrt_backend.py │ ├── plugins.py │ ├── pytreeify.py │ ├── quantized.py │ ├── scan.py │ ├── scan_layers.py │ ├── splash_attention.py │ ├── spmd_fully_sharded_data_parallel.py │ ├── stablehlo_custom_call.py │ ├── triton.py │ ├── unbounded_dynamism_export.py │ ├── xla_dynamic_reshape_ops.py │ ├── xla_marker.py │ ├── xla_mlir_debuginfo.py │ └── xla_quantized_matmul.py ├── runtime.py ├── stablehlo.py ├── test │ ├── __init__.py │ └── test_utils.py ├── tf_saved_model_integration.py ├── torch_xla.py └── utils │ ├── __init__.py │ ├── buffer_donor_context.py │ ├── checkpoint.py │ ├── checkpoint_tagger.py │ ├── closures.py │ ├── dlpack.py │ ├── keyd_queue.py │ ├── serialization.py │ ├── stablehlo_test_utils.py │ └── utils.py └── torchax ├── LICENSE ├── README.md ├── build_nightly.sh ├── dev-requirements.txt ├── docs ├── dispatch.png ├── fixing_op_info_test.md ├── how_it_works.md ├── ops_registry.md ├── support_a_new_model.md ├── torch_dispatch │ ├── README.md │ ├── example.py │ └── run_env.py ├── torch_xla2_dynamo.md └── understand_jax_jit │ ├── jax_jit.py │ └── torch_module.py ├── examples ├── README.md ├── __init__.py ├── _diffusion.py ├── _grad_of_attention.py ├── basic_training.py ├── basic_training_jax.py ├── eager_mode.py ├── lightning_training.py ├── mnist_tpu.ipynb ├── requirements.txt ├── torchbench_models │ └── BERT_pytorch.py ├── train_gpt │ ├── requirements.txt │ └── train_ddp.py ├── train_llama │ ├── README.md │ ├── __init__.py │ ├── model.py │ ├── train_llama_lightning.py │ └── utils.py └── train_llama_torchtitan │ ├── Dockerfile │ ├── README.md │ ├── __init__.py │ ├── helper.py │ ├── splash_attn.py │ └── train_llama.py ├── format.sh ├── pyproject.toml ├── test-requirements.txt ├── test ├── __init__.py ├── base_test_util.py ├── gemma │ ├── __init__.py │ ├── config.py │ ├── model.py │ ├── test_gemma.py │ └── tokenizer.py ├── llama │ ├── BUILD │ ├── __init__.py │ ├── llama_model.py │ ├── model_exportable.py │ └── test_llama.py ├── moe │ ├── __init__.py │ ├── model.py │ └── moe_test.py ├── test_context.py ├── test_conv.py ├── test_core_aten_ops.py ├── test_exports.py ├── test_flax.py ├── test_functions.py ├── test_image.py ├── test_interop.py ├── test_libraries.py ├── test_mutations.py ├── test_ops.py ├── test_symbolic_shapes.py ├── test_tf_integration.py ├── test_train.py ├── test_unbounded_dynamism.py ├── test_util.py └── test_view.py ├── test_dist ├── README.md ├── __init__.py ├── test_distributed.py └── test_mesh_util.py └── torchax ├── CONTRIBUTING.md ├── __init__.py ├── config.py ├── decompositions.py ├── device_module.py ├── distributed.py ├── export.py ├── flax.py ├── interop.py ├── mesh_util.py ├── ops ├── __init__.py ├── jaten.py ├── jax_reimplement.py ├── jc10d.py ├── jimage.py ├── jlibrary.py ├── jtorch.py ├── jtorchvision_nms.py ├── mappings.py ├── op_base.py └── ops_registry.py ├── tensor.py ├── tf_integration.py ├── train.py ├── types.py ├── util.py └── view.py /.bazelversion: -------------------------------------------------------------------------------- 1 | 6.5.0 2 | -------------------------------------------------------------------------------- /.circleci/test_xrt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | source ./xla_env 6 | source .circleci/common.sh 7 | 8 | PYTORCH_DIR=/tmp/pytorch 9 | XLA_DIR=$PYTORCH_DIR/xla 10 | USE_COVERAGE="${USE_COVERAGE:-0}" 11 | 12 | # Needs to be kept in sync with .jenkins/pytorch/common_utils.sh in pytorch/pytorch. 13 | TORCHVISION_COMMIT="$(cat $PYTORCH_DIR/.github/ci_commit_pins/vision.txt)" 14 | 15 | function pip_install() { 16 | # retry 3 times 17 | # old versions of pip don't have the "--progress-bar" flag 18 | pip install --progress-bar off "$@" || pip install --progress-bar off "$@" || pip install --progress-bar off "$@" ||\ 19 | pip install "$@" || pip install "$@" || pip install "$@" 20 | } 21 | 22 | function install_torchvision() { 23 | pip_install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$TORCHVISION_COMMIT" 24 | } 25 | 26 | install_torchvision 27 | 28 | ./test/run_tests.sh 29 | -------------------------------------------------------------------------------- /.clangd: -------------------------------------------------------------------------------- 1 | CompileFlags: 2 | CompilationDatabase: build # Specifies that compile_commands.json is in this directory. 3 | -------------------------------------------------------------------------------- /.devcontainer/gpu-internal/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "gpu-internal", 3 | "image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1", 4 | "runArgs": [ 5 | "--gpus=all", 6 | "--net=host", 7 | "--shm-size=16G" 8 | ], 9 | "containerEnv": { 10 | "BAZEL_REMOTE_CACHE": "1", 11 | "SILO_NAME": "cache-silo-${localEnv:USER}-gpuvm" 12 | }, 13 | "initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1", 14 | "customizations": { 15 | "vscode": { 16 | "extensions": [ 17 | "llvm-vs-code-extensions.vscode-clangd", 18 | "ms-vscode.cpptools-themes", 19 | "BazelBuild.vscode-bazel", 20 | "DevonDCarew.bazel-code", 21 | "StackBuild.bazel-stack-vscode", 22 | "StackBuild.bazel-stack-vscode-cc", 23 | "xaver.clang-format", 24 | "ryanluker.vscode-coverage-gutters", 25 | "ms-azuretools.vscode-docker", 26 | "ms-python.python" 27 | ] 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /.devcontainer/tpu-contributor/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tpu-contributor", 3 | "image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 4 | "runArgs": [ 5 | "--privileged", 6 | "--net=host", 7 | "--shm-size=16G" 8 | ], 9 | "initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 10 | "customizations": { 11 | "vscode": { 12 | "extensions": [ 13 | "llvm-vs-code-extensions.vscode-clangd", 14 | "ms-vscode.cpptools-themes", 15 | "BazelBuild.vscode-bazel", 16 | "DevonDCarew.bazel-code", 17 | "StackBuild.bazel-stack-vscode", 18 | "StackBuild.bazel-stack-vscode-cc", 19 | "xaver.clang-format", 20 | "ryanluker.vscode-coverage-gutters", 21 | "ms-azuretools.vscode-docker", 22 | "ms-python.python" 23 | ] 24 | } 25 | } 26 | } -------------------------------------------------------------------------------- /.devcontainer/tpu-internal/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tpu-internal", 3 | "image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 4 | "runArgs": [ 5 | "--privileged", 6 | "--net=host", 7 | "--shm-size=16G" 8 | ], 9 | "containerEnv": { 10 | "BAZEL_REMOTE_CACHE": "1", 11 | "SILO_NAME": "cache-silo-${localEnv:USER}-tpuvm" 12 | }, 13 | "initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu", 14 | "customizations": { 15 | "vscode": { 16 | "extensions": [ 17 | "llvm-vs-code-extensions.vscode-clangd", 18 | "ms-vscode.cpptools-themes", 19 | "BazelBuild.vscode-bazel", 20 | "StackBuild.bazel-stack-vscode", 21 | "StackBuild.bazel-stack-vscode-cc", 22 | "xaver.clang-format", 23 | "ryanluker.vscode-coverage-gutters", 24 | "ms-azuretools.vscode-docker", 25 | "ms-python.python", 26 | "eeyore.yapf" 27 | ] 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | /infra @lsy323 @ManfeiBai @zpcore @tengyifei @bhavya01 @qihqi 2 | /docs @mikegre-google @tengyifei 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | If you are asking a question, please preface the title with [question]. 2 | If you are submitting a feature request, please preface the title with [feature request]. 3 | If you are submitting a bug report, please fill in the following details. 4 | 5 | ## Issue description 6 | 7 | Provide a short description. 8 | 9 | ## Code example 10 | 11 | Please try to provide a minimal example to repro the bug. 12 | Error messages and stack traces are also helpful. 13 | 14 | ## System Info 15 | 16 | - reproducible on XLA backend [CPU/TPU/CUDA]: 17 | - torch_xla version: 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Report an issue about missing/wrong documentation 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 📚 Documentation 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680Feature Request" 3 | about: Submit a proposal/request for a new feature for PyTorch/XLA integration 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🚀 Feature 11 | 12 | 13 | ## Motivation 14 | 15 | 16 | 17 | ## Pitch 18 | 19 | 20 | 21 | ## Alternatives 22 | 23 | 24 | 25 | ## Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓Questions/Help/Support" 3 | about: Do you need support? We have resources. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## ❓ Questions and Help 11 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 30 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - high priority 8 | - nostale 9 | # Label to use when marking an issue as stale 10 | staleLabel: stale 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | # Limit to only `issues` or `pulls` 19 | only: issues 20 | -------------------------------------------------------------------------------- /.github/upstream/install_valgrind.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | 5 | mkdir valgrind_build && cd valgrind_build 6 | VALGRIND_VERSION=3.16.1 7 | wget https://ossci-linux.s3.amazonaws.com/valgrind-${VALGRIND_VERSION}.tar.bz2 8 | tar -xjf valgrind-${VALGRIND_VERSION}.tar.bz2 9 | cd valgrind-${VALGRIND_VERSION} 10 | ./configure --prefix=/usr/local 11 | make -j6 12 | make install 13 | cd ../../ 14 | rm -rf valgrind_build 15 | alias valgrind="/usr/local/bin/valgrind" 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | *.egg-info/ 4 | torch_xla/lib/ 5 | torch_xla/pb/cpp/* 6 | torch_xla/version.py 7 | torch_xla/csrc/version.cpp 8 | */**/__pycache__ 9 | *.swp 10 | *.pyc 11 | *.so 12 | 13 | # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.) 14 | # 15 | # Below files are not deleted by "setup.py clean". 16 | 17 | # Visual Studio Code files 18 | .vs 19 | .vscode/ 20 | 21 | # Files autogenerated by docs/docs_build.sh 22 | /core 23 | /docs/src/* 24 | 25 | # Local terraform state 26 | .terraform 27 | 28 | 29 | # Build system temporary files 30 | bazel-* 31 | 32 | # Clangd cache directory 33 | .cache/* 34 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/.gitmodules -------------------------------------------------------------------------------- /bazel/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/bazel/BUILD -------------------------------------------------------------------------------- /bazel/dependencies.bzl: -------------------------------------------------------------------------------- 1 | PYTORCH_LOCAL_DIR = "../" 2 | -------------------------------------------------------------------------------- /bazel/nlohmann_json.BUILD: -------------------------------------------------------------------------------- 1 | cc_library( 2 | name = "json", 3 | hdrs = [ 4 | "single_include/nlohmann/json.hpp", 5 | "single_include/nlohmann/json_fwd.hpp", 6 | ], 7 | includes = ["single_include"], 8 | visibility = ["//visibility:public"], 9 | ) 10 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /benchmarks/check_xla_device.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | assert len(sys.argv) in (2, 3) 4 | devkind = sys.argv[1] 5 | 6 | 7 | def use_torch_xla2(): 8 | use_xla2 = False 9 | if len(sys.argv) == 3 and sys.argv[2].lower() == 'true': 10 | use_xla2 = True 11 | return use_xla2 12 | 13 | 14 | import os 15 | 16 | os.environ["PJRT_DEVICE"] = devkind 17 | 18 | if not use_torch_xla2(): 19 | import torch_xla.core.xla_model as xm 20 | devlist = xm.get_xla_supported_devices(devkind=devkind) 21 | else: 22 | # torch_xla2 needs jax to detect the device 23 | os.environ["JAX_PLATFORMS"] = devkind.lower( 24 | ) # JAX_PLATFORMS only accepts lower case 25 | assert devkind.lower() in ('cpu', 'gpu', 'tpu') 26 | import jax 27 | devlist = jax.devices(devkind.lower()) 28 | 29 | if not devlist: 30 | sys.exit(1) 31 | -------------------------------------------------------------------------------- /benchmarks/patches/mismatched_batch_size.patch: -------------------------------------------------------------------------------- 1 | diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py 2 | index 8593ba4c..57fef507 100644 3 | --- a/torchbenchmark/util/model.py 4 | +++ b/torchbenchmark/util/model.py 5 | @@ -182,6 +182,8 @@ class BenchmarkModel(metaclass=PostInitProcessor): 6 | 7 | # use the device suggestion on CUDA inference tests, key should be either eval_batch_size or train_batch_size 8 | device_batch_size_key = f"{self.test}_batch_size" 9 | + # A patch to making sure batch sizes are comparable. It's needed because xla device string is unrecognized. 10 | + current_device_name = 'NVIDIA A100-SXM4-40GB' 11 | if self.metadata and "devices" in self.metadata and current_device_name in self.metadata["devices"] \ 12 | and device_batch_size_key in self.metadata["devices"][current_device_name]: 13 | batch_size = self.metadata["devices"][current_device_name][device_batch_size_key] 14 | -------------------------------------------------------------------------------- /benchmarks/requirements.txt: -------------------------------------------------------------------------------- 1 | tabulate 2 | scipy 3 | pandas 4 | -------------------------------------------------------------------------------- /benchmarks/run_single_graph_bm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -ex 4 | 5 | DATE=$(date +"%Y_%m_%d_%H_%M") 6 | 7 | OUT_PATH=xla/benchmarks/bm_results/single_graph/$DATE 8 | mkdir -p $OUT_PATH 9 | 10 | python new_xla/benchmarks/experiment_runner.py \ 11 | --dynamo=inductor --dynamo=openxla \ 12 | --xla=None --xla=PJRT \ 13 | --test=eval \ 14 | --filter-by-single-graph \ 15 | --pure-wall-time \ 16 | --suite-name=torchbench \ 17 | --accelerator=cuda \ 18 | --output-dirname=$OUT_PATH \ 19 | --repeat=5 \ 20 | --print-subprocess \ 21 | --no-resume \ 22 | > $OUT_PATH/stdout.txt 2> $OUT_PATH/stderr.txt 23 | 24 | python3 xla/benchmarks/result_analyzer.py \ 25 | --output-dirname=$OUT_PATH \ 26 | --database=$OUT_PATH/$DATE.csv 27 | -------------------------------------------------------------------------------- /benchmarks/run_top_tier_bm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -ex 4 | 5 | DATE=$(date +"%Y_%m_%d_%H_%M") 6 | 7 | OUT_PATH=xla/benchmarks/bm_results/$DATE 8 | mkdir -p $OUT_PATH 9 | 10 | python xla/benchmarks/experiment_runner.py \ 11 | --dynamo=inductor --dynamo=openxla \ 12 | --xla=None --xla=PJRT \ 13 | --test=eval --test=train \ 14 | --filter-by-tier=1 --filter-by-tier=2 --filter-by-tier=3 \ 15 | --suite-name=torchbench \ 16 | --accelerator=cuda \ 17 | --output-dirname=$OUT_PATH \ 18 | --repeat=5 \ 19 | --print-subprocess \ 20 | --no-resume \ 21 | > $OUT_PATH/stdout.txt 2> $OUT_PATH/stderr.txt 22 | 23 | python3 xla/benchmarks/result_analyzer.py \ 24 | --output-dirname=$OUT_PATH \ 25 | --database=$OUT_PATH/$DATE.csv 26 | -------------------------------------------------------------------------------- /codegen/BUILD: -------------------------------------------------------------------------------- 1 | exports_files(["xla_native_functions.yaml"]) 2 | 3 | # Requires `torchgen` locally, via pip `torch` package or local builds. 4 | py_binary( 5 | name = "lazy_tensor_generator", 6 | srcs = ["lazy_tensor_generator.py"], 7 | data = [ 8 | "//torch_xla/csrc:aten_xla_type.cpp", 9 | "@torch//:torchgen_deps", 10 | ], 11 | deps = [ 12 | "@pypi_torch//:pkg", 13 | "@pypi_pyyaml//:pkg", 14 | ], 15 | tags = [ 16 | "local", 17 | "no-remote-exec", 18 | ], 19 | ) 20 | 21 | sh_binary( 22 | name = "fix_includes", 23 | srcs = ["fix_includes.sh"], 24 | ) 25 | -------------------------------------------------------------------------------- /codegen/fix_includes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # torchgen insists on including the tensor header as a system header 4 | sed -i 's##"torch_xla/csrc/tensor.h"#' $@ 5 | 6 | # remove the runfiles-prefix used in codegen for pytorch 7 | # `torchgen` generates relative includes path and does not support customizing the root, 8 | # so we have to fix them up. 9 | if [[ $(uname -m) == "x86_64" ]]; then 10 | sed -i 's#bazel-out/k8-[^/]*/bin/codegen/lazy_tensor_generator.runfiles/torch/##' $@ 11 | elif [[ $(uname -m) == "aarch64" ]]; then 12 | sed -i 's#bazel-out/aarch64-[^/]*/bin/codegen/lazy_tensor_generator.runfiles/torch/##' $@ 13 | fi 14 | 15 | # use the generated files that are in the compilation unit 16 | sed -i 's##"bazel-out/\1"#' $@ 17 | -------------------------------------------------------------------------------- /docker/debug_image_cleanup.sh: -------------------------------------------------------------------------------- 1 | IMAGE="gcr.io/tpu-pytorch/xla_debug" 2 | DATE=$(date --date='-90 days' +"%Y-%m-%dT%H:%M:%S") 3 | 4 | for digest in $(gcloud container images list-tags ${IMAGE} --limit=999999 --sort-by=TIMESTAMP --filter="timestamp.datetime < '${DATE}'" --format='get(digest)'); do 5 | echo $digest 6 | gcloud container images delete -q --force-delete-tags "${IMAGE}@${digest}" 7 | done 8 | -------------------------------------------------------------------------------- /docker/docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Explicitly source bashrc even when running commands directly. 4 | # Since commands run as a separate subshell, we need to source manually. 5 | # ex. docker run -it gcr.io/tpu-pytorch/xla:nightly bash ... 6 | # The above will not source bashrc without entrypoint. 7 | source ~/.bashrc 8 | 9 | # Activate pytorch conda env at entry by default. 10 | # TODO: This should not be needed as it is already sourced from the .bashrc above. 11 | conda activate pytorch 12 | 13 | exec "$@" 14 | -------------------------------------------------------------------------------- /docker/gcb_pool.yaml: -------------------------------------------------------------------------------- 1 | privatePoolV1Config: 2 | workerConfig: 3 | diskSizeGb: '500' 4 | machineType: e2-standard-32 5 | -------------------------------------------------------------------------------- /docs/.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb diff=none merge=binary 2 | -------------------------------------------------------------------------------- /docs/_static/img/IRgraph_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/IRgraph_markstep.png -------------------------------------------------------------------------------- /docs/_static/img/IRgraph_no_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/IRgraph_no_markstep.png -------------------------------------------------------------------------------- /docs/_static/img/ci_test_dependency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/ci_test_dependency.png -------------------------------------------------------------------------------- /docs/_static/img/ci_test_dependency_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/ci_test_dependency_gpu.png -------------------------------------------------------------------------------- /docs/_static/img/ddp_md_mnist_with_real_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/ddp_md_mnist_with_real_data.png -------------------------------------------------------------------------------- /docs/_static/img/dynamic_shape_mlp_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/dynamic_shape_mlp_perf.png -------------------------------------------------------------------------------- /docs/_static/img/gpt2_2b_step_time_vs_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/gpt2_2b_step_time_vs_batch.png -------------------------------------------------------------------------------- /docs/_static/img/gpt2_v4_8_mfu_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/gpt2_v4_8_mfu_batch.png -------------------------------------------------------------------------------- /docs/_static/img/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/image-1.png -------------------------------------------------------------------------------- /docs/_static/img/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/image-2.png -------------------------------------------------------------------------------- /docs/_static/img/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/image-3.png -------------------------------------------------------------------------------- /docs/_static/img/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/image-4.png -------------------------------------------------------------------------------- /docs/_static/img/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/image.png -------------------------------------------------------------------------------- /docs/_static/img/llama2_2b_bsz128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/llama2_2b_bsz128.png -------------------------------------------------------------------------------- /docs/_static/img/mesh_spmd2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/mesh_spmd2.png -------------------------------------------------------------------------------- /docs/_static/img/perf_auto_vs_manual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/perf_auto_vs_manual.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/spmd_debug_1.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_1_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/spmd_debug_1_light.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/spmd_debug_2.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_debug_2_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/spmd_debug_2_light.png -------------------------------------------------------------------------------- /docs/_static/img/spmd_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/_static/img/spmd_mode.png -------------------------------------------------------------------------------- /docs/docs_build.sh: -------------------------------------------------------------------------------- 1 | # Installs requirements and builds HTML version of PyTorch/XLA docs. 2 | pip install -r requirements.txt 3 | sphinx-build -b html source build -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # This is a copy of the requirements.txt file from the PyTorch repository v2.7.0, 2 | # with some lines commented out. 3 | sphinx==5.0.0 4 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 5 | # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering 6 | # but it doesn't seem to work and hangs around idly. The initial thought is probably 7 | # something related to Docker setup. We can investigate this later 8 | sphinxcontrib.katex==0.8.6 9 | # matplotlib==3.6.0 10 | # tensorboard==2.10.0 11 | # required to build torch.distributed.elastic.rendezvous.etcd* docs 12 | # python-etcd==0.4.5 13 | sphinx-copybutton==0.5.0 14 | # sphinx-panels==0.4.1 15 | myst-parser==0.18.1 16 | 17 | # This is an additional requirement for the PyTorch XLA documentation. 18 | myst-nb==0.16 -------------------------------------------------------------------------------- /docs/source/_static/img/IRgraph_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/IRgraph_markstep.png -------------------------------------------------------------------------------- /docs/source/_static/img/IRgraph_no_markstep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/IRgraph_no_markstep.png -------------------------------------------------------------------------------- /docs/source/_static/img/ci_test_dependency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/ci_test_dependency.png -------------------------------------------------------------------------------- /docs/source/_static/img/ci_test_dependency_gpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/ci_test_dependency_gpu.png -------------------------------------------------------------------------------- /docs/source/_static/img/ddp_md_mnist_with_real_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/ddp_md_mnist_with_real_data.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger0_pack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/debugger0_pack.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger1_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/debugger1_file.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger2_breakpoint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/debugger2_breakpoint.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger3_session.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/debugger3_session.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger4_active.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/debugger4_active.png -------------------------------------------------------------------------------- /docs/source/_static/img/debugger5_break.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/debugger5_break.png -------------------------------------------------------------------------------- /docs/source/_static/img/dist_op_stack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/dist_op_stack.png -------------------------------------------------------------------------------- /docs/source/_static/img/dynamic_shape_mlp_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/dynamic_shape_mlp_perf.png -------------------------------------------------------------------------------- /docs/source/_static/img/gpt2_2b_step_time_vs_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/gpt2_2b_step_time_vs_batch.png -------------------------------------------------------------------------------- /docs/source/_static/img/gpt2_v4_8_mfu_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/gpt2_v4_8_mfu_batch.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/image-1.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/image-2.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/image-3.png -------------------------------------------------------------------------------- /docs/source/_static/img/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/image-4.png -------------------------------------------------------------------------------- /docs/source/_static/img/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/image.png -------------------------------------------------------------------------------- /docs/source/_static/img/llama2_2b_bsz128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/llama2_2b_bsz128.png -------------------------------------------------------------------------------- /docs/source/_static/img/mesh_spmd2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/mesh_spmd2.png -------------------------------------------------------------------------------- /docs/source/_static/img/perf_auto_vs_manual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/perf_auto_vs_manual.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/spmd_debug_1.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_1_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/spmd_debug_1_light.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/spmd_debug_2.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_debug_2_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/spmd_debug_2_light.png -------------------------------------------------------------------------------- /docs/source/_static/img/spmd_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/docs/source/_static/img/spmd_mode.png -------------------------------------------------------------------------------- /docs/source/accelerators/gpu.md: -------------------------------------------------------------------------------- 1 | # Learn about GPUs 2 | 3 | For information on GPUs on Google Cloud, see: 4 | 5 | - [About GPUs on Google Cloud](https://cloud.google.com/compute/docs/gpus/overview) 6 | - [GPU machine types](https://cloud.google.com/compute/docs/gpus) 7 | -------------------------------------------------------------------------------- /examples/data_parallel/README.md: -------------------------------------------------------------------------------- 1 | ## Recommendation 2 | Please consider using `train_resnet_spmd_data_parallel.py` since it uses SPMD internally and are very likely yield better perfomrance. 3 | -------------------------------------------------------------------------------- /examples/data_parallel/train_resnet_xla_ddp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_resnet_base import TrainResNetBase 6 | 7 | import torch_xla 8 | import torch_xla.core.xla_model as xm 9 | import torch_xla.runtime as xr 10 | 11 | 12 | class TrainResNetXLADDP(TrainResNetBase): 13 | 14 | def run_optimizer(self): 15 | # optimizer_step will call `optimizer.step()` and all_reduce the gradident 16 | xm.optimizer_step(self.optimizer) 17 | 18 | 19 | def _mp_fn(index): 20 | # cache init needs to happens inside the mp_fn. 21 | xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False) 22 | xla_ddp = TrainResNetXLADDP() 23 | xla_ddp.start_training() 24 | 25 | 26 | if __name__ == '__main__': 27 | print('consider using train_resnet_spmd_data_parallel.py instead to get better performance') 28 | torch_xla.launch(_mp_fn, args=()) 29 | -------------------------------------------------------------------------------- /examples/eager/train_decoder_only_eager.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_decoder_only_base import TrainDecoderOnlyBase 6 | 7 | import torch_xla 8 | 9 | 10 | class TrainDecoderOnlyEager(TrainDecoderOnlyBase): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | # We want to run the step fn eagerly. 15 | self.compiled_step_fn = self.step_fn 16 | 17 | 18 | if __name__ == '__main__': 19 | torch_xla.experimental.eager_mode(True) 20 | base = TrainDecoderOnlyEager() 21 | base.start_training() 22 | -------------------------------------------------------------------------------- /examples/eager/train_decoder_only_eager_multi_process.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_decoder_only_base import TrainDecoderOnlyBase 6 | 7 | import torch_xla 8 | import torch_xla.core.xla_model as xm 9 | 10 | 11 | class TrainDecoderXLADDP(TrainDecoderOnlyBase): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | # We want to run the step fn eagerly. 16 | self.compiled_step_fn = self.step_fn 17 | 18 | def run_optimizer(self): 19 | # optimizer_step will call `optimizer.step()` and all_reduce the gradident 20 | xm.optimizer_step(self.optimizer) 21 | 22 | 23 | def _mp_fn(index): 24 | import torch_xla 25 | torch_xla.experimental.eager_mode(True) 26 | xla_ddp = TrainDecoderXLADDP() 27 | xla_ddp.start_training() 28 | 29 | 30 | if __name__ == '__main__': 31 | torch_xla.launch(_mp_fn, args=()) 32 | -------------------------------------------------------------------------------- /examples/eager/train_decoder_only_eager_with_compile.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) 4 | sys.path.append(example_folder) 5 | from train_decoder_only_base import TrainDecoderOnlyBase 6 | 7 | import torch_xla 8 | 9 | if __name__ == '__main__': 10 | # The step fn will still be compiled, random input generation happens eagerly. 11 | torch_xla.experimental.eager_mode(True) 12 | trainer = TrainDecoderOnlyBase() 13 | trainer.start_training() 14 | -------------------------------------------------------------------------------- /examples/fsdp/README.md: -------------------------------------------------------------------------------- 1 | ## Recommendation 2 | Please consider using `train_decoder_only_fsdp_v2.py` since it uses SPMD internally and are very likely yield better perfomrance. 3 | -------------------------------------------------------------------------------- /examples/scan/README.md: -------------------------------------------------------------------------------- 1 | ../../docs/source/features/scan.md -------------------------------------------------------------------------------- /examples/scan/decoder_with_scan.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import override 2 | from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel 3 | 4 | 5 | class DecoderWithScan(DecoderOnlyModel): 6 | 7 | def __init__(self, config: DecoderOnlyConfig): 8 | super().__init__(config) 9 | 10 | @override 11 | def run_decoder_layers(self, hidden_states): 12 | from torch_xla.experimental.scan_layers import scan_layers 13 | return scan_layers(self.layers, hidden_states) 14 | -------------------------------------------------------------------------------- /experimental/reference_models/sdxl_inference/README.md: -------------------------------------------------------------------------------- 1 | # How to run: 2 | 3 | ``` 4 | python sdxl.py 5 | ``` -------------------------------------------------------------------------------- /experimental/reference_models/sdxl_inference/astronaut_rides_horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/experimental/reference_models/sdxl_inference/astronaut_rides_horse.png -------------------------------------------------------------------------------- /experimental/reference_models/sdxl_inference/sdxl_beginning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline 3 | 4 | import torch_xla2 5 | env = torch_xla2.default_env() 6 | 7 | # this is now contains torhc.Tensor 8 | pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") 9 | 10 | with env: 11 | pipe.to('jax') 12 | prompt = "a photograph of an astronaut riding a horse" 13 | image = pipe(prompt, num_inference_steps=10).images[0] 14 | image.save(f"astronaut_rides_horse_orig.png") 15 | -------------------------------------------------------------------------------- /external: -------------------------------------------------------------------------------- 1 | bazel-xla/external -------------------------------------------------------------------------------- /infra/ansible/.ansible-lint: -------------------------------------------------------------------------------- 1 | --- 2 | # .ansible-lint 3 | 4 | profile: moderate 5 | skip_list: 6 | - schema[tasks] -------------------------------------------------------------------------------- /infra/ansible/ansible.cfg: -------------------------------------------------------------------------------- 1 | # See https://docs.ansible.com/ansible/latest/reference_appendices/config.html 2 | # for various configuration options. 3 | 4 | [defaults] 5 | # Displays tasks execution duration. 6 | callbacks_enabled = profile_tasks 7 | # The playbooks is only run on the implicit localhost. 8 | # Silence warning about empty hosts inventory. 9 | localhost_warning = False 10 | # Make output human-readable. 11 | stdout_callback = yaml 12 | 13 | [inventory] 14 | # Silence warning about no inventory. 15 | # This option is available since Ansible 2.14 (available only with Python 3.9+). 16 | inventory_unparsed_warning = False -------------------------------------------------------------------------------- /infra/ansible/config/vars.yaml: -------------------------------------------------------------------------------- 1 | # Used for fetching cuda from the right repo, see apt.yaml. 2 | cuda_repo: debian11 3 | cuda_version: "11.8" 4 | # Determines supported GPUs. See https://developer.nvidia.com/cuda-gpus 5 | cuda_compute_capabilities: 5.2,7.0,7.5,8.0,9.0 6 | # Used for fetching clang from the right repo, see apt.yaml. 7 | llvm_debian_repo: bullseye 8 | clang_version: 17 9 | # PyTorch and PyTorch/XLA wheel versions. 10 | package_version: 2.8.0 11 | # If set to true, wheels will be renamed to $WHEEL_NAME-nightly-cp38-cp38-linux_x86_64.whl. 12 | nightly_release: false 13 | # Whether to preinstall libtpu in the PyTorch/XLA wheel. Ignored for GPU build. 14 | bundle_libtpu: 1 15 | # Suffix for bazel remote cache key 16 | cache_suffix: "" 17 | # Whether to build C++ tests with `torch_xla` wheel 18 | build_cpp_tests: 0 19 | # Whether to tag wheels with git hash, e.g. X.Y.Z+git123abc 20 | git_versioned_xla_build: false 21 | # Whether to use C++11 ABI when building torch and torch_xla. 22 | cxx11_abi: 1 23 | -------------------------------------------------------------------------------- /infra/ansible/development.Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for building a development image. 2 | # The built image contains all required pip and apt packages for building and 3 | # running PyTorch and PyTorch/XLA. The image doesn't contain any source code. 4 | ARG python_version=3.8 5 | ARG debian_version=bullseye 6 | 7 | FROM python:${python_version}-${debian_version} 8 | 9 | RUN pip install ansible 10 | 11 | COPY . /ansible 12 | WORKDIR /ansible 13 | 14 | # List Asnible tasks to apply for the dev image. 15 | ENV TAGS="bazel,configure_env,install_deps" 16 | 17 | ARG ansible_vars 18 | RUN ansible-playbook playbook.yaml -e "stage=build" -e "${ansible_vars}" --tags "${TAGS}" 19 | RUN ansible-playbook playbook.yaml -e "stage=release" -e "${ansible_vars}" --tags "${TAGS}" 20 | -------------------------------------------------------------------------------- /infra/ansible/roles/bazel/defaults/main.yaml: -------------------------------------------------------------------------------- 1 | bazelisk_version: 1.15.0 2 | -------------------------------------------------------------------------------- /infra/ansible/roles/bazel/tasks/main.yaml: -------------------------------------------------------------------------------- 1 | - name: "Download bazelisk v{{ bazelisk_version }}" 2 | ansible.builtin.get_url: 3 | url: "https://github.com/bazelbuild/bazelisk/releases/download/v{{ bazelisk_version }}/bazelisk-linux-amd64" 4 | dest: /usr/local/bin/bazel 5 | mode: 'u=rxw,g=rw,o=r' 6 | 7 | - name: "Tests" 8 | include_tasks: tests.yaml 9 | tags: 10 | - tests 11 | -------------------------------------------------------------------------------- /infra/ansible/roles/bazel/tasks/tests.yaml: -------------------------------------------------------------------------------- 1 | - name: "Bazel --version runs succesfully" 2 | ansible.builtin.command: 3 | cmd: bazel --version 4 | -------------------------------------------------------------------------------- /infra/ansible/roles/build_srcs/tasks/tests.yaml: -------------------------------------------------------------------------------- 1 | - name: "Check that various import statements work" 2 | ansible.builtin.command: 3 | cmd: "{{ item }}" 4 | environment: "{{ env_vars | combine({'USE_CUDA': 0}) }}" 5 | loop: 6 | - python -c "import torchgen" 7 | - python -c "import torch" 8 | - python -c "import torch_xla" 9 | - python -c "import torch_xla.core.xla_model" 10 | -------------------------------------------------------------------------------- /infra/ansible/roles/configure_env/tasks/main.yaml: -------------------------------------------------------------------------------- 1 | - name: Append environment variables required during runtime to ~/.bashrc 2 | ansible.builtin.lineinfile: 3 | path: ~/.bashrc 4 | line: "export {{ item }}={{ env_vars[item] }}" 5 | create: true 6 | loop: "{{ env_vars.keys() | list }}" 7 | 8 | - name: Append environment variables required during runtime to ~/.zshrc 9 | ansible.builtin.lineinfile: 10 | path: ~/.zshrc 11 | line: "export {{ item }}={{ env_vars[item] }}" 12 | create: true 13 | loop: "{{ env_vars.keys() | list }}" 14 | -------------------------------------------------------------------------------- /infra/ansible/roles/fetch_srcs/defaults/main.yaml: -------------------------------------------------------------------------------- 1 | # See https://docs.ansible.com/ansible/latest/collections/ansible/builtin/git_module.html#parameter-version 2 | pytorch_git_rev: HEAD 3 | xla_git_rev: HEAD 4 | -------------------------------------------------------------------------------- /infra/ansible/roles/fetch_srcs/tasks/main.yaml: -------------------------------------------------------------------------------- 1 | - name: "Create source root directory at {{ src_root }}" 2 | ansible.builtin.file: 3 | path: "{{ src_root }}" 4 | state: directory 5 | mode: '0755' 6 | 7 | - name: "Clone git PyTorch and XLA git repos" 8 | ansible.builtin.git: 9 | repo: "{{ item.repo }}" 10 | dest: "{{ item.dest }}" 11 | version: "{{ item.version }}" 12 | depth: 1 13 | force: true 14 | loop: 15 | - repo: https://github.com/pytorch/pytorch 16 | dest: "{{ (src_root, 'pytorch') | path_join }}" 17 | version: "{{ pytorch_git_rev }}" 18 | 19 | - repo: https://github.com/pytorch/xla 20 | dest: "{{ (src_root, 'pytorch/xla') | path_join }}" 21 | version: "{{ xla_git_rev }}" 22 | 23 | - name: "Tests" 24 | include_tasks: tests.yaml 25 | tags: 26 | - tests 27 | -------------------------------------------------------------------------------- /infra/ansible/roles/fetch_srcs/tasks/tests.yaml: -------------------------------------------------------------------------------- 1 | - name: Retrieve status of setup.py files in XLA and PyTorch repos 2 | ansible.builtin.stat: 3 | path: "{{ item }}" 4 | register: _res 5 | loop: 6 | - "{{ (src_root, 'pytorch/setup.py') | path_join }}" 7 | - "{{ (src_root, 'pytorch/xla/setup.py') | path_join }}" 8 | 9 | - name: Assert that setup.py files exist 10 | ansible.builtin.assert: 11 | that: "{{ item.stat.exists }}" 12 | fail_msg: "{{ item.item }} doesn't exist" 13 | loop: "{{ _res.results }}" 14 | -------------------------------------------------------------------------------- /infra/terraform_modules/arc_v4_container_cluster/README.md: -------------------------------------------------------------------------------- 1 | # Cluster creation for TPU CI for PyTorch/XLA 2 | 3 | This module configures: 4 | * A regional GKE cluster 5 | * A CPU node pool 6 | * An autoscaling v4 TPU node pool 7 | * The installation of Actions Runner Controller (ARC) on the GKE cluster 8 | -------------------------------------------------------------------------------- /infra/terraform_modules/arc_v4_container_cluster/arc-values.yaml: -------------------------------------------------------------------------------- 1 | githubConfigUrl: ${github_repo_url} 2 | githubConfigSecret: github-pat 3 | minRunners: ${min_tpu_nodes} 4 | maxRunners: ${max_tpu_nodes} 5 | template: 6 | spec: 7 | containers: 8 | - name: runner 9 | image: ${runner_image} 10 | command: ["/home/runner/run.sh"] 11 | resources: 12 | limits: 13 | google.com/tpu: 4 14 | requests: 15 | google.com/tpu: 4 16 | nodeSelector: 17 | cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice 18 | cloud.google.com/gke-tpu-topology: 2x2x1 19 | -------------------------------------------------------------------------------- /infra/terraform_modules/trigger_schedule_account/service_account.tf: -------------------------------------------------------------------------------- 1 | resource "google_service_account" "build_runner" { 2 | account_id = "build-triggers-scheduler" 3 | description = "Service account for Scheduled Jobs. Has permissions to trigger Cloud Builds." 4 | } 5 | 6 | resource "google_project_iam_custom_role" "build_runner" { 7 | role_id = "build_runner" 8 | title = "Build Runner" 9 | description = "Grants permissions to trigger Cloud Builds." 10 | permissions = ["cloudbuild.builds.create"] 11 | } 12 | 13 | data "google_project" "project" {} 14 | 15 | resource "google_project_iam_member" "build_runner" { 16 | role = google_project_iam_custom_role.build_runner.name 17 | project = data.google_project.project.project_id 18 | member = "serviceAccount:${google_service_account.build_runner.email}" 19 | } 20 | 21 | output "email" { 22 | value = google_service_account.build_runner.email 23 | } 24 | -------------------------------------------------------------------------------- /infra/terraform_modules/worker_pool/worker_pool.tf: -------------------------------------------------------------------------------- 1 | variable "name" { 2 | default = "main" 3 | } 4 | 5 | variable "location" { 6 | default = "us-central1" 7 | } 8 | 9 | variable "machine_type" { 10 | default = "e2-standard-32" 11 | } 12 | 13 | variable "disk_size_gb" { 14 | default = 500 15 | } 16 | 17 | resource "google_cloudbuild_worker_pool" "worker_pool" { 18 | name = var.name 19 | location = var.location 20 | 21 | worker_config { 22 | disk_size_gb = var.disk_size_gb 23 | machine_type = var.machine_type 24 | no_external_ip = false 25 | } 26 | } 27 | 28 | output "id" { 29 | value = google_cloudbuild_worker_pool.worker_pool.id 30 | } 31 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/dev_images.auto.tfvars: -------------------------------------------------------------------------------- 1 | dev_images = [ 2 | { 3 | accelerator = "tpu" 4 | extra_tags = ["tpu"] 5 | python_version = "3.10" 6 | }, 7 | { 8 | accelerator = "cuda" 9 | cuda_version = "12.1" 10 | extra_tags = ["cuda"] 11 | python_version = "3.10" 12 | }, 13 | { 14 | accelerator = "cuda" 15 | cuda_version = "12.3" 16 | extra_tags = ["cuda"] 17 | python_version = "3.10" 18 | } 19 | ] 20 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/iam.auto.tfvars: -------------------------------------------------------------------------------- 1 | project_admins = [ 2 | "group:cloud-tpus-grpadm@twosync.google.com", 3 | "group:pytorchxla-dev@google.com", 4 | ] 5 | 6 | project_remote_build_writers = [ 7 | # "group:pytorchxla-announce@google.com", 8 | "serviceAccount:1001674285173@cloudbuild.gserviceaccount.com", 9 | "group:cloud-tpus-dev-team@twosync.google.com", 10 | ] 11 | 12 | cloudbuild_editors = [ 13 | ] 14 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/infra_triggers.tf: -------------------------------------------------------------------------------- 1 | module "terraform_apply" { 2 | source = "../terraform_modules/apply_terraform_trigger" 3 | 4 | included_files = ["infra/**"] 5 | branch = "master" 6 | config_directory = "infra/tpu-pytorch-releases" 7 | 8 | worker_pool_id = module.worker_pool.id 9 | } 10 | -------------------------------------------------------------------------------- /infra/tpu-pytorch-releases/provider.tf: -------------------------------------------------------------------------------- 1 | # Run `gcloud auth application-default login` before running Terraform. 2 | provider "google" { 3 | project = "tpu-pytorch-releases" 4 | region = "us-central1" 5 | } 6 | 7 | terraform { 8 | required_providers { 9 | google = { 10 | source = "hashicorp/google" 11 | version = ">= 4.52.0" 12 | } 13 | } 14 | 15 | backend "gcs" { 16 | # Make sure that bucket name matches the one specified in ./misc.tf. 17 | bucket = "tpu-pytorch-releases-tfstate" 18 | prefix = "terraform/state" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/README.md: -------------------------------------------------------------------------------- 1 | # Terraform setup for pytorch-xla GCP project 2 | 3 | This setup configures: 4 | * Private docker repository (`docker`). 5 | * CI test trigger. -------------------------------------------------------------------------------- /infra/tpu-pytorch/iam.auto.tfvars: -------------------------------------------------------------------------------- 1 | project_remote_build_writers = [ 2 | "group:cloud-tpus-dev-team@twosync.google.com", 3 | "user:mlewko@google.com", 4 | "user:goranpetrovic@google.com", 5 | # tpu-pytorch-releases project: default Service Account for running Cloud Build jobs. 6 | "serviceAccount:1001674285173@cloudbuild.gserviceaccount.com" 7 | ] 8 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/infra_triggers.tf: -------------------------------------------------------------------------------- 1 | module "terraform_apply" { 2 | source = "../terraform_modules/apply_terraform_trigger" 3 | 4 | included_files = ["infra/**"] 5 | branch = "master" 6 | config_directory = "infra/tpu-pytorch" 7 | 8 | worker_pool_id = module.worker_pool.id 9 | location = "global" 10 | } 11 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/misc.tf: -------------------------------------------------------------------------------- 1 | # Docker registry for private images. 2 | module "docker_registry" { 3 | source = "../terraform_modules/docker_registry" 4 | name = "docker" 5 | description = join(" ", [ 6 | "Private docker images for PyTorch/XLA.", 7 | "Managed by Terraform setup in docker/experimental/tpu-pytorch/misc.tf.", 8 | ]) 9 | } 10 | 11 | # Storage bucket for Terraform state of this project. 12 | module "tfstate_storage_bucket" { 13 | source = "../terraform_modules/storage_bucket" 14 | name = "tpu-pytorch-tfstate" 15 | } 16 | 17 | # Private worker pool for Cloud Builds. 18 | module "worker_pool" { 19 | source = "../terraform_modules/worker_pool" 20 | # See https://cloud.google.com/compute/docs/machine-resource#machine_type_comparison. 21 | machine_type = "e2-standard-32" 22 | } 23 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/provider.tf: -------------------------------------------------------------------------------- 1 | # Run `gcloud auth application-default login` before running Terraform. 2 | provider "google" { 3 | project = "tpu-pytorch" 4 | region = "us-central1" 5 | } 6 | 7 | terraform { 8 | required_providers { 9 | google = { 10 | source = "hashicorp/google" 11 | version = ">= 4.52.0" 12 | } 13 | } 14 | 15 | backend "gcs" { 16 | # Make sure that bucket name matches the only specified in ./misc.tf. 17 | bucket = "tpu-pytorch-tfstate" 18 | prefix = "terraform/state" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /infra/tpu-pytorch/tpu_ci.tf: -------------------------------------------------------------------------------- 1 | module "v4_arc_cluster" { 2 | source = "../terraform_modules/arc_v4_container_cluster" 3 | project_id = "tpu-pytorch" 4 | cluster_name = "tpu-ci" 5 | cpu_nodepool_name = "cpu-nodepool" 6 | cpu_node_count = 32 7 | tpu_nodepool_name = "tpu-nodepool" 8 | min_tpu_nodes = 32 9 | max_tpu_nodes = 32 10 | github_repo_url = "https://github.com/pytorch/xla" 11 | # Dockerfile for this image can be found at test/tpu/Dockerfile 12 | runner_image = "gcr.io/tpu-pytorch/tpu-ci-runner:latest" 13 | } 14 | -------------------------------------------------------------------------------- /openxla_patches/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/openxla_patches/BUILD -------------------------------------------------------------------------------- /openxla_patches/count_down.diff: -------------------------------------------------------------------------------- 1 | diff --git a/xla/backends/cpu/runtime/convolution_thunk_internal.h b/xla/backends/cpu/runtime/convolution_thunk_internal.h 2 | index 84fed6bb78..9835f12e4e 100644 3 | --- a/xla/backends/cpu/runtime/convolution_thunk_internal.h 4 | +++ b/xla/backends/cpu/runtime/convolution_thunk_internal.h 5 | @@ -342,7 +342,8 @@ void EigenGenericConv2D( 6 | Eigen::Index start = task_index * task_size; 7 | Eigen::Index end = std::min(start + task_size, feature_group_count); 8 | for (Eigen::Index i = start; i < end; ++i) { 9 | - auto on_done = [count_down]() mutable { count_down.CountDown(); }; 10 | + // auto on_done = [count_down]() mutable { count_down.CountDown(); }; 11 | + auto on_done = [count_down]() mutable { const_cast(&count_down)->CountDown(); }; 12 | auto [output, convolved] = convolve_group(i); 13 | output.device(device, std::move(on_done)) = convolved; 14 | } 15 | -------------------------------------------------------------------------------- /openxla_patches/gpu_race_condition.diff: -------------------------------------------------------------------------------- 1 | diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc 2 | index 9279bd877..fab926a7c 100644 3 | --- a/xla/service/gpu/gpu_executable.cc 4 | +++ b/xla/service/gpu/gpu_executable.cc 5 | @@ -669,8 +669,7 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( 6 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 7 | 8 | // Force synchronous execution if the allocator requires it. 9 | - const bool block_host_until_done = 10 | - !memory_allocator->AllowsAsynchronousDeallocation(); 11 | + const bool block_host_until_done = true; 12 | 13 | // Lock the GPU with a shared lock so that we don't interfere with autotuning 14 | // that may be running during JIT compilation while allowing multiple XLA -------------------------------------------------------------------------------- /plugins/cpu/BUILD: -------------------------------------------------------------------------------- 1 | load( 2 | "@xla//xla:xla.bzl", 3 | "xla_cc_binary", 4 | ) 5 | 6 | cc_library( 7 | name = "test_cpu_plugin", 8 | srcs = ["test_cpu_plugin.cpp"], 9 | hdrs = ["test_cpu_plugin.h"], 10 | visibility = ["//visibility:public"], 11 | deps = [ 12 | "@xla//xla/pjrt/c:pjrt_c_api_cpu_internal", 13 | "@xla//xla/pjrt/c:pjrt_c_api_hdrs", 14 | ], 15 | ) 16 | 17 | xla_cc_binary( 18 | name = "pjrt_c_api_cpu_plugin.so", 19 | linkopts = [ 20 | "-Wl,--version-script,$(location :pjrt_c_api_cpu_version_script.lds)", 21 | "-Wl,--no-undefined", 22 | ], 23 | linkshared = True, 24 | deps = [ 25 | ":pjrt_c_api_cpu_version_script.lds", 26 | ":test_cpu_plugin", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /plugins/cpu/pjrt_c_api_cpu_version_script.lds: -------------------------------------------------------------------------------- 1 | # Only symbols in the global section are available to other frameworks. 2 | VERS_1.0 { 3 | global: 4 | extern "C" { 5 | GetPjrtApi; 6 | }; 7 | 8 | local: 9 | *; 10 | }; 11 | -------------------------------------------------------------------------------- /plugins/cpu/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "torch_xla_cpu_plugin" 7 | version = "0.0.1" 8 | authors = [ 9 | {name = "PyTorch/XLA Dev Team'", email = "pytorch-xla@googlegroups.com"}, 10 | ] 11 | description = "CPU PJRT Plugin for testing only" 12 | requires-python = ">=3.8" 13 | 14 | [tool.setuptools.package-data] 15 | torch_xla_cpu_plugin = ["lib/*.so"] 16 | 17 | [project.entry-points."torch_xla.plugins"] 18 | example = "torch_xla_cpu_plugin:CpuPlugin" 19 | -------------------------------------------------------------------------------- /plugins/cpu/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # add `build_util` to import path 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 6 | 7 | import build_util 8 | import setuptools 9 | 10 | build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so', 11 | 'torch_xla_cpu_plugin/lib') 12 | 13 | setuptools.setup() 14 | -------------------------------------------------------------------------------- /plugins/cpu/test_cpu_plugin.cpp: -------------------------------------------------------------------------------- 1 | #include "plugins/cpu/test_cpu_plugin.h" 2 | 3 | #include 4 | 5 | #include "xla/pjrt/c/pjrt_c_api.h" 6 | #include "xla/pjrt/c/pjrt_c_api_cpu_internal.h" 7 | 8 | // Use `test` as the platform name instead of `cpu` so torch_xla treats this 9 | // as an unknown device. 10 | PJRT_Error* test_platform_name(PJRT_Client_PlatformName_Args* args) { 11 | static const std::string platform_name = "test"; 12 | args->platform_name = platform_name.c_str(); 13 | args->platform_name_size = platform_name.size(); 14 | return nullptr; 15 | } 16 | 17 | const PJRT_Api* GetPjrtApi() { 18 | // HACK: The CPU client is created as a constexpr, so const-casting is 19 | // undefined behavior. Make a non-const copy of the struct so we can override 20 | // methods. Don't do this for a real plugin. 21 | static PJRT_Api pjrt_api = *pjrt::cpu_plugin::GetCpuPjrtApi(); 22 | pjrt_api.PJRT_Client_PlatformName = test_platform_name; 23 | 24 | return &pjrt_api; 25 | } 26 | -------------------------------------------------------------------------------- /plugins/cpu/test_cpu_plugin.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_PJRT_C_PJRT_C_API_CPU_H_ 2 | #define XLA_PJRT_C_PJRT_C_API_CPU_H_ 3 | 4 | #include "xla/pjrt/c/pjrt_c_api.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | const PJRT_Api* GetPjrtApi(); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | 16 | #endif // XLA_PJRT_C_PJRT_C_API_CPU_H_ 17 | -------------------------------------------------------------------------------- /plugins/cpu/torch_xla_cpu_plugin/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch_xla.experimental import plugins 3 | from torch_xla._internal import tpu 4 | 5 | 6 | class CpuPlugin(plugins.DevicePlugin): 7 | 8 | def library_path(self) -> str: 9 | return os.path.join( 10 | os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so') 11 | 12 | def physical_chip_count(self) -> int: 13 | return 1 14 | -------------------------------------------------------------------------------- /plugins/cuda/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "numpy"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "torch_xla_cuda_plugin" 7 | authors = [ 8 | {name = "PyTorch/XLA Dev Team", email = "pytorch-xla@googlegroups.com"}, 9 | ] 10 | description = "PyTorch/XLA CUDA Plugin" 11 | requires-python = ">=3.8" 12 | dynamic = ["version"] 13 | 14 | [tool.setuptools.package-data] 15 | torch_xla_cuda_plugin = ["lib/*.so"] 16 | 17 | [project.entry-points."torch_xla.plugins"] 18 | cuda = "torch_xla_cuda_plugin:CudaPlugin" 19 | -------------------------------------------------------------------------------- /plugins/cuda/setup.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import sys 4 | 5 | # add `build_util` to import path 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 7 | 8 | import build_util 9 | import setuptools 10 | 11 | build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', 12 | 'torch_xla_cuda_plugin/lib', ['--config=cuda']) 13 | 14 | setuptools.setup( 15 | # TODO: Use a common version file 16 | version=os.getenv('TORCH_XLA_VERSION', 17 | f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}')) 18 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | filelock 2 | fsspec 3 | jinja2 4 | markupsafe 5 | mpmath 6 | networkx 7 | pyyaml 8 | sympy 9 | typing-extensions 10 | -------------------------------------------------------------------------------- /scripts/dump_stacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # The following command is needed (as root) in order to enable GDB to attach 4 | # existing user processes: 5 | # 6 | # echo 0 > /proc/sys/kernel/yama/ptrace_scope 7 | # 8 | 9 | import argparse 10 | import stack_trace_parse as stp 11 | import subprocess 12 | 13 | 14 | def get_stacks(pid): 15 | return subprocess.check_output([ 16 | 'gdb', '-p', 17 | str(pid), '-batch', '-ex', 'thread apply all bt', '-ex', 'quit' 18 | ]).decode('utf-8') 19 | 20 | 21 | def dump_stacks(args): 22 | stacks = get_stacks(args.pid) 23 | stp.process_stack_lines(stacks.splitlines(True), args) 24 | 25 | 26 | if __name__ == '__main__': 27 | arg_parser = argparse.ArgumentParser() 28 | arg_parser.add_argument( 29 | 'pid', 30 | type=int, 31 | metavar='PID', 32 | help='The process ID whose stacks need to be dumped') 33 | args, files = arg_parser.parse_known_args() 34 | dump_stacks(args) 35 | -------------------------------------------------------------------------------- /scripts/metrics_to_tensorboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Add metrics images to tensorboard summary for easy viewing/comparisons 3 | 4 | import argparse 5 | import os 6 | 7 | from PIL import Image 8 | from pathlib import Path 9 | 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torchvision.transforms import ToTensor 12 | 13 | 14 | def generate_tensorboard_img_summary(logdir, imgdir): 15 | writer = SummaryWriter(logdir) 16 | all_metrics_graphs = list(Path(imgdir).rglob("*.png")) 17 | 18 | for img in all_metrics_graphs: 19 | tag = os.path.basename(img) 20 | img_tensor = ToTensor()(Image.open(img)) 21 | writer.add_image(tag, img_tensor, 0) 22 | writer.close() 23 | 24 | 25 | if __name__ == '__main__': 26 | arg_parser = argparse.ArgumentParser() 27 | arg_parser.add_argument('--logdir', type=str) 28 | arg_parser.add_argument('--imgdir', type=str) 29 | args = arg_parser.parse_args() 30 | generate_tensorboard_img_summary(args.logdir, args.imgdir) 31 | -------------------------------------------------------------------------------- /scripts/normalize_graph_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | 9 | def normalize(args): 10 | fd = sys.stdin 11 | if args.input: 12 | fd = open(args.input) 13 | # %397 = f32[128]{0} xla::cross_replica_sum(%396), scale=0.125, groups=() 14 | for line in fd: 15 | line.rstrip('\n') 16 | m = re.match(r'(\s*)%\d+\s*=\s*(.*::[^(]+\()[^)]*(.*)', line) 17 | if m: 18 | line = m.group(1) + m.group(2) + m.group(3) 19 | print(line) 20 | 21 | 22 | if __name__ == '__main__': 23 | arg_parser = argparse.ArgumentParser() 24 | arg_parser.add_argument('--input', type=str) 25 | args = arg_parser.parse_args() 26 | normalize(args) 27 | -------------------------------------------------------------------------------- /scripts/run_bazel_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" 3 | export XRT_WORKERS="localservice:0;grpc://localhost:40934" 4 | export XLA_EXPERIMENTAL="nonzero:masked_select" 5 | 6 | BAZEL_REMOTE_CACHE_CONFIG="--config=remote_cache --remote_default_exec_properties=cache-silo-key=cache-silo-coverage" 7 | if [ ! -z "$GCLOUD_SERVICE_KEY_FILE" ]; then 8 | file_size=$(stat -c%s "$GCLOUD_SERVICE_KEY_FILE") 9 | if [ "$file_size" -le 1 ]; then 10 | BAZEL_REMOTE_CACHE_CONFIG="" 11 | fi 12 | fi 13 | 14 | bazel coverage $BAZEL_REMOTE_CACHE_CONFIG //... 15 | cp "$(bazel info output_path)/_coverage/_coverage_report.dat" /tmp/cov_xrt.dat 16 | 17 | export PJRT_DEVICE="CPU" 18 | bazel coverage $BAZEL_REMOTE_CACHE_CONFIG //test/... 19 | cp "$(bazel info output_path)/_coverage/_coverage_report.dat" /tmp/cov_pjrt.dat 20 | 21 | # requires `apt-get install lcov` 22 | lcov --add-tracefile /tmp/cov_xrt.dat -a /tmp/cov_pjrt.dat -o /tmp/merged.dat 23 | genhtml /tmp/merged.dat -o CodeCoveragn m -------------------------------------------------------------------------------- /scripts/tf_log_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | 9 | def normalize(args): 10 | fd = sys.stdin 11 | if args.input: 12 | fd = open(args.input) 13 | # 2019-04-06 02:51:26.397580: I torch_xla/csrc/forward.cpp:168] 14 | for line in fd: 15 | line.rstrip('\n') 16 | m = re.match(r'.*:\d+\] (.*)', line) 17 | if m: 18 | print(m.group(1)) 19 | else: 20 | print(line) 21 | 22 | 23 | if __name__ == '__main__': 24 | arg_parser = argparse.ArgumentParser() 25 | arg_parser.add_argument('--input', type=str) 26 | args = arg_parser.parse_args() 27 | normalize(args) 28 | -------------------------------------------------------------------------------- /scripts/update_nightly_torch_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | # Activate torch-xla-nightly conda env if not already in it 6 | if [ "$CONDA_DEFAULT_ENV" != "torch-xla-nightly" ]; then 7 | conda activate torch-xla-nightly 8 | fi 9 | 10 | $(dirname ${BASH_SOURCE[0]})/update_torch_wheels.sh 11 | -------------------------------------------------------------------------------- /scripts/update_torch_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -x 4 | 5 | DIST_BUCKET="gs://tpu-pytorch/wheels" 6 | TORCH_WHEEL="torch-nightly-cp36-cp36m-linux_x86_64.whl" 7 | TORCH_XLA_WHEEL="torch_xla-nightly-cp36-cp36m-linux_x86_64.whl" 8 | TORCHVISION_WHEEL="torchvision-nightly-cp36-cp36m-linux_x86_64.whl" 9 | 10 | [[ ! -z "$1" ]] && conda activate $1 11 | 12 | function update_wheels() { 13 | gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" /tmp/ 14 | gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" /tmp/ 15 | gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" /tmp/ 16 | pip install "/tmp/$TORCH_WHEEL" 17 | pip install "/tmp/$TORCH_XLA_WHEEL" 18 | pip install "/tmp/$TORCHVISION_WHEEL" 19 | 20 | rm -f "/tmp/$TORCH_WHEEL" "/tmp/$TORCH_XLA_WHEEL" "/tmp/$TORCHVISION_WHEEL" 21 | } 22 | 23 | update_wheels 24 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/test/__init__.py -------------------------------------------------------------------------------- /test/benchmarks/.gitignore: -------------------------------------------------------------------------------- 1 | *.tmp 2 | -------------------------------------------------------------------------------- /test/benchmarks/Makefile: -------------------------------------------------------------------------------- 1 | TEST_ARGS = $(shell echo $@ | perl -pe 's/([^.]*)\.([^.]*)\.([^.]*).*\.test/--accelerator=$$1 --test=$$2 --report=$$3/') 2 | EMBEDDED_TEST_ARGS = $(shell cat $@ | grep '^# ARGS: ' | perl -pe 's/^# ARGS: (.*)/$$1/') 3 | 4 | TESTS := $(wildcard *.test) 5 | all: $(TESTS) 6 | .PHONY: $(TESTS) all 7 | 8 | ifndef V 9 | QUIET_AGGREGATE = @echo ' ' AGGREGATE $(TEST_ARGS) $(EMBEDDED_TEST_ARGS); 10 | QUIET_DIFF = @echo ' ' DIFF $@; 11 | QUIET_RM = @echo ' ' RM $@.tmp; 12 | endif 13 | 14 | $(TESTS): 15 | $(QUIET_AGGREGATE)python3 ../../benchmarks/aggregate.py \ 16 | --format=csv \ 17 | $(TEST_ARGS) $(EMBEDDED_TEST_ARGS) \ 18 | $(wildcard *.jsonl) > $@.tmp 19 | $(QUIET_DIFF)git diff -I'^# ARGS: ' --no-index $@ $@.tmp 20 | $(QUIET_RM)$(RM) $@.tmp 21 | 22 | clean: 23 | $(RM) *.tmp 24 | .PHONY: clean 25 | -------------------------------------------------------------------------------- /test/benchmarks/a6000.inference.speedup.test: -------------------------------------------------------------------------------- 1 | # Datetime(UTC),Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+Dynamo/Oldest Inductor),StdDev 2 | 2023-11-11 04:43:56.070348,1.0,0.0,, 3 | -------------------------------------------------------------------------------- /test/benchmarks/a6000.training.latest.empty.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends openxla+lazytensor -- 2 | -------------------------------------------------------------------------------- /test/benchmarks/a6000.training.latest.test: -------------------------------------------------------------------------------- 1 | # Workload,Speedup(Inductor/Oldest Inductor),StdDev,ModelName(Inductor),Speedup(XLA+Dynamo/Oldest Inductor),StdDev,ModelName(XLA+Dynamo) 2 | 0,1.0,0.0,BERT_pytorch,2.84589073,0.0,BERT_pytorch 3 | 1,1.0,0.0,Background_Matting,,, 4 | -------------------------------------------------------------------------------- /test/benchmarks/test_benchmark_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from benchmark_model import BenchmarkModel 4 | 5 | 6 | class BenchmarkModelTest(unittest.TestCase): 7 | 8 | def test_to_dict(self): 9 | bm = BenchmarkModel("torchbench or other", "super_deep_model", 10 | "placeholder") 11 | actual = bm.to_dict() 12 | self.assertEqual(2, len(actual)) 13 | self.assertEqual("torchbench or other", actual["suite_name"]) 14 | self.assertEqual("super_deep_model", actual["model_name"]) 15 | 16 | 17 | if __name__ == '__main__': 18 | unittest.main() 19 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.histogram.lazytensor.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends inductor openxla+lazytensor -- 2 | # Datetime(UTC),Inductor p95,Inductor p50,Inductor p5,XLA+LazyTensor p95,XLA+LazyTensor p50,XLA+LazyTensor p5 3 | 2023-11-11 05:32:18.723407,1.0,1.0,1.0,,, 4 | 2023-11-12 05:32:18,1.50833479,1.40761418,1.30689358,0.41071322,0.41071322,0.41071322 5 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.histogram.tab.test: -------------------------------------------------------------------------------- 1 | # ARGS: --format=tab 2 | ╒════════════════════════════╤════════════╤════════════╤════════════╤══════════════╤══════════════╤══════════════╕ 3 | │ Datetime(UTC) │ Inductor │ Inductor │ Inductor │ XLA+Dynamo │ XLA+Dynamo │ XLA+Dynamo │ 4 | │ │ p95 │ p50 │ p5 │ p95 │ p50 │ p5 │ 5 | ╞════════════════════════════╪════════════╪════════════╪════════════╪══════════════╪══════════════╪══════════════╡ 6 | │ 2023-11-11 05:32:18.723407 │ 1.00 │ 1.00 │ 1.00 │ 0.98 │ 0.86 │ 0.74 │ 7 | ├────────────────────────────┼────────────┼────────────┼────────────┼──────────────┼──────────────┼──────────────┤ 8 | │ 2023-11-12 05:32:18 │ 1.51 │ 1.41 │ 1.31 │ 1.53 │ 1.17 │ 0.81 │ 9 | ╘════════════════════════════╧════════════╧════════════╧════════════╧══════════════╧══════════════╧══════════════╛ 10 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.histogram.test: -------------------------------------------------------------------------------- 1 | # Datetime(UTC),Inductor p95,Inductor p50,Inductor p5,XLA+Dynamo p95,XLA+Dynamo p50,XLA+Dynamo p5 2 | 2023-11-11 05:32:18.723407,1.0,1.0,1.0,0.97631327,0.85586259,0.7354119 3 | 2023-11-12 05:32:18,1.50833479,1.40761418,1.30689358,1.52901152,1.17088985,0.81276817 4 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest.openxla_baseline.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends openxla+dynamo inductor --baseline=latest --filter-by-tier=1 2 | # Workload,Speedup(XLA+Dynamo/Latest XLA+Dynamo),StdDev,ModelName(XLA+Dynamo),Speedup(Inductor/Latest XLA+Dynamo),StdDev,ModelName(Inductor) 3 | 0,1.0,0.0,BERT_pytorch,0.96858952,0.0,BERT_pytorch 4 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends inductor openxla+dynamo openxla+lazytensor -- 2 | # Workload,Speedup(Inductor/Oldest Inductor),StdDev,ModelName(Inductor),Speedup(XLA+Dynamo/Oldest Inductor),StdDev,ModelName(XLA+Dynamo),Speedup(XLA+LazyTensor/Oldest Inductor),StdDev,ModelName(XLA+LazyTensor) 3 | 0,1.2957024,0.0,Background_Matting,0.77297688,0.0,Background_Matting,0.41071322,0.0,Background_Matting 4 | 1,1.51952596,0.06679279,BERT_pytorch,1.56880282,0.06895882,BERT_pytorch,,, 5 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest.tier1.test: -------------------------------------------------------------------------------- 1 | # ARGS: --filter-by-tier=1 2 | # Workload,Speedup(Inductor/Oldest Inductor),StdDev,ModelName(Inductor),Speedup(XLA+Dynamo/Oldest Inductor),StdDev,ModelName(XLA+Dynamo) 3 | 0,1.51952596,0.06679279,BERT_pytorch,1.56880282,0.06895882,BERT_pytorch 4 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.latest_grouped.test: -------------------------------------------------------------------------------- 1 | # ModelName,Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+Dynamo/Oldest Inductor),StdDev 2 | Background_Matting,1.2957024,0.0,0.77297688,0.0 3 | BERT_pytorch,1.51952596,0.06679279,1.56880282,0.06895882 4 | GEOMEAN,1.40315838,0.03083885,1.10120312,0.02420242 5 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.baseline_latest.test: -------------------------------------------------------------------------------- 1 | # ARGS: --baseline=latest 2 | # Datetime(UTC),Speedup(Inductor/Latest Inductor),StdDev,Speedup(XLA+Dynamo/Latest Inductor),StdDev 3 | 2023-11-11 05:32:18.723407,0.71267792,0.01566335,0.60245072,0.0 4 | 2023-11-12 05:32:18,1.0,0.0,0.78480315,0.0 5 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.lazytensor.test: -------------------------------------------------------------------------------- 1 | # ARGS: --backends inductor openxla+lazytensor -- 2 | # Datetime(UTC),Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+LazyTensor/Oldest Inductor),StdDev 3 | 2023-11-11 05:32:18.723407,1.0,0.03108182,, 4 | 2023-11-12 05:32:18,1.40315838,0.03083885,0.41071322,0.0 5 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.tab.test: -------------------------------------------------------------------------------- 1 | # ARGS: --format=tab 2 | ╒════════════════════════════╤════════════╤══════════╤══════════════╤══════════╕ 3 | │ Datetime(UTC) │ Speedup │ StdDev │ Speedup │ StdDev │ 4 | │ │ Inductor │ │ XLA+Dynamo │ │ 5 | │ │ over │ │ over │ │ 6 | │ │ Oldest │ │ Oldest │ │ 7 | │ │ Inductor │ │ Inductor │ │ 8 | ╞════════════════════════════╪════════════╪══════════╪══════════════╪══════════╡ 9 | │ 2023-11-11 05:32:18.723407 │ 1.00 │ 0.03 │ 0.85 │ 0.02 │ 10 | ├────────────────────────────┼────────────┼──────────┼──────────────┼──────────┤ 11 | │ 2023-11-12 05:32:18 │ 1.40 │ 0.03 │ 1.10 │ 0.02 │ 12 | ╘════════════════════════════╧════════════╧══════════╧══════════════╧══════════╛ 13 | -------------------------------------------------------------------------------- /test/benchmarks/v100.inference.speedup.test: -------------------------------------------------------------------------------- 1 | # Datetime(UTC),Speedup(Inductor/Oldest Inductor),StdDev,Speedup(XLA+Dynamo/Oldest Inductor),StdDev 2 | 2023-11-11 05:32:18.723407,1.0,0.03108182,0.84533378,0.01857889 3 | 2023-11-12 05:32:18,1.40315838,0.03083885,1.10120312,0.02420242 4 | -------------------------------------------------------------------------------- /test/cpp/get_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "$PATH" 3 | exec llvm-cov gcov "$@" 4 | -------------------------------------------------------------------------------- /test/cpp/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | int main(int argc, char* argv[]) { 7 | ::testing::InitGoogleTest(&argc, argv); 8 | return RUN_ALL_TESTS(); 9 | } 10 | -------------------------------------------------------------------------------- /test/dynamo/test_dynamo_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_xla 3 | import unittest 4 | from torch_xla._dynamo import config 5 | 6 | 7 | class DynamoconfigTest(unittest.TestCase): 8 | 9 | def dummy_test(self, a): 10 | return a.cos().sin() 11 | 12 | def test_config_skip_input_data_check(self): 13 | device = torch_xla.device() 14 | print(config.skip_input_data_check) 15 | config.skip_input_data_check = True 16 | compiled_dummy = torch.compile(self.dummy_test, backend="openxla") 17 | t1 = torch.randn(3, 4, device=device) 18 | compiled_dummy(t1) 19 | t2 = torch.randn(3, 4, device=device) 20 | t2 += 5 21 | with self.assertRaisesRegex( 22 | RuntimeError, r'input data to dynamo graph can not be a pending ir'): 23 | compiled_dummy(t2) 24 | 25 | 26 | if __name__ == '__main__': 27 | test = unittest.main() 28 | -------------------------------------------------------------------------------- /test/pjrt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/test/pjrt/__init__.py -------------------------------------------------------------------------------- /test/pjrt/args_parse.py: -------------------------------------------------------------------------------- 1 | ../args_parse.py -------------------------------------------------------------------------------- /test/pjrt/test_dynamic_plugin_tpu.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | 3 | from absl.testing import absltest 4 | import torch_xla 5 | import torch_xla.core.xla_model as xm 6 | from torch_xla.experimental import plugins 7 | import torch_xla.runtime as xr 8 | from torch_xla._internal import tpu 9 | 10 | plugins.register_plugin('TPU', tpu.TpuPlugin()) 11 | plugins.use_dynamic_plugins() 12 | 13 | 14 | class TestDynamicTpuPlugin(absltest.TestCase): 15 | 16 | @classmethod 17 | def setUpClass(cls): 18 | xr.set_device_type('TPU') 19 | 20 | @staticmethod 21 | def _assert_tpus_exist(index=0): 22 | del index 23 | assert len(xm.get_xla_supported_devices('TPU')) > 0 24 | 25 | def test_single_process(self): 26 | with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: 27 | executor.submit(self._assert_tpus_exist).result() 28 | 29 | def test_spawn(self): 30 | torch_xla.launch(self._assert_tpus_exist) 31 | 32 | 33 | if __name__ == '__main__': 34 | absltest.main() 35 | -------------------------------------------------------------------------------- /test/spmd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/test/spmd/__init__.py -------------------------------------------------------------------------------- /test/stablehlo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/test/stablehlo/__init__.py -------------------------------------------------------------------------------- /test/test_mp_mesh_reduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_xla 3 | from torch_xla import runtime as xr 4 | import torch_xla.core.xla_model as xm 5 | 6 | 7 | def _test_scalar(): 8 | 9 | def reduce_add(vlist): 10 | return sum(vlist) 11 | 12 | svalue = 1.25 13 | rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_scalar', svalue, 14 | reduce_add) 15 | assert rvalue == svalue * xr.world_size() 16 | 17 | 18 | def _test_tensor(): 19 | 20 | def reduce_add(vlist): 21 | return torch.stack(vlist).sum(dim=0) 22 | 23 | tvalue = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32) 24 | rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_tensor', tvalue, 25 | reduce_add) 26 | assert rvalue.allclose(tvalue * xr.world_size()) 27 | 28 | 29 | def _mp_fn(index): 30 | _test_scalar() 31 | _test_tensor() 32 | 33 | 34 | if __name__ == '__main__': 35 | torch_xla.launch(_mp_fn, args=()) 36 | -------------------------------------------------------------------------------- /test/test_torch_distributed_fsdp_frozen_weight.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch_xla 5 | import torch_xla.core.xla_model as xm 6 | from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP 7 | 8 | 9 | def _mp_fn(index): 10 | dev = torch_xla.device() 11 | if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): 12 | print( 13 | 'Default device {} is not a TPU or CUDA device'.format(dev), 14 | file=sys.stderr) 15 | return 16 | 17 | model = nn.Linear(1024, 1024) 18 | model.weight.requires_grad = False # the weight param is frozen 19 | 20 | model = FSDP(model) # wrapping the linear module with FSDP 21 | 22 | input = torch.rand((2, 1024), device='xla') 23 | 24 | output = model(input) 25 | loss = torch.sum(output) 26 | loss.backward() 27 | assert not any(p._has_full_param for p in model.full_params), \ 28 | 'Expecting all the full params to be freed at this moment.' 29 | 30 | 31 | if __name__ == "__main__": 32 | torch_xla.launch(_mp_fn, args=()) 33 | -------------------------------------------------------------------------------- /test/tpu/run_expensive_test_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xue 3 | CDIR="$(cd "$(dirname "$0")" ; pwd -P)" 4 | TEST_CDIR="$(dirname "$CDIR")" 5 | 6 | # This test takes ~1350 seconds to run 7 | python3 "$TEST_CDIR/test_multi_queries_paged_attention_kernel.py" 8 | -------------------------------------------------------------------------------- /test/tpu/run_expensive_test_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xue 3 | CDIR="$(cd "$(dirname "$0")" ; pwd -P)" 4 | TEST_CDIR="$(dirname "$CDIR")" 5 | 6 | # This test takes ~1000 seconds to run 7 | python3 "$TEST_CDIR/test_ragged_paged_attention_kernel.py" 8 | -------------------------------------------------------------------------------- /test/tpu/run_pallas_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xue 3 | 4 | # Absolute path to the directory of this script. 5 | _TPU_DIR="$( 6 | cd "$(dirname "$0")" 7 | pwd -P 8 | )" 9 | 10 | # Absolute path to the test/ directory. 11 | _TEST_DIR="$(dirname "$_TPU_DIR")" 12 | 13 | # This test takes ~370 seconds to run 14 | python3 "$_TEST_DIR/test_pallas.py" -v 15 | 16 | # This test takes ~15 seconds to run 17 | python3 "$_TEST_DIR/test_pallas_spmd.py" 18 | 19 | # This test takes ~15 seconds to run 20 | XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$_TEST_DIR/test_pallas_spmd.py" 21 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/test/utils/__init__.py -------------------------------------------------------------------------------- /torch_xla/_dynamo/__init__.py: -------------------------------------------------------------------------------- 1 | import torch_xla._dynamo.config as config 2 | -------------------------------------------------------------------------------- /torch_xla/_dynamo/config.py: -------------------------------------------------------------------------------- 1 | import torch_xla 2 | 3 | # Whether to skip checking input is a device data or not in the optim_mod. 4 | # Enabling it will reduce the overhead a bit but will throw a runtime error 5 | # if input is a pending IR. 6 | skip_input_data_check = False -------------------------------------------------------------------------------- /torch_xla/_internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torch_xla/_internal/__init__.py -------------------------------------------------------------------------------- /torch_xla/_internal/decomp_registration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.library.register_kernel("aten::upsample_trilinear3d", "xla", 4 | torch._decomp.decompositions.upsample_trilinear3d) 5 | -------------------------------------------------------------------------------- /torch_xla/_internal/gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch_xla.core.xla_env_vars as xenv 3 | 4 | 5 | def num_local_processes() -> int: 6 | """Returns number of processes to create on this host. 7 | 8 | Raises: 9 | AssertionError: if GPU_NUM_DEVICES environment variable 10 | is not configured 11 | """ 12 | assert xenv.GPU_NUM_DEVICES in os.environ, \ 13 | "Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client" 14 | os.environ[xenv.LOCAL_WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES] 15 | return int(os.environ[xenv.LOCAL_WORLD_SIZE]) 16 | -------------------------------------------------------------------------------- /torch_xla/_internal/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def parse_xla_device(device: str): 5 | m = re.match(r'([A-Z]+):(\d+)$', device) 6 | if m: 7 | return (m.group(1), int(m.group(2))) 8 | -------------------------------------------------------------------------------- /torch_xla/_internal/xpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch_xla.experimental import plugins 4 | 5 | 6 | class XpuPlugin(plugins.DevicePlugin): 7 | 8 | def library_path(self): 9 | return os.environ.get('XPU_LIBRARY_PATH', 'libxpu.so') 10 | -------------------------------------------------------------------------------- /torch_xla/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from .autocast_mode import autocast # noqa: F401 2 | from .grad_scaler import GradScaler # noqa: F401 3 | -------------------------------------------------------------------------------- /torch_xla/amp/syncfree/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam import Adam 2 | from .adamw import AdamW 3 | from .sgd import SGD 4 | -------------------------------------------------------------------------------- /torch_xla/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torch_xla/core/__init__.py -------------------------------------------------------------------------------- /torch_xla/core/dynamo_bridge.py: -------------------------------------------------------------------------------- 1 | # TODO(JackCaoG): remove after updated upstream 2 | from torch_xla._dynamo.dynamo_bridge import * -------------------------------------------------------------------------------- /torch_xla/csrc/aten_cuda_functions.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_ATEN_CUDA_FUNCTIONS_H_ 2 | #define XLA_TORCH_XLA_CSRC_ATEN_CUDA_FUNCTIONS_H_ 3 | 4 | #include 5 | 6 | // Forward declaration of PyTorch CUDA functions. 7 | // Source: c10/cuda/CUDAFunctions.h 8 | // 9 | // These are needed in order to synchronize the CUDA device after running 10 | // the operation in PyTorch eager mode. 11 | // 12 | // It would be better to include the actual header. However, if we build 13 | // PyTorch/XLA in an environment where PyTorch wasn't compiled with CUDA 14 | // (i.e. our CI), the build would fail. 15 | 16 | namespace c10 { 17 | 18 | // Type alias used inside PyTorch. 19 | using DeviceIndex = int8_t; 20 | 21 | namespace cuda { 22 | 23 | DeviceIndex device_count() noexcept; 24 | 25 | c10::DeviceIndex current_device(); 26 | 27 | void set_device(c10::DeviceIndex); 28 | 29 | void device_synchronize(); 30 | 31 | } // namespace cuda 32 | } // namespace c10 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_ATEN_CUDA_FUNCTIONS_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/aten_fallback.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ 2 | #define XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ 3 | 4 | #include 5 | 6 | namespace torch_xla { 7 | 8 | void xla_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); 9 | 10 | std::vector GetFallbackOperations(); 11 | 12 | } // namespace torch_xla 13 | 14 | #endif // XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ 15 | -------------------------------------------------------------------------------- /torch_xla/csrc/dl_convertor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace torch_xla { 8 | 9 | DLManagedTensor* toDLPack(const at::Tensor& src); 10 | at::Tensor fromDLPack(DLManagedTensor* src); 11 | 12 | } // namespace torch_xla 13 | 14 | #endif // XLA_TORCH_XLA_CSRC_DL_CONVERTOR_H_ 15 | -------------------------------------------------------------------------------- /torch_xla/csrc/dtype.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_DTYPE_H_ 2 | #define XLA_TORCH_XLA_CSRC_DTYPE_H_ 3 | 4 | #include "torch_xla/csrc/device.h" 5 | #include "xla/shape.h" 6 | 7 | namespace torch_xla { 8 | 9 | at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type); 10 | 11 | xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type); 12 | 13 | // Downcast type to be compatible with device if necessary. 14 | xla::PrimitiveType MaybeDowncastToXlaDeviceType( 15 | xla::PrimitiveType type, const torch::lazy::BackendDevice& device); 16 | 17 | xla::PrimitiveType MaybeDowncastToXlaDeviceType( 18 | at::ScalarType scalar_type, const torch::lazy::BackendDevice& device); 19 | 20 | // Upcast type to original PyTorch type. 21 | at::ScalarType MaybeUpcastToHostTorchType(xla::PrimitiveType xla_type); 22 | 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_DTYPE_H_ 26 | -------------------------------------------------------------------------------- /torch_xla/csrc/function_call_tracker.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_FUNCTION_CALL_TRACKER_H_ 2 | #define XLA_TORCH_XLA_CSRC_FUNCTION_CALL_TRACKER_H_ 3 | 4 | namespace torch_xla { 5 | namespace fn_tracker { 6 | 7 | #define XLA_FN_TRACK(level) \ 8 | torch_xla::fn_tracker::TrackFunction(__FUNCTION__, level) 9 | 10 | void TrackFunction(const char* tag, int level); 11 | 12 | } // namespace fn_tracker 13 | } // namespace torch_xla 14 | 15 | #endif // XLA_TORCH_XLA_CSRC_FUNCTION_CALL_TRACKER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/generated_file_include.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_GENERATED_FILE_INCLUDE_H_ 2 | #define XLA_TORCH_XLA_CSRC_GENERATED_FILE_INCLUDE_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/aten_fallback.h" 7 | #include "torch_xla/csrc/aten_xla_bridge.h" 8 | #include "torch_xla/csrc/ir.h" 9 | #include "torch_xla/csrc/ops/ops_xla_shape_fn.h" 10 | #include "torch_xla/csrc/runtime/debug_macros.h" 11 | #include "torch_xla/csrc/runtime/metrics.h" 12 | 13 | #endif // XLA_TORCH_XLA_CSRC_GENERATED_FILE_INCLUDE_H_ 14 | -------------------------------------------------------------------------------- /torch_xla/csrc/matrix.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_MATRIX_H_ 2 | #define XLA_TORCH_XLA_CSRC_MATRIX_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | xla::XlaOp BuildTriu(xla::XlaOp input, int64_t diagonal); 9 | 10 | xla::XlaOp BuildTril(xla::XlaOp input, int64_t diagonal); 11 | 12 | xla::XlaOp BuildDiagonal(xla::XlaOp input, int64_t offset, int64_t dim1, 13 | int64_t dim2); 14 | 15 | xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input, 16 | int64_t offset, int64_t dim1, int64_t dim2); 17 | 18 | xla::XlaOp BuildInverse(xla::XlaOp input); 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_MATRIX_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/nll_loss.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_NLL_LOSS_H_ 2 | #define XLA_TORCH_XLA_CSRC_NLL_LOSS_H_ 3 | 4 | #include "absl/types/optional.h" 5 | #include "torch_xla/csrc/reduction.h" 6 | #include "xla/hlo/builder/xla_builder.h" 7 | 8 | namespace torch_xla { 9 | 10 | // Builds the NLLLoss for log-probabilities "logits" and class indices "labels". 11 | xla::XlaOp BuildNllLoss(xla::XlaOp logits, xla::XlaOp labels, xla::XlaOp weight, 12 | int ignore_index, ReductionMode reduction_mode); 13 | 14 | // Builds the NLLLoss gradient for log-probabilities "logits" and class indices 15 | // "labels". 16 | xla::XlaOp BuildNllLossBackward(xla::XlaOp grad_output, xla::XlaOp logits, 17 | xla::XlaOp labels, xla::XlaOp weight, 18 | xla::XlaOp total_weight, int ignore_index, 19 | ReductionMode reduction_mode); 20 | 21 | } // namespace torch_xla 22 | 23 | #endif // XLA_TORCH_XLA_CSRC_NLL_LOSS_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/adaptive_max_pool2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_ADAPTIVE_MAX_POOL2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_ADAPTIVE_MAX_POOL2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class AdaptiveMaxPool2d : public XlaNode { 11 | public: 12 | AdaptiveMaxPool2d(const torch::lazy::Value& input, 13 | std::vector output_size); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& output_size() const { return output_size_; } 22 | 23 | private: 24 | std::vector output_size_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_ADAPTIVE_MAX_POOL2D_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_AMP_FOREACH_NON_FINITE_CHECK_AND_UNSCALE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_AMP_FOREACH_NON_FINITE_CHECK_AND_UNSCALE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class AmpForachNonFiniteCheckAndUnscale : public XlaNode { 9 | public: 10 | AmpForachNonFiniteCheckAndUnscale(const torch::lazy::OpList& inputs, 11 | const torch::lazy::Value& found_inf, 12 | const torch::lazy::Value& inv_scale); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | }; 18 | 19 | } // namespace torch_xla 20 | 21 | #endif // XLA_TORCH_XLA_CSRC_OPS_AMP_FOREACH_NON_FINITE_CHECK_AND_UNSCALE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/amp_update_scale.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_AMP_UPDATE_SCALE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_AMP_UPDATE_SCALE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class AmpUpdateScale : public XlaNode { 9 | public: 10 | AmpUpdateScale(const torch::lazy::Value& current_scale, 11 | const torch::lazy::Value& growth_tracker, 12 | const torch::lazy::Value& found_inf, 13 | double scale_growth_factor, double scale_backoff_factor, 14 | int growth_interval); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | double scale_growth_factor_; 22 | double scale_backoff_factor_; 23 | int growth_interval_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_AMP_UPDATE_SCALE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/arithmetic_ir_ops.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_ARITHMETIC_IR_OPS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_ARITHMETIC_IR_OPS_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | torch::lazy::NodePtr operator+(const torch::lazy::Value& node1, 9 | const torch::lazy::Value& node2); 10 | torch::lazy::NodePtr operator-(const torch::lazy::Value& node1, 11 | const torch::lazy::Value& node2); 12 | torch::lazy::NodePtr operator*(const torch::lazy::Value& node1, 13 | const torch::lazy::Value& node2); 14 | torch::lazy::NodePtr operator/(const torch::lazy::Value& node1, 15 | const torch::lazy::Value& node2); 16 | 17 | } // namespace torch_xla 18 | 19 | #endif // XLA_TORCH_XLA_CSRC_OPS_ARITHMETIC_IR_OPS_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/bernoulli.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_BERNOULLI_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_BERNOULLI_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Bernoulli : public XlaNode { 9 | public: 10 | Bernoulli(const torch::lazy::Value& probability, 11 | const torch::lazy::Value& seed, xla::Shape shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_BERNOULLI_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cast_int4.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CAST_INT4 2 | #define XLA_TORCH_XLA_CSRC_OPS_CAST_INT4 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class CastInt4 : public XlaNode { 9 | public: 10 | CastInt4(const torch::lazy::Value& weight, 11 | const std::vector& int4_weight_values); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | private: 20 | std::vector int4_vals_; 21 | }; 22 | 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_OPS_CAST_INT4 26 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cat.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CAT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CAT_H_ 3 | 4 | #include 5 | 6 | #include "absl/types/span.h" 7 | #include "torch_xla/csrc/ir.h" 8 | 9 | namespace torch_xla { 10 | 11 | class Cat : public XlaNode { 12 | public: 13 | Cat(c10::ArrayRef values, int64_t dim, 14 | at::ScalarType dtype); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | int64_t dim() const { return dim_; }; 23 | 24 | at::ScalarType dtype() const { return dtype_; }; 25 | 26 | private: 27 | int64_t dim_; 28 | at::ScalarType dtype_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_CAT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cdist.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CDIST_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CDIST_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class CdistForward : public XlaNode { 9 | public: 10 | CdistForward(const torch::lazy::Value& x1, const torch::lazy::Value& x2, 11 | const torch::lazy::Value& p, bool use_hamming, 12 | bool use_chebyshev); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | bool use_hamming_; // handle p == 0 22 | bool use_chebyshev_; // handle p == +inf 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_CDIST_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/collective_permute.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ 3 | 4 | #include "torch_xla/csrc/cross_replica_reduces.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class CollectivePermute : public XlaNode { 10 | public: 11 | CollectivePermute( 12 | const torch::lazy::Value& input, const torch::lazy::Value& token, 13 | std::vector> source_target_pairs); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | const std::vector>& source_target_pairs() const { 22 | return source_target_pairs_; 23 | } 24 | 25 | private: 26 | std::vector> source_target_pairs_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/constant.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CONSTANT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CONSTANT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Constant : public XlaNode { 9 | public: 10 | Constant(xla::Literal value); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | const xla::Literal& value() const { return value_; } 19 | 20 | private: 21 | xla::Literal value_; 22 | }; 23 | 24 | torch::lazy::hash_t LiteralHash(const xla::Literal& l); 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_CONSTANT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/constant_pad_nd.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CONSTANT_PAD_ND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CONSTANT_PAD_ND_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ConstantPadNd : public XlaNode { 11 | public: 12 | ConstantPadNd(const torch::lazy::Value& input, std::vector pad, 13 | const at::Scalar& value); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | const at::Scalar& value() const { return value_; } 22 | 23 | const std::vector& pad() const { return pad_; } 24 | 25 | private: 26 | std::vector pad_; 27 | at::Scalar value_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_CONSTANT_PAD_ND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/count_nonzero.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_COUNT_NONZERO_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_COUNT_NONZERO_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class CountNonzero : public XlaNode { 9 | public: 10 | CountNonzero(const torch::lazy::Value& input, std::vector dims); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::optional> dims() const { return dims_; } 19 | 20 | private: 21 | std::vector dims_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_COUNT_NONZERO_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cummax.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class CumMax : public XlaNode { 11 | public: 12 | CumMax(const torch::lazy::Value& input, int64_t dim); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | int64_t dim() const { return dim_; } 21 | 22 | private: 23 | int64_t dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_ 29 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cumprod.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CUMPROD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CUMPROD_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class CumProd : public XlaNode { 13 | public: 14 | CumProd(const torch::lazy::Value& input, int64_t dim, 15 | std::optional dtype); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | int64_t dim_; 29 | std::optional dtype_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_CUMPROD_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/cumsum.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_CUMSUM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_CUMSUM_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class CumSum : public XlaNode { 13 | public: 14 | CumSum(const torch::lazy::Value& input, int64_t dim, 15 | std::optional dtype); 16 | 17 | std::string ToString() const override; 18 | 19 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 20 | 21 | XlaOpVector Lower(LoweringContext* loctx) const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | int64_t dim_; 29 | std::optional dtype_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_CUMSUM_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dequant_tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DEQUANT_TENSOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DEQUANT_TENSOR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DequantizeTensor : public XlaNode { 9 | public: 10 | DequantizeTensor(const torch::lazy::Value& input, 11 | const std::vector& scale, 12 | const std::vector& zero_point, int quant_min, 13 | int quant_max, const std::string& dtype, int axis); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | private: 22 | int quant_min_; 23 | int quant_max_; 24 | int axis_; 25 | std::string dtype_; 26 | std::vector scale_; 27 | std::vector zero_point_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/diagonal.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Diagonal : public XlaNode { 9 | public: 10 | Diagonal(const torch::lazy::Value& input, int64_t offset, int64_t dim1, 11 | int64_t dim2); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t offset() const { return offset_; } 20 | 21 | int64_t dim1() const { return dim1_; } 22 | 23 | int64_t dim2() const { return dim2_; } 24 | 25 | static xla::Shape MakeDiagonalShape(const xla::Shape& shape, int64_t offset, 26 | int64_t dim1, int64_t dim2); 27 | 28 | private: 29 | int64_t offset_; 30 | int64_t dim1_; 31 | int64_t dim2_; 32 | }; 33 | 34 | } // namespace torch_xla 35 | 36 | #endif // XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/diagonal_view_update.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_VIEW_UPDATE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_VIEW_UPDATE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DiagonalViewUpdate : public XlaNode { 9 | public: 10 | DiagonalViewUpdate(const torch::lazy::Value& target, 11 | const torch::lazy::Value& input, int64_t offset, 12 | int64_t dim1, int64_t dim2); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | int64_t offset() const { return offset_; } 21 | 22 | int64_t dim1() const { return dim1_; } 23 | 24 | int64_t dim2() const { return dim2_; } 25 | 26 | private: 27 | int64_t offset_; 28 | int64_t dim1_; 29 | int64_t dim2_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_DIAGONAL_VIEW_UPDATE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/discrete_uniform.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DISCRETE_UNIFORM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DISCRETE_UNIFORM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DiscreteUniform : public XlaNode { 9 | public: 10 | DiscreteUniform(const torch::lazy::Value& from, const torch::lazy::Value& to, 11 | const torch::lazy::Value& seed, const xla::Shape& rng_shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_DISCRETE_UNIFORM_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dot_general.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DOT_GENERAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DOT_GENERAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DotGeneral : public XlaNode { 9 | public: 10 | DotGeneral(const torch::lazy::Value& lhs, const torch::lazy::Value& rhs, 11 | const std::vector>& dim_vectors, 12 | std::optional preferred_element_type); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | std::vector> dim_vectors_; 22 | std::optional preferred_element_type_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_DOT_GENERAL_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dynamic_expand.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_EXPAND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_EXPAND_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DynamicExpand : public XlaNode { 9 | public: 10 | DynamicExpand(const torch::lazy::Value& input, 11 | const std::vector& size, 12 | const torch::lazy::Value& src_tensor, int src_index, 13 | int target_index); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | private: 22 | std::vector size_; 23 | int src_index_; 24 | int target_index_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_EXPAND_H_ 30 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/dynamic_view.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_VIEW_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_VIEW_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class DynamicView : public XlaNode { 9 | public: 10 | DynamicView(const torch::lazy::Value& input, const std::vector& size, 11 | const torch::lazy::Value& src_tensor, int src_index, 12 | int target_index, float mul_scaler); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | std::vector size_; 22 | int src_index_; 23 | int target_index_; 24 | float mul_scaler_; 25 | xla::Shape complete_output_shape_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_DYNAMIC_VIEW_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/eigh.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EIGH_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EIGH_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | #include "xla/types.h" 8 | 9 | namespace torch_xla { 10 | 11 | class Eigh : public XlaNode { 12 | public: 13 | Eigh(const torch::lazy::Value& input, std::string_view uplo); 14 | 15 | std::string ToString() const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | private: 20 | char uplo_; 21 | }; 22 | 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_OPS_EIGH_H_ 26 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/einsum.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EINSUM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EINSUM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Einsum : public XlaNode { 9 | public: 10 | Einsum(const torch::lazy::OpList& operands, const std::string equation); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | 16 | std::string ToString() const override; 17 | 18 | const std::string& equation() const { return equation_; } 19 | 20 | private: 21 | const std::string equation_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_EINSUM_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/einsum_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EINSUM_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EINSUM_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class EinsumBackward : public XlaNode { 9 | public: 10 | EinsumBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::OpList& inputs, const std::string equation); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::string& equation() const { return equation_; } 20 | 21 | private: 22 | const std::string equation_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_EINSUM_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/embedding_bag.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class EmbeddingBag : public XlaNode { 11 | public: 12 | EmbeddingBag(const torch::lazy::Value& weight, 13 | const torch::lazy::Value& indices, 14 | const torch::lazy::Value& offsets, int64_t mode, 15 | const torch::lazy::Value& per_sample_weights, 16 | bool include_last_offset); 17 | 18 | std::string ToString() const override; 19 | 20 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 21 | 22 | XlaOpVector Lower(LoweringContext* loctx) const override; 23 | 24 | private: 25 | int64_t mode_; 26 | bool include_last_offset_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_EMBEDDING_BAG_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/expand.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EXPAND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EXPAND_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class Expand : public XlaNode { 11 | public: 12 | Expand(const torch::lazy::Value& input, std::vector size); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | const std::vector& size() const { return size_; }; 21 | 22 | private: 23 | std::vector size_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_EXPAND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/expand_symint.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EXPAND_SYMINT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EXPAND_SYMINT_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | #include "torch_xla/csrc/torch_util.h" 8 | 9 | namespace torch_xla { 10 | 11 | class ExpandSymInt : public XlaNode { 12 | public: 13 | ExpandSymInt(const torch::lazy::Value& input, 14 | const SymIntElements& size_elements); 15 | 16 | std::string ToString() const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | const std::vector& size() const { return upper_bounds_; }; 21 | 22 | const bool IsDynamic(int index) const { return dynamic_dims_[index]; }; 23 | 24 | private: 25 | std::vector upper_bounds_; 26 | std::vector dynamic_dims_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_EXPAND_SYMINT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/exponential.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_EXPONENTIAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_EXPONENTIAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Exponential : public XlaNode { 9 | public: 10 | Exponential(const torch::lazy::Value& lambda, const torch::lazy::Value& seed, 11 | xla::Shape shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_EXPONENTIAL_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/flip.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/ops/flip.h" 2 | 3 | #include "torch_xla/csrc/lowering_context.h" 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | Flip::Flip(const torch::lazy::Value& input, std::vector dims) 9 | : XlaNode(torch::lazy::OpKind(at::aten::flip), {input}, GetXlaShape(input), 10 | /*num_outputs=*/1, torch::lazy::MHash(dims)), 11 | dims_(std::move(dims)) {} 12 | 13 | torch::lazy::NodePtr Flip::Clone(torch::lazy::OpList operands) const { 14 | return torch_xla::MakeNode(operands.at(0), dims_); 15 | } 16 | 17 | XlaOpVector Flip::Lower(LoweringContext* loctx) const { 18 | xla::XlaOp input = loctx->GetOutputOp(operand(0)); 19 | xla::XlaOp output = xla::Rev(input, dims_); 20 | return ReturnOp(output, loctx); 21 | } 22 | 23 | std::string Flip::ToString() const { 24 | std::stringstream ss; 25 | ss << XlaNode::ToString() << ", dims=(" << absl::StrJoin(dims_, ", ") << ")"; 26 | return ss.str(); 27 | } 28 | 29 | } // namespace torch_xla 30 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/flip.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_FLIP_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_FLIP_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class Flip : public XlaNode { 10 | public: 11 | Flip(const torch::lazy::Value& input, std::vector dims); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::vector& dims() const { return dims_; } 20 | 21 | private: 22 | // The dimensions which are flipped. 23 | std::vector dims_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_FLIP_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/gather.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_GATHER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_GATHER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Gather : public XlaNode { 9 | public: 10 | Gather(const torch::lazy::Value& input, int64_t dim, 11 | const torch::lazy::Value& index); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_GATHER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/generic_slice.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_GENERIC_SLICE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_GENERIC_SLICE_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class GenericSlice : public XlaNode { 10 | public: 11 | GenericSlice(const torch::lazy::Value& input, 12 | absl::Span base_indices, 13 | absl::Span sizes); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& base_indices() const { return base_indices_; } 22 | 23 | const std::vector& sizes() const { return sizes_; } 24 | 25 | private: 26 | std::vector base_indices_; 27 | std::vector sizes_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_GENERIC_SLICE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/get_dimensions_size.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_GET_DIMENSIONS_SIZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_GET_DIMENSIONS_SIZE_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class GetDimensionsSize : public XlaNode { 11 | public: 12 | GetDimensionsSize(const torch::lazy::Value& input, 13 | std::vector dimensions); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& dimensions() const { return dimensions_; } 22 | 23 | private: 24 | std::vector dimensions_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_GET_DIMENSIONS_SIZE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/gpu_custom_call.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | class GpuCustomCall : public XlaNode { 8 | public: 9 | // Make a GPU custom call with payload, e.g., Triton. 10 | GpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, 11 | const std::string& payload); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | private: 20 | std::string payload_; 21 | }; 22 | 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ 26 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/hardtanh_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_HARDTANH_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_HARDTANH_BACKWARD_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class HardtanhBackward : public XlaNode { 11 | public: 12 | HardtanhBackward(const torch::lazy::Value& grad_output, 13 | const torch::lazy::Value& input, const at::Scalar& min_val, 14 | const at::Scalar& max_val); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | at::Scalar min_val() const { return min_val_; } 23 | 24 | at::Scalar max_val() const { return max_val_; } 25 | 26 | private: 27 | at::Scalar min_val_; 28 | at::Scalar max_val_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_HARDTANH_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/index_get.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INDEX_GET_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INDEX_GET_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class IndexGet : public XlaNode { 9 | public: 10 | IndexGet(const torch::lazy::Value& base, const torch::lazy::Value& indices, 11 | int64_t start_dim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t start_dim() const { return start_dim_; } 20 | 21 | private: 22 | // The dimension number at which indexing starts. 23 | int64_t start_dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_INDEX_GET_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/index_put.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INDEX_PUT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INDEX_PUT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class IndexPut : public XlaNode { 9 | public: 10 | IndexPut(const torch::lazy::Value& base, const torch::lazy::Value& indices, 11 | int64_t start_dim, const torch::lazy::Value& values, 12 | bool accumulate); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | int64_t start_dim() const { return start_dim_; } 21 | 22 | bool accumulate() const { return accumulate_; } 23 | 24 | private: 25 | // The dimension number at which indexing starts. 26 | int64_t start_dim_; 27 | // Whether to accumulate instead of set. 28 | bool accumulate_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_INDEX_PUT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/index_select.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INDEX_SELECT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INDEX_SELECT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class IndexSelect : public XlaNode { 9 | public: 10 | IndexSelect(const torch::lazy::Value& input, int64_t dim, 11 | const torch::lazy::Value& index); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_INDEX_SELECT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/infer_output_shape.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_INFER_OUTPUT_SHAPE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_INFER_OUTPUT_SHAPE_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "xla/hlo/builder/xla_builder.h" 6 | 7 | namespace torch_xla { 8 | 9 | using LowerForShapeFn = 10 | std::function operands)>; 11 | using LowerForShapesFn = std::function( 12 | absl::Span operands)>; 13 | 14 | // Compute the output shape for the given input shapes and lowering. 15 | xla::Shape InferOutputShape(absl::Span input_shapes, 16 | const LowerForShapeFn& core_lowering_fn); 17 | 18 | xla::Shape InferOutputShapes(absl::Span input_shapes, 19 | const LowerForShapesFn& core_lowering_fn); 20 | 21 | } // namespace torch_xla 22 | 23 | #endif // XLA_TORCH_XLA_CSRC_OPS_INFER_OUTPUT_SHAPE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/kth_value.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_KTH_VALUE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_KTH_VALUE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class KthValue : public XlaNode { 9 | public: 10 | KthValue(const torch::lazy::Value& input, int64_t k, int64_t dim, 11 | bool keepdim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t k() const { return k_; }; 20 | 21 | int64_t dim() const { return dim_; }; 22 | 23 | bool keepdim() const { return keepdim_; } 24 | 25 | private: 26 | int64_t k_; 27 | int64_t dim_; 28 | bool keepdim_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_KTH_VALUE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/linear_interpolation.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LINEAR_INTERPOLATION_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LINEAR_INTERPOLATION_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class LinearInterpolation : public XlaNode { 9 | public: 10 | LinearInterpolation(const torch::lazy::Value& value, 11 | const torch::lazy::Value& new_value, double alpha); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | double alpha() const { return alpha_; } 20 | 21 | private: 22 | double alpha_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_LINEAR_INTERPOLATION_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/linspace.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LINSPACE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LINSPACE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Linspace : public XlaNode { 9 | public: 10 | Linspace(const torch::lazy::Value& start, const torch::lazy::Value& end, 11 | const int64_t steps); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t steps() const { return steps_; }; 20 | 21 | private: 22 | int64_t steps_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_LINSPACE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/log_softmax.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | // IR node for log(softmax) operation. 13 | class LogSoftmax : public XlaNode { 14 | public: 15 | LogSoftmax(const torch::lazy::Value& input, int64_t dim, 16 | std::optional dtype, 17 | std::vector&& shapes); 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | std::string ToString() const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | // The dimension along which the result is computed. 29 | int64_t dim_; 30 | std::optional dtype_; 31 | }; 32 | 33 | } // namespace torch_xla 34 | 35 | #endif // XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_H_ 36 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/log_softmax_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class LogSoftmaxBackward : public XlaNode { 9 | public: 10 | LogSoftmaxBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::Value& output, int64_t dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t dim() const { return dim_; } 20 | 21 | private: 22 | // The dimension along which the result is computed. 23 | int64_t dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_LOG_SOFTMAX_BACKWARD_H_#pragma once -------------------------------------------------------------------------------- /torch_xla/csrc/ops/logsumexp.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_LOGSUMEXP_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_LOGSUMEXP_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class Logsumexp : public XlaNode { 11 | public: 12 | Logsumexp(const torch::lazy::Value& input, std::vector dimensions, 13 | bool keep_reduced_dimensions); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | const std::vector& dimensions() const { return dimensions_; } 22 | 23 | bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; } 24 | 25 | private: 26 | std::vector dimensions_; 27 | bool keep_reduced_dimensions_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_LOGSUMEXP_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/mark_tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MarkTensor : public XlaNode { 9 | public: 10 | MarkTensor(const torch::lazy::Value& input, const std::string& info); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | private: 19 | std::string info_; 20 | }; 21 | 22 | } // namespace torch_xla 23 | 24 | #endif // XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_ 25 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/masked_scatter.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MASKED_SCATTER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MASKED_SCATTER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class MaskedScatter : public XlaNode { 12 | public: 13 | MaskedScatter(const torch::lazy::Value& input, const torch::lazy::Value& mask, 14 | const torch::lazy::Value& source); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | }; 20 | 21 | } // namespace torch_xla 22 | 23 | #endif // XLA_TORCH_XLA_CSRC_OPS_MASKED_SCATTER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/masked_select.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MASKED_SELECT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MASKED_SELECT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class MaskedSelect : public XlaNode { 12 | public: 13 | MaskedSelect(const torch::lazy::Value& input, const torch::lazy::Value& mask); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | }; 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_OPS_MASKED_SELECT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/max_in_dim.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MAX_IN_DIM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MAX_IN_DIM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MaxInDim : public XlaNode { 9 | public: 10 | MaxInDim(const torch::lazy::Value& input, int64_t dim, bool keepdim); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | int64_t dim() const { return dim_; }; 19 | 20 | bool keepdim() const { return keepdim_; } 21 | 22 | private: 23 | int64_t dim_; 24 | bool keepdim_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_MAX_IN_DIM_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/max_unpool_nd.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MAX_UNPOOL_ND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MAX_UNPOOL_ND_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MaxUnpoolNd : public XlaNode { 9 | public: 10 | MaxUnpoolNd(const torch::lazy::Value& input, 11 | const torch::lazy::Value& indices, 12 | std::vector output_size); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const std::vector& output_size() const { return output_size_; } 21 | 22 | private: 23 | std::vector output_size_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_MAX_UNPOOL_ND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/min_in_dim.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MIN_IN_DIM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MIN_IN_DIM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class MinInDim : public XlaNode { 9 | public: 10 | MinInDim(const torch::lazy::Value& input, int64_t dim, bool keepdim); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | int64_t dim() const { return dim_; }; 19 | 20 | bool keepdim() const { return keepdim_; } 21 | 22 | private: 23 | int64_t dim_; 24 | bool keepdim_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_MIN_IN_DIM_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/mse_loss.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | #include "torch_xla/csrc/reduction.h" 6 | #include "xla/types.h" 7 | 8 | namespace torch_xla { 9 | 10 | class MseLoss : public XlaNode { 11 | public: 12 | MseLoss(const torch::lazy::Value& input, const torch::lazy::Value& target, 13 | ReductionMode reduction); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | ReductionMode reduction() const { return reduction_; } 22 | 23 | private: 24 | ReductionMode reduction_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/mse_loss_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | #include "torch_xla/csrc/reduction.h" 6 | #include "xla/types.h" 7 | 8 | namespace torch_xla { 9 | 10 | class MseLossBackward : public XlaNode { 11 | public: 12 | MseLossBackward(const torch::lazy::Value& grad_output, 13 | const torch::lazy::Value& input, 14 | const torch::lazy::Value& target, ReductionMode reduction); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | ReductionMode reduction() const { return reduction_; } 23 | 24 | private: 25 | ReductionMode reduction_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_MSE_LOSS_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/multinomial.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "torch_xla/csrc/ir.h" 4 | 5 | namespace torch_xla { 6 | 7 | class Multinomial : public XlaNode { 8 | public: 9 | Multinomial(const torch::lazy::Value& input, const torch::lazy::Value& seed, 10 | int64_t num_samples, bool replacement); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | 16 | private: 17 | int64_t num_samples_; 18 | bool replacement_; 19 | }; 20 | 21 | } // namespace torch_xla 22 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/native_dropout.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class NativeDropout : public XlaNode { 12 | public: 13 | NativeDropout(const torch::lazy::Value& input, const torch::lazy::Value& seed, 14 | float p, std::optional train); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | private: 21 | float p_; 22 | std::optional train_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_NATIVE_DROPOUT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/nll_loss.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS_H_ 3 | 4 | #include "absl/types/optional.h" 5 | #include "torch_xla/csrc/ir.h" 6 | #include "torch_xla/csrc/reduction.h" 7 | 8 | namespace torch_xla { 9 | 10 | class NllLoss : public XlaNode { 11 | public: 12 | NllLoss(const torch::lazy::Value& logits, const torch::lazy::Value& labels, 13 | const absl::optional& weight, 14 | ReductionMode reduction, int ignore_index); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | ReductionMode reduction() const { return reduction_; } 23 | 24 | int ignore_index() const { return ignore_index_; } 25 | 26 | private: 27 | ReductionMode reduction_; 28 | int ignore_index_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_NLL_LOSS_H_ 34 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/nms.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NMS_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NMS_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Nms : public XlaNode { 9 | public: 10 | Nms(const torch::lazy::Value& boxes, const torch::lazy::Value& scores, 11 | const torch::lazy::Value& iou_threshold); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_NMS_H_ 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/nonzero.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NONZERO_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NONZERO_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // This node has no metadata, so it could have been implemented as generic-op in 9 | // ops.cpp, but since this might require special handling from upper IR layers, 10 | // it gets its own IR node class. 11 | class NonZero : public XlaNode { 12 | public: 13 | NonZero(const torch::lazy::Value& input); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | }; 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_OPS_NONZERO_H_ 23 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/normal.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NORMAL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NORMAL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Normal : public XlaNode { 9 | public: 10 | Normal(const torch::lazy::Value& mean, const torch::lazy::Value& std, 11 | const torch::lazy::Value& seed); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_NORMAL_H_ 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/not_supported.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/ops/not_supported.h" 2 | 3 | #include "torch_xla/csrc/lowering_context.h" 4 | #include "torch_xla/csrc/ops/xla_ops.h" 5 | #include "torch_xla/csrc/runtime/debug_macros.h" 6 | 7 | namespace torch_xla { 8 | 9 | NotSupported::NotSupported(std::string description, xla::Shape shape) 10 | : XlaNode(xla_not_supported, std::move(shape), /*num_outputs=*/1, 11 | torch::lazy::MHash(description)), 12 | description_(std::move(description)) {} 13 | 14 | torch::lazy::NodePtr NotSupported::Clone(torch::lazy::OpList operands) const { 15 | return torch_xla::MakeNode(description_, xla_shape()); 16 | } 17 | 18 | XlaOpVector NotSupported::Lower(LoweringContext* /* loctx */) const { 19 | XLA_ERROR() << "Node not supported: " << ToString(); 20 | } 21 | 22 | std::string NotSupported::ToString() const { 23 | std::stringstream ss; 24 | ss << XlaNode::ToString() << ", description=" << description_; 25 | return ss.str(); 26 | } 27 | 28 | } // namespace torch_xla 29 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/not_supported.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_NOT_SUPPORTED_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_NOT_SUPPORTED_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class NotSupported : public XlaNode { 11 | public: 12 | NotSupported(std::string description, xla::Shape shape); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | const std::string& description() const { return description_; } 21 | 22 | private: 23 | std::string description_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_NOT_SUPPORTED_H_ 29 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/optimization_barrier.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_OPTIMIZATION_BARRIER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_OPTIMIZATION_BARRIER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class OptimizationBarrier : public XlaNode { 9 | public: 10 | OptimizationBarrier(const torch::lazy::OpList& inputs); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | }; 16 | 17 | } // namespace torch_xla 18 | 19 | #endif // XLA_TORCH_XLA_CSRC_OPS_OPTIMIZATION_BARRIER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/permute.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_PERMUTE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_PERMUTE_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class Permute : public XlaNode { 10 | public: 11 | Permute(const torch::lazy::Value& input, std::vector dims); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::vector& dims() const { return dims_; } 20 | 21 | static xla::Shape MakePermuteShape(const xla::Shape& source_shape, 22 | absl::Span permutation); 23 | 24 | private: 25 | // The permutation of dimensions. 26 | std::vector dims_; 27 | }; 28 | 29 | } // namespace torch_xla 30 | 31 | #endif // XLA_TORCH_XLA_CSRC_OPS_PERMUTE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/put.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_PUT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_PUT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Put : public XlaNode { 9 | public: 10 | Put(const torch::lazy::Value& input, const torch::lazy::Value& index, 11 | const torch::lazy::Value& source, bool accumulate); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | bool accumulate() const { return accumulate_; } 20 | 21 | private: 22 | bool accumulate_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_PUT_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/qr.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_QR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_QR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class QR : public XlaNode { 9 | public: 10 | QR(const torch::lazy::Value& input, bool some); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | bool some() const { return some_; } 19 | 20 | private: 21 | bool some_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_QR_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/quant_tensor.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class QuantizeTensor : public XlaNode { 9 | public: 10 | QuantizeTensor(const torch::lazy::Value& input, 11 | const std::vector& scale, 12 | const std::vector& zero_point, int quant_min, 13 | int quant_max, const std::string& dtype, int axis); 14 | 15 | std::string ToString() const override; 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | private: 22 | int quant_min_; 23 | int quant_max_; 24 | int axis_; 25 | std::string dtype_; 26 | std::vector scale_; 27 | std::vector zero_point_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/randperm.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class RandPerm : public XlaNode { 11 | public: 12 | RandPerm(int64_t n, const at::ScalarType dtype, const at::Layout layout, 13 | const at::Device device, bool pin_memory); 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | std::string ToString() const override; 17 | 18 | private: 19 | int64_t n_; 20 | }; 21 | 22 | } // namespace torch_xla 23 | 24 | #endif // XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ 25 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/recv.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_RECV_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_RECV_H_ 3 | 4 | #include "torch_xla/csrc/cross_replica_reduces.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | namespace ir { 9 | namespace ops { 10 | 11 | class Recv : public XlaNode { 12 | public: 13 | Recv(const torch::lazy::Value& token, const xla::Shape& recv_shape, 14 | int64_t channel_id); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | int64_t channel_id() const { return channel_id_; } 23 | 24 | private: 25 | xla::Shape recv_shape_; 26 | int64_t channel_id_; 27 | }; 28 | 29 | } // namespace ops 30 | } // namespace ir 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_RECV_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/reflection_pad2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ReflectionPad2d : public XlaNode { 11 | public: 12 | ReflectionPad2d(const torch::lazy::Value& input, 13 | std::vector padding); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& padding() const { return padding_; } 22 | 23 | private: 24 | std::vector padding_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_H_ 30 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/reflection_pad2d_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_BACKWARD_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ReflectionPad2dBackward : public XlaNode { 11 | public: 12 | ReflectionPad2dBackward(const torch::lazy::Value& gard_output, 13 | const torch::lazy::Value& input, 14 | std::vector padding); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | std::string ToString() const override; 21 | 22 | const std::vector& padding() const { return padding_; } 23 | 24 | private: 25 | std::vector padding_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_REFLECTION_PAD2D_BACKWARD_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/replication_pad.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class ReplicationPad : public XlaNode { 10 | public: 11 | ReplicationPad(const torch::lazy::Value& input, std::vector padding); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::vector& padding() const { return padding_; } 20 | 21 | private: 22 | std::vector padding_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/replication_pad_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_BACKWARD_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ReplicationPadBackward : public XlaNode { 11 | public: 12 | ReplicationPadBackward(const torch::lazy::Value& gard_output, 13 | const torch::lazy::Value& input, 14 | std::vector padding); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | std::string ToString() const override; 21 | 22 | const std::vector& padding() const { return padding_; } 23 | 24 | private: 25 | std::vector padding_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_REPLICATION_PAD_BACKWARD_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/resize.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_RESIZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_RESIZE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Resize : public XlaNode { 9 | public: 10 | Resize(const torch::lazy::Value& input, std::vector size); 11 | 12 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 13 | 14 | XlaOpVector Lower(LoweringContext* loctx) const override; 15 | 16 | std::string ToString() const override; 17 | 18 | const std::vector& size() const { return size_; } 19 | 20 | private: 21 | std::vector size_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_RESIZE_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/roll.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_ROLL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_ROLL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Roll : public XlaNode { 9 | public: 10 | Roll(const torch::lazy::Value& input, std::vector shifts, 11 | std::vector dims); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::vector& shifts() const { return shifts_; } 20 | 21 | const std::vector& dims() const { return dims_; } 22 | 23 | private: 24 | std::vector shifts_; 25 | std::vector dims_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_ROLL_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/scatter.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SCATTER_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SCATTER_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Scatter : public XlaNode { 9 | public: 10 | Scatter(const torch::lazy::Value& input, const torch::lazy::Value& index, 11 | const torch::lazy::Value& src, int64_t dim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/scatter_add.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SCATTER_ADD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SCATTER_ADD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class ScatterAdd : public XlaNode { 9 | public: 10 | ScatterAdd(const torch::lazy::Value& input, const torch::lazy::Value& index, 11 | const torch::lazy::Value& src, int64_t dim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_ADD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/scatter_reduce.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class ScatterReduce : public XlaNode { 9 | public: 10 | ScatterReduce(const torch::lazy::Value& input, 11 | const torch::lazy::Value& index, const torch::lazy::Value& src, 12 | std::string_view reduce, bool include_self, int64_t dim); 13 | 14 | std::string ToString() const override; 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | int64_t dim() const { return dim_; }; 21 | 22 | private: 23 | std::string reduce_; 24 | bool include_self_; 25 | int64_t dim_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_SCATTER_REDUCE_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/send.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SEND_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SEND_H_ 3 | 4 | #include "torch_xla/csrc/cross_replica_reduces.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | namespace ir { 9 | namespace ops { 10 | 11 | class Send : public XlaNode { 12 | public: 13 | Send(const torch::lazy::Value& input, const torch::lazy::Value& token, 14 | int64_t channel_id); 15 | 16 | std::string ToString() const override; 17 | 18 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 19 | 20 | XlaOpVector Lower(LoweringContext* loctx) const override; 21 | 22 | int64_t channel_id() const { return channel_id_; } 23 | 24 | private: 25 | int64_t channel_id_; 26 | }; 27 | 28 | } // namespace ops 29 | } // namespace ir 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_SEND_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/softmax.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_H_ 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "torch_xla/csrc/ir.h" 9 | 10 | namespace torch_xla { 11 | 12 | class Softmax : public XlaNode { 13 | public: 14 | Softmax(const torch::lazy::Value& input, int64_t dim, 15 | std::optional dtype); 16 | 17 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 18 | 19 | XlaOpVector Lower(LoweringContext* loctx) const override; 20 | 21 | std::string ToString() const override; 22 | 23 | int64_t dim() const { return dim_; } 24 | 25 | const std::optional& dtype() const { return dtype_; } 26 | 27 | private: 28 | int64_t dim_; 29 | std::optional dtype_; 30 | }; 31 | 32 | } // namespace torch_xla 33 | 34 | #endif // XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_H_ 35 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/softmax_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class SoftmaxBackward : public XlaNode { 9 | public: 10 | SoftmaxBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::Value& output, int64_t dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t dim() const { return dim_; } 20 | 21 | private: 22 | // The dimension along which the result is computed. 23 | int64_t dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_SOFTMAX_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/split.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SPLIT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SPLIT_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | // Split the tensor into chunks along a given dimension. 11 | class Split : public XlaNode { 12 | public: 13 | Split(const torch::lazy::Value& input, std::vector split_sizes, 14 | int64_t dim); 15 | 16 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 17 | 18 | XlaOpVector Lower(LoweringContext* loctx) const override; 19 | 20 | std::string ToString() const override; 21 | 22 | const std::vector& split_sizes() const { return split_sizes_; } 23 | 24 | int64_t dim() const { return dim_; } 25 | 26 | private: 27 | std::vector split_sizes_; 28 | int64_t dim_; 29 | }; 30 | 31 | } // namespace torch_xla 32 | 33 | #endif // XLA_TORCH_XLA_CSRC_OPS_SPLIT_H_#pragma once -------------------------------------------------------------------------------- /torch_xla/csrc/ops/squeeze.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SQUEEZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SQUEEZE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Squeeze : public XlaNode { 9 | public: 10 | // Squeeze out the specified dimension index, -1 for all trivial dimensions. 11 | Squeeze(const torch::lazy::Value& input, int dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int dim() const { return dim_; } 20 | 21 | private: 22 | int dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_SQUEEZE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/stack.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_STACK_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_STACK_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class Stack : public XlaNode { 10 | public: 11 | Stack(c10::ArrayRef values, int64_t dim); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t dim() const { return dim_; }; 20 | 21 | private: 22 | int64_t dim_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_STACK_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/svd.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SVD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SVD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class SVD : public XlaNode { 9 | public: 10 | SVD(const torch::lazy::Value& input, bool some, bool compute_uv); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | bool some() const { return some_; } 19 | 20 | bool compute_uv() const { return compute_uv_; } 21 | 22 | private: 23 | bool some_; 24 | bool compute_uv_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_SVD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/symeig.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_SYMEIG_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_SYMEIG_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class SymEig : public XlaNode { 9 | public: 10 | SymEig(const torch::lazy::Value& input, bool eigenvectors, bool lower); 11 | 12 | std::string ToString() const override; 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | bool eigenvectors() const { return eigenvectors_; } 19 | 20 | bool lower() const { return lower_; } 21 | 22 | private: 23 | bool eigenvectors_; 24 | bool lower_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_SYMEIG_H_ 30 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/threshold.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | // IR node for the threshold operation. 9 | class Threshold : public XlaNode { 10 | public: 11 | Threshold(const torch::lazy::Value& input, float threshold, float value); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | float threshold() const { return threshold_; } 20 | 21 | float value() const { return value_; } 22 | 23 | private: 24 | float threshold_; 25 | float value_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/threshold_backward.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_BACKWARD_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_BACKWARD_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class ThresholdBackward : public XlaNode { 9 | public: 10 | ThresholdBackward(const torch::lazy::Value& grad_output, 11 | const torch::lazy::Value& input, float threshold); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | float threshold() const { return threshold_; } 20 | 21 | private: 22 | float threshold_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_THRESHOLD_BACKWARD_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/topk.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_TOPK_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_TOPK_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class TopK : public XlaNode { 9 | public: 10 | TopK(const torch::lazy::Value& input, int64_t k, int64_t dim, bool largest, 11 | bool sorted, bool stable); 12 | 13 | std::string ToString() const override; 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | int64_t k() const { return k_; }; 20 | 21 | int64_t dim() const { return dim_; }; 22 | 23 | bool largest() const { return largest_; } 24 | 25 | bool sorted() const { return sorted_; } 26 | 27 | bool stable() const { return stable_; } 28 | 29 | private: 30 | int64_t k_; 31 | int64_t dim_; 32 | bool largest_; 33 | bool sorted_; 34 | bool stable_; 35 | }; 36 | 37 | } // namespace torch_xla 38 | 39 | #endif // XLA_TORCH_XLA_CSRC_OPS_TOPK_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/tpu_custom_call.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class TpuCustomCall : public XlaNode { 9 | public: 10 | // Make a TPU custom call with payload, e.g., Mosaic. 11 | TpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, 12 | const std::string& payload); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | private: 21 | std::string payload_; 22 | }; 23 | 24 | } // namespace torch_xla 25 | 26 | #endif // XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_ 27 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/uniform.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UNIFORM_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UNIFORM_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Uniform : public XlaNode { 9 | public: 10 | Uniform(const torch::lazy::Value& from, const torch::lazy::Value& to, 11 | const torch::lazy::Value& seed, const xla::Shape& rng_shape); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | }; 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_OPS_UNIFORM_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/unselect.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UNSELECT_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UNSELECT_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Unselect : public XlaNode { 9 | public: 10 | Unselect(const torch::lazy::Value& target, const torch::lazy::Value& source, 11 | int64_t dim, int64_t start, int64_t end, int64_t stride); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int64_t dim() const { return dim_; } 20 | 21 | int64_t start() const { return start_; } 22 | 23 | int64_t end() const { return end_; } 24 | 25 | int64_t stride() const { return stride_; } 26 | 27 | private: 28 | int64_t dim_; 29 | int64_t start_; 30 | int64_t end_; 31 | int64_t stride_; 32 | }; 33 | 34 | } // namespace torch_xla 35 | 36 | #endif // XLA_TORCH_XLA_CSRC_OPS_UNSELECT_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/unsqueeze.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UNSQUEEZE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UNSQUEEZE_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | 6 | namespace torch_xla { 7 | 8 | class Unsqueeze : public XlaNode { 9 | public: 10 | // Insert a dimension of size one at the specified position. 11 | Unsqueeze(const torch::lazy::Value& input, int dim); 12 | 13 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | int dim() const { return dim_; } 20 | 21 | private: 22 | // Position to unsqueeze. 23 | int dim_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_UNSQUEEZE_H_#pragma once -------------------------------------------------------------------------------- /torch_xla/csrc/ops/update_slice.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UPDATE_SLICE_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UPDATE_SLICE_H_ 3 | 4 | #include "absl/types/span.h" 5 | #include "torch_xla/csrc/ir.h" 6 | 7 | namespace torch_xla { 8 | 9 | class UpdateSlice : public XlaNode { 10 | public: 11 | UpdateSlice(const torch::lazy::Value& input, const torch::lazy::Value& source, 12 | absl::Span base_indices); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const std::vector& base_indices() const { return base_indices_; } 21 | 22 | private: 23 | std::vector base_indices_; 24 | }; 25 | 26 | } // namespace torch_xla 27 | 28 | #endif // XLA_TORCH_XLA_CSRC_OPS_UPDATE_SLICE_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/upsample_bilinear2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_BILINEAR2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_BILINEAR2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class UpsampleBilinear : public XlaNode { 11 | public: 12 | UpsampleBilinear(const torch::lazy::Value& input, 13 | std::vector output_size, bool align_corners); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& output_size() const { return output_size_; } 22 | 23 | bool align_corners() const { return align_corners_; } 24 | 25 | private: 26 | std::vector output_size_; 27 | bool align_corners_; 28 | }; 29 | 30 | } // namespace torch_xla 31 | 32 | #endif // XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_BILINEAR2D_H_ 33 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/upsample_nearest2d.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_NEAREST2D_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_NEAREST2D_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class UpsampleNearest : public XlaNode { 11 | public: 12 | UpsampleNearest(const torch::lazy::Value& input, 13 | std::vector output_size); 14 | 15 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 16 | 17 | XlaOpVector Lower(LoweringContext* loctx) const override; 18 | 19 | std::string ToString() const override; 20 | 21 | const std::vector& output_size() const { return output_size_; } 22 | 23 | private: 24 | std::vector output_size_; 25 | }; 26 | 27 | } // namespace torch_xla 28 | 29 | #endif // XLA_TORCH_XLA_CSRC_OPS_UPSAMPLE_NEAREST2D_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/ops/user_computation.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_USER_COMPUTATION_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_USER_COMPUTATION_H_ 3 | 4 | #include "torch_xla/csrc/ir.h" 5 | #include "torch_xla/csrc/runtime/computation_client.h" 6 | 7 | namespace torch_xla { 8 | 9 | class UserComputation : public XlaNode { 10 | public: 11 | UserComputation(torch::lazy::OpKind op, torch::lazy::OpList operands, 12 | runtime::ComputationClient::ComputationPtr computation); 13 | 14 | torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; 15 | 16 | XlaOpVector Lower(LoweringContext* loctx) const override; 17 | 18 | std::string ToString() const override; 19 | 20 | const runtime::ComputationClient::ComputationPtr& computation() const { 21 | return computation_; 22 | } 23 | 24 | private: 25 | runtime::ComputationClient::ComputationPtr computation_; 26 | }; 27 | 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_TORCH_XLA_CSRC_OPS_USER_COMPUTATION_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/ops/view.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_OPS_VIEW_H_ 2 | #define XLA_TORCH_XLA_CSRC_OPS_VIEW_H_ 3 | 4 | #include 5 | 6 | #include "torch_xla/csrc/ir.h" 7 | 8 | namespace torch_xla { 9 | 10 | class ViewOp : public XlaNode { 11 | public: 12 | ViewOp(const torch::lazy::Value& input, std::vector output_size); 13 | ViewOp(const torch::lazy::Value& input, xla::Shape output_shape); 14 | 15 | XlaOpVector Lower(LoweringContext* loctx) const override; 16 | 17 | std::string ToString() const override; 18 | 19 | const std::vector& output_size() const { return output_size_; } 20 | 21 | private: 22 | std::vector output_size_; 23 | }; 24 | 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_TORCH_XLA_CSRC_OPS_VIEW_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/random.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_RANDOM_H_ 2 | #define XLA_TORCH_XLA_CSRC_RANDOM_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | // Set downcast to true if the caller knows the |maxval - minval| is appropriate 9 | // for f16 dtype. We avoid computing the range on-the-fly since it incurs an XLA 10 | // computation. 11 | xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape, 12 | xla::XlaOp minval, xla::XlaOp maxval, 13 | bool downcast = false); 14 | 15 | xla::XlaOp RngDiscreteUniform(xla::XlaOp seed, const xla::Shape& shape, 16 | xla::XlaOp minval, xla::XlaOp maxval); 17 | 18 | xla::XlaOp RngNormal(xla::XlaOp seed, const xla::Shape& shape, xla::XlaOp mean, 19 | xla::XlaOp std); 20 | 21 | } // namespace torch_xla 22 | 23 | #endif // XLA_TORCH_XLA_CSRC_RANDOM_H_ 24 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/env_hash.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace torch_xla { 4 | namespace runtime { 5 | namespace hash { 6 | 7 | // Take a hash of XLA flags which impact the compilation result. 8 | // TODO(jonbolin): We should move away from manually hashing the env vars and 9 | // instead hash the compilation environment directly when the functionality is 10 | // supported in the runtime. 11 | torch::lazy::hash_t HashXlaEnvVars(); 12 | 13 | } // namespace hash 14 | } // namespace runtime 15 | } // namespace torch_xla 16 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/metrics_reader.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_METRICS_READER_H_ 2 | #define XLA_CLIENT_METRICS_READER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "torch_xla/csrc/runtime/metrics.h" 9 | #include "torch_xla/csrc/runtime/types.h" 10 | 11 | namespace torch_xla { 12 | namespace runtime { 13 | namespace metrics_reader { 14 | 15 | // Creates a report with the current metrics statistics. 16 | std::string CreateMetricReport( 17 | const std::map& xrt_metrics); 18 | 19 | // Creates a report with the selected metrics statistics. 20 | std::string CreateMetricReport(const std::vector& counter_names, 21 | const std::vector& metric_names); 22 | 23 | } // namespace metrics_reader 24 | } // namespace runtime 25 | } // namespace torch_xla 26 | 27 | #endif // XLA_CLIENT_METRICS_READER_H_ 28 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/pjrt_registry.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ 2 | #define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_ 3 | 4 | #include "torch_xla/csrc/runtime/xla_coordinator.h" 5 | #include "xla/pjrt/pjrt_client.h" 6 | #include "xla/pjrt/pjrt_common.h" 7 | 8 | namespace torch_xla { 9 | namespace runtime { 10 | 11 | class PjRtPlugin { 12 | public: 13 | virtual std::string library_path() const = 0; 14 | 15 | virtual const std::unordered_map 16 | client_create_options() const = 0; 17 | 18 | virtual bool requires_xla_coordinator() const = 0; 19 | }; 20 | 21 | void RegisterPjRtPlugin(std::string name, 22 | std::shared_ptr plugin); 23 | 24 | std::tuple, std::unique_ptr> 25 | InitializePjRt(const std::string& device_type); 26 | 27 | } // namespace runtime 28 | } // namespace torch_xla 29 | 30 | #endif // XLA_CLIENT_INITIALIZE_PJRT_H_ 31 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/runtime.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_RUNTIME_H_ 2 | #define XLA_CLIENT_RUNTIME_H_ 3 | 4 | #include "torch_xla/csrc/runtime/computation_client.h" 5 | 6 | namespace torch_xla { 7 | namespace runtime { 8 | 9 | // Returns the ComputationClient singleton. 10 | ComputationClient* GetComputationClient(); 11 | 12 | ComputationClient* GetComputationClientIfInitialized(); 13 | 14 | // Run the XRT local service, this will block the caller unitl the server 15 | // being stopped. 16 | void RunLocalService(uint64_t service_port); 17 | 18 | } // namespace runtime 19 | } // namespace torch_xla 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/stablehlo_composite_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef STABLEHLO_COMPOSITE_HELPER_H_ 2 | #define STABLEHLO_COMPOSITE_HELPER_H_ 3 | 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/IR/BuiltinOps.h" 6 | #include "mlir/Pass/Pass.h" 7 | 8 | namespace torch_xla { 9 | namespace runtime { 10 | 11 | std::unique_ptr> 12 | CreateBuildStableHLOCompositePass(); 13 | 14 | std::unique_ptr> 15 | CreateRemoveXlaMarkTensorOpsPass(); 16 | 17 | } // namespace runtime 18 | } // namespace torch_xla 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/tf_logging.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/runtime/tf_logging.h" 2 | 3 | #include 4 | 5 | namespace torch_xla { 6 | namespace runtime { 7 | namespace internal { 8 | 9 | void ErrorGenerator::operator&(const std::basic_ostream& oss) const { 10 | const ErrorSink& sink = dynamic_cast(oss); 11 | auto sink_str = sink.str(); 12 | TF_VLOG(1) << sink_str; 13 | std::stringstream ess; 14 | ess << file_ << ":" << line_ << " : " << sink_str; 15 | // We cannot use AT_ERROR() here, due to layering issues. 16 | throw std::runtime_error(ess.str()); 17 | } 18 | 19 | } // namespace internal 20 | } // namespace runtime 21 | } // namespace torch_xla 22 | -------------------------------------------------------------------------------- /torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_MLIR_DEBUGINFO_HELPER_H_ 2 | #define XLA_MLIR_DEBUGINFO_HELPER_H_ 3 | 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/IR/BuiltinOps.h" 6 | #include "mlir/Pass/Pass.h" 7 | 8 | namespace torch_xla { 9 | namespace runtime { 10 | 11 | std::unique_ptr> 12 | CreatePrepareXlaMlirDebuginfoPass(); 13 | 14 | } // namespace runtime 15 | } // namespace torch_xla 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /torch_xla/csrc/shape_builder.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_SHAPE_BUILDER_H_ 2 | #define XLA_TORCH_XLA_CSRC_SHAPE_BUILDER_H_ 3 | 4 | #include 5 | 6 | #include "absl/types/span.h" 7 | #include "xla/shape.h" 8 | #include "xla/types.h" 9 | 10 | namespace torch_xla { 11 | 12 | class ShapeBuilder { 13 | public: 14 | explicit ShapeBuilder(xla::PrimitiveType type) : type_(type) {} 15 | 16 | ShapeBuilder& Add(const xla::Shape& shape, int64_t dim); 17 | 18 | ShapeBuilder& Add(const xla::Shape& shape, 19 | absl::Span dimensions); 20 | 21 | ShapeBuilder& Add(int64_t size); 22 | 23 | xla::Shape Build() const; 24 | 25 | private: 26 | struct ShapeDim { 27 | const xla::Shape* shape = nullptr; 28 | int64_t dim_or_size = -1; 29 | }; 30 | 31 | xla::PrimitiveType type_; 32 | std::vector dims_; 33 | }; 34 | 35 | } // namespace torch_xla 36 | 37 | #endif // XLA_TORCH_XLA_CSRC_SHAPE_BUILDER_H -------------------------------------------------------------------------------- /torch_xla/csrc/shape_helper.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/shape_helper.h" 2 | 3 | #include "torch_xla/csrc/runtime/debug_macros.h" 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | const xla::Shape& ShapeHelper::ShapeOfXlaOp(xla::XlaOp op) { 9 | const xla::Shape* shape = ConsumeValue(op.builder()->GetShapePtr(op)); 10 | return *shape; 11 | } 12 | 13 | } // namespace torch_xla 14 | -------------------------------------------------------------------------------- /torch_xla/csrc/shape_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_SHAPE_HELPER_H_ 2 | #define XLA_TORCH_XLA_SHAPE_HELPER_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | class ShapeHelper { 9 | public: 10 | // Returns the shape of the given XLA operation. 11 | static const xla::Shape& ShapeOfXlaOp(xla::XlaOp op); 12 | }; 13 | 14 | } // namespace torch_xla 15 | 16 | #endif // XLA_TORCH_XLA_SHAPE_HELPER_H_ 17 | -------------------------------------------------------------------------------- /torch_xla/csrc/softmax_builder.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_SOFTMAX_BUILDER_H_ 2 | #define XLA_TORCH_XLA_CSRC_SOFTMAX_BUILDER_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | // Computes log(softmax(logits)) along the dimension specified by "dim". 9 | xla::XlaOp BuildLogSoftmax(xla::XlaOp logits, int64_t dim); 10 | 11 | // Computes the gradient of the input of the LogSoftmax function. 12 | xla::XlaOp BuildLogSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, 13 | int64_t dim); 14 | 15 | xla::XlaOp BuildSoftmax(xla::XlaOp logits, int64_t dim); 16 | 17 | xla::XlaOp BuildSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, 18 | int64_t dim); 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_SOFTMAX_BUILDER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/tensor_common.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_TENSOR_COMMON_H_ 2 | #define XLA_TORCH_XLA_CSRC_TENSOR_COMMON_H_ 3 | 4 | #include 5 | 6 | #include "xla/hlo/builder/xla_builder.h" 7 | 8 | namespace torch_xla { 9 | 10 | // XLA SPMD sharding spec annoation. The XLA tensor uses this to create 11 | // HloSharding for replication, manual and tile shardings. 12 | struct ShardingSpec { 13 | ShardingSpec(const xla::OpSharding& sharding) : sharding(sharding) {} 14 | ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape) 15 | : sharding(sharding), shape(shape) {} 16 | 17 | xla::OpSharding sharding; 18 | // Optional source tensor shape unpartitioned. 19 | std::optional shape; 20 | }; 21 | 22 | using ShardingSpecPtr = std::shared_ptr; 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_TENSOR_COMMON_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/thread_pool.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_xla/csrc/thread_pool.h" 2 | 3 | #include 4 | 5 | #include "torch_xla/csrc/runtime/sys_util.h" 6 | #include "tsl/platform/env.h" 7 | #include "tsl/platform/threadpool.h" 8 | 9 | namespace torch_xla { 10 | namespace thread { 11 | 12 | void Schedule(std::function fn) { 13 | static size_t num_threads = torch_xla::runtime::sys_util::GetEnvInt( 14 | "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); 15 | static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", 16 | num_threads); 17 | pool.Schedule(std::move(fn)); 18 | } 19 | 20 | } // namespace thread 21 | } // namespace torch_xla 22 | -------------------------------------------------------------------------------- /torch_xla/csrc/thread_pool.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_CLIENT_THREAD_POOL_H_ 2 | #define XLA_CLIENT_THREAD_POOL_H_ 3 | 4 | #include 5 | 6 | namespace torch_xla { 7 | namespace thread { 8 | 9 | // Schedules a closure to be run. The closure should not block waiting for other 10 | // events. 11 | void Schedule(std::function fn); 12 | 13 | } // namespace thread 14 | } // namespace torch_xla 15 | 16 | #endif // XLA_CLIENT_THREAD_POOL_H_ 17 | -------------------------------------------------------------------------------- /torch_xla/csrc/token_handler.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_TOKEN_HANDLER_H_ 2 | #define XLA_TORCH_XLA_CSRC_TOKEN_HANDLER_H_ 3 | 4 | #include "xla/hlo/builder/xla_builder.h" 5 | 6 | namespace torch_xla { 7 | 8 | class TokenHandler { 9 | public: 10 | explicit TokenHandler(xla::XlaOp token) : token_(token) {} 11 | 12 | xla::XlaOp GetInput(xla::XlaOp input, const xla::Shape* input_shape); 13 | 14 | xla::XlaOp GetNewToken(xla::XlaOp result); 15 | 16 | private: 17 | xla::XlaOp token_; 18 | }; 19 | 20 | } // namespace torch_xla 21 | 22 | #endif // XLA_TORCH_XLA_CSRC_TOKEN_HANDLER_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/unwrap_data.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_UNWRAP_DATA_H_ 2 | #define XLA_TORCH_XLA_CSRC_UNWRAP_DATA_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "absl/types/span.h" 10 | #include "torch_xla/csrc/runtime/computation_client.h" 11 | 12 | namespace torch_xla { 13 | 14 | runtime::ComputationClient::DataPtr UnwrapXlaData( 15 | const torch::lazy::BackendDataPtr& data); 16 | 17 | std::vector UnwrapXlaData( 18 | absl::Span datas); 19 | 20 | std::vector WrapXlaData( 21 | absl::Span xla_datas); 22 | 23 | } // namespace torch_xla 24 | 25 | #endif // XLA_TORCH_XLA_CSRC_UNWRAP_DATA_H 26 | -------------------------------------------------------------------------------- /torch_xla/csrc/version.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_VERSION_H_ 2 | #define XLA_TORCH_XLA_CSRC_VERSION_H_ 3 | 4 | namespace torch_xla { 5 | 6 | extern const char XLA_GITREV[]; 7 | extern const char TORCH_GITREV[]; 8 | 9 | } // namespace torch_xla 10 | 11 | #endif // XLA_TORCH_XLA_CSRC_VERSION_H_ -------------------------------------------------------------------------------- /torch_xla/csrc/xla_backend_impl.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_XLA_BACKEND_IMPL_H_ 2 | #define XLA_TORCH_XLA_CSRC_XLA_BACKEND_IMPL_H_ 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "torch_xla/csrc/device.h" 10 | #include "torch_xla/csrc/runtime/computation_client.h" 11 | 12 | namespace torch_xla { 13 | 14 | torch::lazy::BackendImplInterface* GetXlaBackendImpl(); 15 | 16 | bool InitXlaBackend(); 17 | 18 | } // namespace torch_xla 19 | 20 | #endif // XLA_TORCH_XLA_CSRC_XLA_BACKEND_IMPL_H_ 21 | -------------------------------------------------------------------------------- /torch_xla/csrc/xla_op_builder.h: -------------------------------------------------------------------------------- 1 | #ifndef XLA_TORCH_XLA_CSRC_XLA_OP_BUILDER_H_ 2 | #define XLA_TORCH_XLA_CSRC_XLA_OP_BUILDER_H_ 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "xla/hlo/builder/xla_builder.h" 11 | 12 | namespace torch_xla { 13 | namespace op_builder { 14 | 15 | using BuilderPtr = std::shared_ptr; 16 | 17 | struct Op { 18 | Op(BuilderPtr builder, xla::XlaOp op) 19 | : builder(std::move(builder)), op(std::move(op)) {} 20 | 21 | BuilderPtr builder; 22 | xla::XlaOp op; 23 | }; 24 | 25 | using OpPtr = std::shared_ptr; 26 | 27 | py::object ShapeToPyShape(const xla::Shape& shape); 28 | 29 | xla::Shape PyShapeToShape(py::object shape); 30 | 31 | OpPtr CreateOp(BuilderPtr builder, const std::string& opname, 32 | const std::vector& operands, py::dict args); 33 | 34 | } // namespace op_builder 35 | } // namespace torch_xla 36 | 37 | #endif // XLA_TORCH_XLA_CSRC_XLA_OP_BUILDER_H_ -------------------------------------------------------------------------------- /torch_xla/debug/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torch_xla/debug/__init__.py -------------------------------------------------------------------------------- /torch_xla/debug/graph_saver.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import threading 4 | import torch_xla 5 | 6 | _SAVE_GRAPH_LOCK = threading.Lock() 7 | _SAVE_GRAPH_IDS = collections.defaultdict(dict) 8 | 9 | 10 | def save_tensors_graph(save_dir, name, tensors): 11 | fmt = os.environ.get('SAVE_GRAPH_FMT', 'text') 12 | if fmt == 'text': 13 | graph = torch_xla._XLAC._get_xla_tensors_text(tensors) 14 | elif fmt == 'dot': 15 | graph = torch_xla._XLAC._get_xla_tensors_dot(tensors) 16 | elif fmt == 'hlo': 17 | graph = torch_xla._XLAC._get_xla_tensors_hlo(tensors) 18 | else: 19 | raise RuntimeError('Invalid save graph format: {}'.format(fmt)) 20 | tid = threading.current_thread().ident 21 | with _SAVE_GRAPH_LOCK: 22 | tdict = _SAVE_GRAPH_IDS[tid] 23 | id = tdict.get(name, 0) 24 | tdict[name] = id + 1 25 | fname = '{}-{}-{}.{}'.format(name, tid, id, fmt) 26 | with open(os.path.join(save_dir, fname), 'w') as fd: 27 | fd.write(graph) 28 | -------------------------------------------------------------------------------- /torch_xla/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torch_xla/distributed/__init__.py -------------------------------------------------------------------------------- /torch_xla/distributed/fsdp/__init__.py: -------------------------------------------------------------------------------- 1 | from .xla_fully_sharded_data_parallel import XlaFullyShardedDataParallel 2 | from .state_dict_utils import (consolidate_sharded_state_dicts, 3 | consolidate_sharded_model_checkpoints) 4 | from .utils import checkpoint_module 5 | 6 | __all__ = [ 7 | "XlaFullyShardedDataParallel", 8 | "consolidate_sharded_state_dicts", 9 | "consolidate_sharded_model_checkpoints", 10 | "checkpoint_module", 11 | ] 12 | -------------------------------------------------------------------------------- /torch_xla/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from .eager import eager_mode, is_eager_mode, eager_mode_context 2 | 3 | __all__ = [ 4 | "eager_mode", 5 | "is_eager_mode", 6 | "eager_mode_context", 7 | ] 8 | -------------------------------------------------------------------------------- /torch_xla/experimental/distributed_checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .manager import CheckpointManager 2 | from .planners import SPMDSavePlanner, SPMDLoadPlanner 3 | from .util import prime_optimizer 4 | 5 | __all__ = [ 6 | "CheckpointManager", 7 | "SPMDSavePlanner", 8 | "SPMDLoadPlanner", 9 | "prime_optimizer", 10 | ] 11 | -------------------------------------------------------------------------------- /torch_xla/experimental/dynamo_mark_sharding.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.warn( 4 | "dynamo_mark_sharding will be auto registered starting 2.5 release, please remove this import" 5 | ) 6 | -------------------------------------------------------------------------------- /torch_xla/experimental/dynamo_set_buffer_donor.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.warn( 4 | "dynamo_set_buffer_donor_ will be auto registered starting 2.5 release, please remove this import" 5 | ) 6 | -------------------------------------------------------------------------------- /torch_xla/experimental/eager.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from contextlib import contextmanager 3 | 4 | import torch_xla 5 | import logging 6 | 7 | 8 | def eager_mode(enable: bool): 9 | """Configure torch_xla's default execution mode. 10 | 11 | Under eager mode only functions that was `torch_xla.compile`d will be 12 | traced and compiled. Other torch ops will be executed eagerly. 13 | """ 14 | torch_xla._XLAC._set_use_eager_mode(enable) 15 | 16 | 17 | def is_eager_mode() -> bool: 18 | """Return True if torch_xla is currently under eager mode 19 | """ 20 | return torch_xla._XLAC._get_use_eager_mode() 21 | 22 | 23 | @contextmanager 24 | def eager_mode_context(enable: bool): 25 | """Context manager to enable/disable the eager mode. 26 | """ 27 | saved_eager_mode = is_eager_mode() 28 | eager_mode(enable) 29 | try: 30 | yield saved_eager_mode 31 | finally: 32 | eager_mode(saved_eager_mode) 33 | -------------------------------------------------------------------------------- /torch_xla/experimental/pallas_kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torch_xla/experimental/pallas_kernels/__init__.py -------------------------------------------------------------------------------- /torch_xla/experimental/pjrt_backend.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.distributed as dist 4 | from torch_xla.distributed import xla_backend 5 | from torch_xla._internal import rendezvous 6 | from torch_xla._internal import tpu 7 | 8 | if tpu.num_available_chips() > 0 and tpu.version() <= 3: 9 | from torch.testing._internal.distributed import multi_threaded_pg 10 | logging.warning('Patching torch.distributed state to support multithreading.') 11 | logging.warning('torch.distributed support on TPU v2 and v3 is experimental ' 12 | 'and does not support torchrun.') 13 | multi_threaded_pg._install_threaded_pg() 14 | 15 | dist.register_rendezvous_handler('pjrt', rendezvous.pjrt_rendezvous_handler) 16 | -------------------------------------------------------------------------------- /torch_xla/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torch_xla/test/__init__.py -------------------------------------------------------------------------------- /torch_xla/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torch_xla/utils/__init__.py -------------------------------------------------------------------------------- /torch_xla/utils/buffer_donor_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch_xla 4 | 5 | 6 | @contextmanager 7 | def alias_with_buffer_donor_config(should_alias: bool = True): 8 | saved_config = torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config( 9 | ) 10 | torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(should_alias) 11 | try: 12 | yield saved_config 13 | finally: 14 | torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(saved_config) 15 | -------------------------------------------------------------------------------- /torchax/build_nightly.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | NIGHTLY_VERSION=$(date '+%Y%m%d%H%M') 5 | 6 | # Update the version to .devYYYYMMDDHHMM in __init__.py 7 | VERSION_UPDATE_PATTERN="s/^__version__\s*=\s*\"([^\"]+)\"/__version__ = \"\1.dev$NIGHTLY_VERSION\"/g;" 8 | sed -r "$VERSION_UPDATE_PATTERN" torchax/__init__.py --in-place 9 | 10 | hatch build -t wheel 11 | -------------------------------------------------------------------------------- /torchax/dev-requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch 2 | torch==2.6.0 ; sys_platform == 'darwin' # macOS 3 | torch==2.6.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU 4 | yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml` 5 | flax==0.10.6 6 | -------------------------------------------------------------------------------- /torchax/docs/dispatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/docs/dispatch.png -------------------------------------------------------------------------------- /torchax/docs/torch_dispatch/run_env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchax 3 | 4 | env = torchax.default_env() 5 | env.config.debug_print_each_op = True 6 | env.config.debug_accuracy_for_each_op = True 7 | 8 | with env: 9 | y = torch.tensor([1, 5, 10]) 10 | print(torch.trapezoid(y)) 11 | print(torch.trapz(y, y)) 12 | -------------------------------------------------------------------------------- /torchax/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/examples/__init__.py -------------------------------------------------------------------------------- /torchax/examples/eager_mode.py: -------------------------------------------------------------------------------- 1 | import torchax 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch 5 | 6 | xla_env = torchax.enable_globally() 7 | 8 | 9 | class MyModel(nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.fc1 = nn.Linear(28 * 28, 120) 14 | self.fc2 = nn.Linear(120, 84) 15 | self.fc3 = nn.Linear(84, 10) 16 | 17 | def forward(self, x): 18 | x = x.view(-1, 28 * 28) 19 | x = F.relu(self.fc1(x)) 20 | x = F.relu(self.fc2(x)) 21 | x = self.fc3(x) 22 | return x 23 | 24 | 25 | m = MyModel() 26 | m = m.to('jax') 27 | 28 | # Execute this model using torch 29 | inputs = torch.randn(3, 3, 28, 28, device='jax') 30 | 31 | print(m(inputs)) 32 | print('---=====') 33 | 34 | m_compiled = torchax.compile(m) 35 | 36 | print(m_compiled(inputs)) 37 | 38 | print('---') 39 | -------------------------------------------------------------------------------- /torchax/examples/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | matplotlib 3 | optax -------------------------------------------------------------------------------- /torchax/examples/train_gpt/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | git+https://github.com/karpathy/minGPT.git@master 3 | datasets 4 | tiktoken 5 | -------------------------------------------------------------------------------- /torchax/examples/train_llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/examples/train_llama/__init__.py -------------------------------------------------------------------------------- /torchax/examples/train_llama_torchtitan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/examples/train_llama_torchtitan/__init__.py -------------------------------------------------------------------------------- /torchax/format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | yapf --recursive -i *.py test torchax -------------------------------------------------------------------------------- /torchax/test-requirements.txt: -------------------------------------------------------------------------------- 1 | -r dev-requirements.txt 2 | absl-py==2.2.2 3 | immutabledict==4.2.1 4 | pytest==8.3.5 5 | pytest-xdist==3.6.1 6 | pytest-forked==1.6.0 7 | sentencepiece==0.2.0 8 | expecttest==0.3.0 9 | optax==0.2.4 10 | tensorflow==2.19.0 11 | -------------------------------------------------------------------------------- /torchax/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/test/__init__.py -------------------------------------------------------------------------------- /torchax/test/gemma/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/test/gemma/__init__.py -------------------------------------------------------------------------------- /torchax/test/llama/BUILD: -------------------------------------------------------------------------------- 1 | # TODO(hanq): describe this package. 2 | load( 3 | "//third_party/py/torch/google/bazel_rules/rules_python/python:defs.bzl", 4 | "py_test", 5 | ) 6 | 7 | package( 8 | default_applicable_licenses = ["//devtools/compliance/licenses:no_external_contributions"], 9 | default_visibility = ["//visibility:public"], 10 | licenses = ["notice"], 11 | ) 12 | 13 | py_test( 14 | name = "test_llama", 15 | srcs = [ 16 | "llama_model.py", 17 | "test_llama.py", 18 | ], 19 | deps = [ 20 | "//third_party/py/jax", 21 | "//third_party/py/torch:pytorch", 22 | "//third_party/py/torch/google/_torx", 23 | "//third_party/py/torch/google/_torx/test:test_base", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /torchax/test/llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/test/llama/__init__.py -------------------------------------------------------------------------------- /torchax/test/moe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/test/moe/__init__.py -------------------------------------------------------------------------------- /torchax/test_dist/README.md: -------------------------------------------------------------------------------- 1 | This directory contains multi-accelerator tests that cannot be distributed with 2 | `pytest-xdist`. 3 | 4 | TODO: merge these into `tests/` 5 | -------------------------------------------------------------------------------- /torchax/test_dist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/xla/d4c1be3776f88b74cb0b5e693afeb6a75534ee36/torchax/test_dist/__init__.py -------------------------------------------------------------------------------- /torchax/torchax/config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | 4 | @dataclasses.dataclass 5 | class Configuration: 6 | debug_print_each_op: bool = False 7 | debug_accuracy_for_each_op: bool = False 8 | debug_mixed_tensor: bool = False 9 | debug_print_each_op_operands: bool = False 10 | use_int32_for_index: bool = False 11 | 12 | # Flash attention 13 | use_tpu_flash_attention: bool = False 14 | shmap_flash_attention: bool = False 15 | 16 | # device 17 | treat_cuda_as_jax_device: bool = True 18 | use_torch_native_for_cpu_tensor: bool = True 19 | internal_respect_torch_return_dtypes: bool = False 20 | -------------------------------------------------------------------------------- /torchax/torchax/device_module.py: -------------------------------------------------------------------------------- 1 | def _is_in_bad_fork(): 2 | return False 3 | 4 | 5 | def manual_seed_all(seed): 6 | pass 7 | 8 | 9 | def device_count(): 10 | return 1 11 | 12 | 13 | def get_rng_state(): 14 | return [] 15 | 16 | 17 | def set_rng_state(new_state, device): 18 | pass 19 | 20 | 21 | def is_available(): 22 | return True 23 | 24 | 25 | def current_device(): 26 | return 0 27 | -------------------------------------------------------------------------------- /torchax/torchax/ops/__init__.py: -------------------------------------------------------------------------------- 1 | def all_aten_jax_ops(): 2 | # to load the ops 3 | import torchax.ops.jaten # type: ignore 4 | import torchax.ops.ops_registry # type: ignore 5 | 6 | return { 7 | key: val.func 8 | for key, val in torchax.ops.ops_registry.all_aten_ops.items() 9 | if val.is_jax_function 10 | } 11 | -------------------------------------------------------------------------------- /torchax/torchax/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any, Union, ParamSpec, TypeAlias 2 | import torch 3 | import jax 4 | import jax.numpy as jnp 5 | import sys 6 | 7 | P = ParamSpec('P') 8 | 9 | TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] 10 | TorchCallable: TypeAlias = Callable[P, TorchValue] 11 | JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any] 12 | JaxCallable: TypeAlias = Callable[P, JaxValue] --------------------------------------------------------------------------------