├── .clang-format ├── .flake8 ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md ├── actions │ ├── atorch-pre-commit │ │ └── action.yml │ └── atorch-python-test │ │ └── action.yml ├── pull_request_template.md └── workflows │ └── main.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LEGAL.md ├── LICENSE ├── README.md ├── atorch ├── __init__.py ├── amp │ ├── __init__.py │ ├── amp.py │ ├── hook.py │ └── pipe_amp.py ├── auto │ ├── __init__.py │ ├── accelerate.py │ ├── analyser │ │ ├── __init__.py │ │ └── analyser.py │ ├── auto_accelerate_context.py │ ├── clip_grad_norm.py │ ├── device_context.py │ ├── dry_runner │ │ ├── __init__.py │ │ └── dry_runner.py │ ├── engine │ │ ├── __init__.py │ │ ├── acceleration_engine.py │ │ ├── analyser_result.py │ │ ├── client.py │ │ ├── executor.py │ │ ├── optimization_method.py │ │ ├── planner.py │ │ ├── servicer.py │ │ ├── sg_algo │ │ │ ├── __init__.py │ │ │ ├── bayes_opt_sg.py │ │ │ ├── combination_sg.py │ │ │ ├── hebo │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── acq_optimizers │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── evolution_optimizer.py │ │ │ │ ├── acquisitions │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── acq.py │ │ │ │ ├── design_space │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── categorical_param.py │ │ │ │ │ ├── design_space.py │ │ │ │ │ └── param.py │ │ │ │ ├── models │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_model.py │ │ │ │ │ ├── gauss_process │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── gpy_wgp.py │ │ │ │ │ ├── layers.py │ │ │ │ │ ├── model_factory.py │ │ │ │ │ ├── random_forest │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── rf.py │ │ │ │ │ └── util.py │ │ │ │ └── optimizers │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── abstract_optimizer.py │ │ │ │ │ ├── hebo.py │ │ │ │ │ └── util.py │ │ │ ├── sg_algo_lib.py │ │ │ ├── sg_algorithm.py │ │ │ └── utils.py │ │ ├── strategy.py │ │ └── task.py │ ├── engine_client.py │ ├── model_context.py │ ├── opt_lib │ │ ├── __init__.py │ │ ├── amp_optimization.py │ │ ├── checkpoint_optimization.py │ │ ├── ds_3d_parallel_optimization.py │ │ ├── dynamo_backends.py │ │ ├── dynamo_optimization.py │ │ ├── half_optimization.py │ │ ├── mixed_parallel_optimization.py │ │ ├── module_replace_optimization.py │ │ ├── optimization.py │ │ ├── optimization_library.py │ │ ├── parallel_mode_optimization.py │ │ ├── pipeline_parallel_optimization.py │ │ ├── selective_offloading_checkpoint.py │ │ ├── sequence_parallel_optimization.py │ │ ├── shard_planners │ │ │ ├── __init__.py │ │ │ ├── base_stage_planner.py │ │ │ ├── base_tp_planner.py │ │ │ ├── dim_planner.py │ │ │ ├── mip_tp_planner.py │ │ │ ├── topology.py │ │ │ └── utils.py │ │ ├── tensor_parallel_optimization.py │ │ ├── utils.py │ │ └── zero_optimization.py │ ├── strategy.py │ └── task.py ├── checkpoint │ ├── __init__.py │ └── torch_checkpoint.py ├── common │ ├── __init__.py │ ├── constants.py │ ├── env.py │ ├── log_utils.py │ ├── singleton.py │ └── util_func.py ├── communication │ ├── __init__.py │ ├── functions.py │ └── pipe_communicator.py ├── data │ ├── __init__.py │ ├── coworker_dataset.py │ ├── data_utils.py │ ├── elastic_dataloader.py │ ├── elastic_dataset.py │ ├── preloader.py │ ├── shm_context.py │ ├── shm_dataloader.py │ ├── unordered_dataloader.py │ └── unshuffled_batch_dataloader.py ├── data_parallel │ ├── __init__.py │ ├── adp.py │ ├── auto_wrap.py │ ├── wrapper.py │ └── zero_ddp_mix_112.py ├── distributed │ ├── __init__.py │ ├── distributed.py │ ├── elastic_controller.py │ ├── elastic_trainer.py │ ├── hooks.py │ ├── launch.py │ ├── mesh.py │ └── run.py ├── fault_tolerance │ ├── __init__.py │ ├── api.py │ ├── custom_agent.py │ └── hanging_detector.py ├── kernels │ ├── __init__.py │ ├── extensions │ │ ├── __init__.py │ │ ├── abstract_extension.py │ │ ├── flash_atten_1_extension.py │ │ ├── flash_atten_extension.py │ │ ├── flash_attention │ │ │ ├── __init__.py │ │ │ ├── dropout_add_layer_norm.py │ │ │ ├── flash_attn_3_func_ext.py │ │ │ ├── flash_attn_cross_entropy.py │ │ │ └── flash_attn_func_ext.py │ │ ├── flash_attention_1 │ │ │ ├── __init__.py │ │ │ ├── dropout_add_layer_norm_1.py │ │ │ └── flash_attn_func_ext_1.py │ │ ├── grouped_gemm_exts │ │ │ ├── __init__.py │ │ │ └── grouped_gemm_gmm.py │ │ ├── npu │ │ │ ├── __init__.py │ │ │ ├── adamw_npu.py │ │ │ ├── flash_attention_npu.py │ │ │ ├── fused_cross_entropy_npu.py │ │ │ ├── fused_permute_npu.py │ │ │ ├── fused_unpermute_npu.py │ │ │ └── rms_norm_npu.py │ │ ├── npu_extension.py │ │ ├── te │ │ │ ├── __init__.py │ │ │ └── moe_func.py │ │ ├── torch_xla_extension.py │ │ └── xla │ │ │ ├── __init__.py │ │ │ └── flash_attention_xla.py │ ├── patches │ │ ├── __init__.py │ │ └── patch_llama3_fa3.py │ └── triton_jit │ │ ├── __init__.py │ │ ├── atorch_layer_norm.py │ │ ├── bias_gather_add.py │ │ ├── cross_entropy.py │ │ ├── rmsnorm_kernel.py │ │ ├── rope.py │ │ ├── swiglu.py │ │ └── triton_import_lib.py ├── local_sgd │ ├── DDP │ │ ├── __init__.py │ │ ├── outer_optim_model_averager.py │ │ └── stateful_post_localSGD_optimizer.py │ ├── FSDP │ │ ├── __init__.py │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── _init_utils.py │ │ │ ├── _runtime_utils.py │ │ │ └── _state_dict_utils.py │ │ ├── torch_2_1_0 │ │ │ ├── __init__.py │ │ │ ├── _init_utils.py │ │ │ └── _runtime_utils.py │ │ └── torch_2_4_0 │ │ │ ├── __init__.py │ │ │ ├── _init_utils.py │ │ │ └── _runtime_utils.py │ ├── README.md │ ├── __init__.py │ ├── configs │ │ ├── __init__.py │ │ └── configs.py │ ├── megatron │ │ ├── __init__.py │ │ ├── arguments.py │ │ ├── distributed_data_parallel.py │ │ ├── optimizer.py │ │ ├── parallel_state.py │ │ ├── param_and_grad_buffer.py │ │ ├── timers.py │ │ └── training.py │ └── utils │ │ ├── __init__.py │ │ ├── anomaly_detection.py │ │ └── reduce_methods │ │ ├── __init__.py │ │ ├── base.py │ │ ├── generalized_task_arithmetic.py │ │ ├── linear.py │ │ └── sparsify.py ├── modules │ ├── __init__.py │ ├── distributed_modules │ │ ├── __init__.py │ │ ├── activation_checkpointing.py │ │ ├── compilers │ │ │ ├── __init__.py │ │ │ ├── pipe_compiler │ │ │ │ ├── PipelineStage.py │ │ │ │ ├── StageInterleaver.py │ │ │ │ ├── __init__.py │ │ │ │ ├── distributed_pippy_compiler.py │ │ │ │ └── utils.py │ │ │ └── tp_compiler │ │ │ │ ├── __init__.py │ │ │ │ ├── dtensor_compiler.py │ │ │ │ └── tp_compiler.py │ │ ├── cross_entropy.py │ │ ├── layers.py │ │ ├── mappings.py │ │ ├── mappings_registry.py │ │ ├── materialize_modules.py │ │ ├── modules_registry.py │ │ ├── randomizer.py │ │ ├── transformer.py │ │ └── utils.py │ ├── distributed_transformer │ │ ├── __init__.py │ │ ├── commu_utils.py │ │ └── distributed_attention.py │ ├── fp8 │ │ ├── __init__.py │ │ ├── cuda_kernel.py │ │ ├── precision_switchable_linear.py │ │ ├── quantize.py │ │ ├── quantized_grouped_linear.py │ │ ├── scaled_linear.py │ │ ├── triton_kernel.py │ │ └── utils.py │ ├── moe │ │ ├── __init__.py │ │ ├── ddp.py │ │ ├── grouped_gemm_moe.py │ │ ├── inject.py │ │ ├── moe_layer.py │ │ ├── switch_gating.py │ │ ├── token_dispatcher.py │ │ └── topk_gating.py │ ├── normalization │ │ ├── __init__.py │ │ └── layernorm.py │ └── transformer │ │ ├── __init__.py │ │ ├── _fa_api_compat_patch.py │ │ ├── cross_entropy.py │ │ ├── inject.py │ │ ├── layers.py │ │ ├── linear.py │ │ ├── losses.py │ │ └── rmsnorm.py ├── mup │ ├── __init__.py │ ├── infshape.py │ ├── init.py │ ├── module.py │ ├── optim.py │ └── shape.py ├── npu │ ├── __init__.py │ ├── csrc │ │ ├── cann │ │ │ └── gmm.cpp │ │ └── inc │ │ │ └── pytorch_npu_helper.h │ ├── gmm.py │ ├── layers.py │ ├── op_builder │ │ ├── __init__.py │ │ ├── gmm_builder.py │ │ └── npu_builder.py │ └── optim.py ├── ops │ ├── __init__.py │ ├── accelerator │ │ ├── __init__.py │ │ ├── abstract_accelerator.py │ │ ├── cuda_accelerator.py │ │ └── real_accelerator.py │ ├── csrc │ │ ├── includes │ │ │ ├── conversion_utils.h │ │ │ ├── dequantization_utils.h │ │ │ ├── kernel_utils.h │ │ │ ├── memory_access_utils.h │ │ │ ├── quantization.h │ │ │ ├── quantization_optimizer.h │ │ │ ├── quantization_utils.h │ │ │ ├── quantizer.h │ │ │ └── reduction_utils.h │ │ └── quantization │ │ │ ├── dequantize.cu │ │ │ ├── pt_binding.cpp │ │ │ ├── quant_reduce.cu │ │ │ ├── quantization_optimizer.cc │ │ │ ├── quantization_optimizer.cu │ │ │ ├── quantize.cu │ │ │ └── swizzled_quantize.cu │ ├── git_version_info.py │ ├── op_builder │ │ ├── __init__.py │ │ ├── all_ops.py │ │ ├── builder.py │ │ ├── quantization_optimizer.py │ │ └── quantizer.py │ └── quantizer │ │ └── __init__.py ├── optimizers │ ├── __init__.py │ ├── adam_offload.py │ ├── adaml.py │ ├── agd.py │ ├── bf16_optimizer.py │ ├── low_bit │ │ ├── __init__.py │ │ ├── config.py │ │ ├── functional.py │ │ └── optim │ │ │ ├── __init__.py │ │ │ ├── q_adafactor.py │ │ │ ├── q_adamw.py │ │ │ ├── q_agd.py │ │ │ ├── q_came.py │ │ │ └── q_optimizer.py │ ├── utils.py │ └── wsam.py ├── pipeline_parallel │ ├── __init__.py │ ├── pipe_engine.py │ ├── pipe_module.py │ ├── pipe_partition.py │ ├── pipe_schedule.py │ ├── pipe_stage.py │ └── scheduler.py ├── protos │ ├── __init__.py │ ├── acceleration.proto │ ├── coworker.proto │ ├── protobuf_3_20_3 │ │ ├── __init__.py │ │ ├── acceleration_pb2.py │ │ ├── acceleration_pb2_grpc.py │ │ ├── coworker_pb2.py │ │ └── coworker_pb2_grpc.py │ └── protobuf_4_25_3 │ │ ├── __init__.py │ │ ├── acceleration_pb2.py │ │ ├── acceleration_pb2_grpc.py │ │ ├── coworker_pb2.py │ │ └── coworker_pb2_grpc.py ├── requirements.txt ├── rl │ ├── __init__.py │ ├── config.md │ ├── config.py │ ├── data │ │ ├── __init__.py │ │ └── data_utils.py │ ├── ds_hybrid_engine │ │ ├── __init__.py │ │ ├── ds_hook.py │ │ ├── hybrid_engine.py │ │ ├── initialize.py │ │ ├── module_inject │ │ │ ├── __init__.py │ │ │ ├── containers │ │ │ │ ├── __init__.py │ │ │ │ └── llama.py │ │ │ └── utils.py │ │ └── replace_policy.py │ ├── inference_backend │ │ ├── __init__.py │ │ └── vllm_backend.py │ ├── main.py │ ├── model_engine │ │ ├── __init__.py │ │ ├── model_engine.py │ │ └── strategy.py │ ├── model_utils │ │ ├── __init__.py │ │ ├── llama2_utils.py │ │ ├── load_init_model.py │ │ ├── model_util.py │ │ └── redis_util.py │ ├── ppo_utils │ │ ├── __init__.py │ │ └── ppo_util.py │ ├── replay_buffer │ │ ├── __init__.py │ │ └── replay_buffer.py │ └── trainer │ │ ├── __init__.py │ │ ├── ppo_trainer.py │ │ └── rl_trainer.py ├── service │ ├── __init__.py │ ├── coworker_data_service.py │ ├── data_info_service.py │ └── rpc_clients.py ├── tensor_parallel │ ├── __init__.py │ └── manual_tp.py ├── tests │ ├── common_tests │ │ ├── acc_executor_test.py │ │ ├── acc_planner_test.py │ │ ├── acc_strategy_test.py │ │ ├── adp_test.py │ │ ├── amp_test.py │ │ ├── analyser_test.py │ │ ├── analyzer_result_test.py │ │ ├── auto_acc_client_test.py │ │ ├── auto_acc_servicer_test.py │ │ ├── auto_accelerate_context_test.py │ │ ├── auto_accelerate_test.py │ │ ├── bf16_optimizer_test.py │ │ ├── bo_sg_test.py │ │ ├── clip_grad_norm_test.py │ │ ├── communicator_test.py │ │ ├── data_utils_test.py │ │ ├── device_context_test.py │ │ ├── dim_planner_test.py │ │ ├── distributed_mappings_test.py │ │ ├── distributed_test.py │ │ ├── distributed_transformer_test.py │ │ ├── dryrunner_test.py │ │ ├── ds_3d_parallel_optimization_test.py │ │ ├── ds_pipe_test.py │ │ ├── elastic_dataset_test.py │ │ ├── engine_client_test.py │ │ ├── fp8 │ │ │ └── test_quantized_grouped_linear.py │ │ ├── fsdp2 │ │ │ ├── device_mesh_test.py │ │ │ ├── fsdp2_test.py │ │ │ └── grad_norm_test.py │ │ ├── fsdp_lora_load_test.py │ │ ├── fsdp_moe_save_load_test.py │ │ ├── fsdp_save_auto_acc_test.py │ │ ├── fsdp_speedup_init_test.py │ │ ├── grouped_gemm_moe_test.py │ │ ├── hebo_test.py │ │ ├── hsdp_moe_save_load_test.py │ │ ├── inspector_test.py │ │ ├── kernel_extension_test.py │ │ ├── local_sgd_anomaly_detection_test.py │ │ ├── local_sgd_ddp_test.py │ │ ├── local_sgd_fsdp_test.py │ │ ├── local_sgd_gta_reducer_test.py │ │ ├── local_sgd_megatron_trainer_test.py │ │ ├── log_util_test.py │ │ ├── manual_tp_test.py │ │ ├── megatron_vocab_test.py │ │ ├── meta_init_test.py │ │ ├── meta_module_test.py │ │ ├── mixed_parallel_test.py │ │ ├── model_context_test.py │ │ ├── module_replace_optimization_test.py │ │ ├── moe_test.py │ │ ├── mup_test.py │ │ ├── normalization_test.py │ │ ├── npu_test.py │ │ ├── optim_test.py │ │ ├── optimizations_test.py │ │ ├── optimizer_offload_test.py │ │ ├── pipeline_parallel_optimization_test.py │ │ ├── pipeline_test.py │ │ ├── pipelining │ │ │ ├── model_registry.py │ │ │ ├── pipe_communicator_test.py │ │ │ ├── pipe_module_test.py │ │ │ ├── pipe_save_load_test.py │ │ │ ├── pipe_stage_test.py │ │ │ └── schedule_test.py │ │ ├── popen_redirect_io_test.py │ │ ├── precision_switch_linear_test.py │ │ ├── preloader_test.py │ │ ├── profiler_test.py │ │ ├── prompt_dataset_test.py │ │ ├── randomizer_test.py │ │ ├── scaled_linear_test.py │ │ ├── selective_checkpoint_test.py │ │ ├── semi_auto_acc_test.py │ │ ├── shm_context_test.py │ │ ├── shm_dataloader_test.py │ │ ├── sparse_tensor_test.py │ │ ├── sync_batch_norm_process_group_test.py │ │ ├── tensor_parallel_operators_test.py │ │ ├── tensor_parallel_optimization_test.py │ │ ├── test_cross_entropy.py │ │ ├── test_dynamic_profile.py │ │ ├── test_event_hook.py │ │ ├── test_loss_spike_utils.py │ │ ├── test_quantize.py │ │ ├── test_rmsnorm.py │ │ ├── tp_activation_checkpoint_test.py │ │ ├── tracer_test.py │ │ ├── trainer_test.py │ │ ├── tune_test.py │ │ ├── unordered_dataloader_test.py │ │ ├── unshuffled_batch_dataloder_test.py │ │ ├── util_test.py │ │ └── zero_optimization_test.py │ ├── glm │ │ └── modeling_glm.py │ ├── npu_ops_test │ │ └── gmm_test.py │ ├── rl_tests │ │ ├── rl_config_test.py │ │ ├── rl_deepspeed_llama2_test.py │ │ ├── rl_llama_model_util_test.py │ │ ├── rl_load_init_model_test.py │ │ ├── rl_model_engine_test.py │ │ ├── rl_ppo_util_test.py │ │ └── rl_replay_buffer_test.py │ ├── test_define_rl_models │ │ ├── __init__.py │ │ ├── independent_models │ │ │ ├── __init__.py │ │ │ ├── hg_model_def.yaml │ │ │ ├── model_def.yaml │ │ │ ├── model_definition.py │ │ │ └── strategy.py │ │ └── share_weights_models │ │ │ ├── actor_critic_ref.py │ │ │ └── config.yaml │ ├── test_modules │ │ ├── __init__.py │ │ ├── test_distributed_selfattn.py │ │ ├── test_flash_attn.py │ │ ├── test_fused_layernorm.py │ │ ├── test_fused_qkv.py │ │ ├── test_linear.py │ │ ├── test_mixprecision_layer.py │ │ ├── test_moe_ddp.py │ │ └── test_moelayer.py │ ├── toy_modules │ │ ├── __init__.py │ │ ├── toy_for_moe.py │ │ └── toy_module.py │ ├── tp_modules │ │ ├── atorch_mlp.py │ │ ├── fairscale_mlp.py │ │ └── model_args.py │ └── utils │ │ ├── __init__.py │ │ ├── test_numberic_check.py │ │ ├── test_util.py │ │ └── test_utils.py ├── trainer │ ├── __init__.py │ ├── args.py │ ├── atorch_args.py │ ├── atorch_profiler.py │ ├── atorch_trainer.py │ ├── atorch_trainer_v2.py │ ├── base │ │ ├── __init__.py │ │ ├── async_save │ │ │ ├── __init__.py │ │ │ └── megatron_async_save │ │ │ │ ├── __init__.py │ │ │ │ └── megatron_async_torch_save.py │ │ ├── atorch_container.py │ │ ├── atorch_module.py │ │ ├── atorch_train_engine.py │ │ ├── checkpoint.py │ │ ├── ckptloader.py │ │ ├── ckptsaver.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── dist_checkpointing │ │ │ ├── __init__.py │ │ │ └── strategies │ │ │ │ ├── __init__.py │ │ │ │ ├── async_torch_save_strategy.py │ │ │ │ └── async_utils.py │ │ ├── inferface.py │ │ ├── optimizer.py │ │ ├── scheduler.py │ │ └── train_step.py │ ├── debug_utils │ │ ├── __init__.py │ │ └── debug_wrappers.py │ ├── event_util.py │ ├── fsdp │ │ ├── __init__.py │ │ ├── dcp_forked.py │ │ ├── fsdp_ckpt_loader.py │ │ ├── fsdp_ckpt_saver.py │ │ ├── fsdp_shard_async_saver.py │ │ └── fsdp_shard_dcp_loader.py │ ├── megatron │ │ ├── __init__.py │ │ ├── megatron_async_save.py │ │ ├── megatron_ckpt_loader.py │ │ ├── megatron_ckpt_saver.py │ │ ├── megatron_dataloader.py │ │ ├── megatron_train_step.py │ │ └── megatron_wrapper.py │ ├── models │ │ ├── __init__.py │ │ └── atorch_model.py │ ├── state.py │ ├── trainer_callback.py │ └── utils.py └── utils │ ├── __init__.py │ ├── args.py │ ├── clip_grad.py │ ├── config.py │ ├── dataclass_utils.py │ ├── dev_utils.py │ ├── ds_pipe_utils.py │ ├── dynamic_profiler │ ├── __init__.py │ ├── _dynamic_profile.py │ └── _file_monitor.py │ ├── fa_util.py │ ├── fsdp_async_ckpt_util.py │ ├── fsdp_init_util.py │ ├── fsdp_save_util.py │ ├── gc.py │ ├── grad_scaler.py │ ├── graph_transform_utils.py │ ├── grouped_gemm_util.py │ ├── hooks.py │ ├── ib_monitor.py │ ├── import_util.py │ ├── inspector │ ├── __init__.py │ ├── analyze_log.py │ └── hooks.py │ ├── loss_spike_utils.py │ ├── meta_model_utils.py │ ├── meta_overrides.py │ ├── metric_util.py │ ├── moe_util.py │ ├── numberic_checker.py │ ├── parse_trace_json.py │ ├── patch_fairscale.py │ ├── patch_fsdp_param.py │ ├── patch_te.py │ ├── path_utils.py │ ├── pipe_file_utils.py │ ├── prof.py │ ├── rank_reorder │ ├── __init__.py │ ├── reorder.py │ └── util.py │ ├── shape_prop.py │ ├── sharding_spec.py │ ├── sparse.py │ ├── spec_prop.py │ ├── te_checkpoint.py │ ├── timer.py │ ├── tracer.py │ ├── trainer_utils.py │ ├── version.py │ └── virtual_optimizer │ ├── __init__.py │ └── megatron_virtual_optimizer.py ├── dev ├── docker │ ├── Dockerfile-ubuntu2004-pt210 │ ├── README.md │ ├── base │ │ ├── .condarc │ │ ├── Dockerfile │ │ ├── Dockerfile-pt21 │ │ ├── docker.Makefile │ │ ├── pip.conf │ │ └── requirements.txt │ └── handle_driver_compat.sh └── scripts │ ├── build.sh │ ├── build_image_atorch_dev.sh │ ├── build_proto.sh │ ├── import_atorch_after_build.py │ ├── pre-commit.sh │ ├── render_setup.py │ └── test_whl_import.sh ├── docs ├── FA-glm.md ├── README-AGD.md ├── README-EDiT.md ├── README-LOSS-SPIKE-UTIL.md ├── README-WSAM.md ├── auto_accelerate_api.md ├── developer_guide.md ├── feature_required_packages.md └── img │ ├── agd_beale.gif │ ├── agd_nanogpt.png │ ├── atorch.png │ ├── atorch_fig.png │ ├── edit_illustration.png │ └── wsam_traj.png ├── examples ├── async_save │ ├── fsdp_shard_async_save.py │ └── fsdp_shard_load.py ├── atorch_trainer_v2 │ ├── baseline_megatron │ │ └── pretrain_llama2_7b.sh │ ├── gpt2_config.yaml │ ├── llama2_7b_config.yaml │ ├── pretrain_atorch_trainer_megatron.py │ └── run.sh ├── auto_accelerate │ ├── README.md │ ├── data.py │ ├── modeling.py │ ├── train.py │ ├── train_gpt2_entry.sh │ ├── train_llama_entry.sh │ ├── train_toy_distributed_entry.sh │ └── train_toy_fully_automatic_entry.sh ├── llama2 │ ├── README.md │ ├── bayes_opt_sg_llama2.py │ ├── bayes_opt_sg_llama2_entry.sh │ ├── dataset_model.sh │ ├── ds_3d_llama2.py │ ├── ds_3d_llama2_entry.sh │ ├── example_utils.py │ ├── fsdp_llama2.py │ ├── fsdp_llama2_entry.sh │ ├── llama2_dummy_data_13b.sh │ ├── llama2_dummy_data_70b.sh │ ├── llama2_dummy_data_7b.sh │ ├── llama2_fsdp2_7b.sh │ ├── llama2_pp.sh │ ├── llama2_pp_uneven_partition.sh │ ├── requirements.txt │ ├── train_llama2_dummy_data.py │ └── train_llama2_with_pp.py ├── llama2_7b_ATorchTrainer │ ├── README.md │ ├── deepspeed_configs │ │ └── ds_config.json │ ├── instruction_dataset_utils.py │ ├── llama2_7b_trainer_entry.sh │ ├── llama2_7b_trainer_lora_entry.sh │ ├── llama2_clm.py │ ├── llama2_clm_atorch_trainer.py │ ├── llama2_trainer.py │ ├── prepare_dataset_and_weight.sh │ └── requirements.txt ├── local_sgd │ ├── atorch_trainer_megatron │ │ ├── gpt2_config.yaml │ │ ├── llama2_7b_config.yaml │ │ ├── pretrain_atorch_trainer_megatron_local_sgd.py │ │ └── run.sh │ └── auto_accelerate │ │ ├── __init__.py │ │ ├── data.py │ │ ├── modeling.py │ │ ├── run_local_sgd.sh │ │ └── train.py ├── moe │ ├── cal_mfu.py │ ├── moe_modules.py │ ├── npu_run_moe.sh │ ├── run_moe_fsdp1_ep.sh │ ├── run_moe_fsdp2_ep.sh │ ├── run_moe_pp.sh │ ├── run_moe_pp_fsdp2.sh │ ├── run_moe_pp_fsdp2_ep.sh │ ├── run_moe_pp_fsdp2_uneven_partition.sh │ ├── train_moe_dummy_data.py │ └── train_moe_with_pp.py ├── nanoGPTATorch │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── agd.py │ ├── assets │ │ ├── gpt2_124M_loss.png │ │ └── nanogpt.jpg │ ├── configurator.py │ ├── model.py │ ├── nadamw.py │ ├── openwebtext │ │ ├── prepare.py │ │ └── readme.md │ ├── prepare_dataset.sh │ ├── sample.py │ ├── scaling_laws.ipynb │ ├── train.py │ ├── train_atorch.py │ ├── train_atorch_entry.sh │ └── transformer_sizing.ipynb └── optimizer │ ├── README.md │ ├── __init__.py │ ├── main.py │ ├── model.py │ ├── train_agd_entry.sh │ ├── train_wsam_entry.sh │ └── utils.py ├── pytest.ini └── setup.py.tpl /.clang-format: -------------------------------------------------------------------------------- 1 | #Run manually to reformat a file: 2 | #clang-format -i --style=file 3 | BasedOnStyle: Google 4 | DerivePointerAlignment: false 5 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, W503, E741, F824 3 | max-line-length = 120 4 | per-file-ignores = __init__.py:F401 atorch/distributed/distributed.py:F401 5 | 6 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # root directory 2 | * @skydoorkai @adamantboy @hxdtest @nash635 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: report 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the unexpected case: 15 | 1. What kink of training? [e.g. FDDP] 16 | 2. The command using? [e.g. dlrover-run xxxxx xxxx] 17 | 3. When and where? 18 | 4. See error 19 | 20 | **Logs or Screenshots** 21 | Logs(necessary) or screenshots to help explain your problem. 22 | 23 | **Expected behavior** 24 | A clear and concise description of what you expected to happen. 25 | 26 | **APP Info (please complete the following information):** 27 | - DLRover: [e.g. 0.3.8] 28 | - Torch [e.g. 2.1.2] 29 | 30 | **ENV Info (please complete the following information):** 31 | - Platform: [e.g. ubuntu xxx] 32 | - Python: [e.g. 3.8.1] 33 | - GRPC [e.g. 1.5.x] 34 | 35 | **HARDWARE Info (please complete the following information):** 36 | - Device: [e.g. GPU A100 / NPU Ascend 910] 37 | 38 | **Additional context** 39 | Add any other context about the problem here. 40 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: For questions. 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/actions/atorch-pre-commit/action.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: atorch-pre-commit 3 | description: run pre-commit to check codes for atorch 4 | runs: 5 | using: 'docker' 6 | image: "easydl/atorch:aci" 7 | args: 8 | - "/bin/bash" 9 | - "-c" 10 | - "sh dev/scripts/pre-commit.sh" 11 | -------------------------------------------------------------------------------- /.github/actions/atorch-python-test/action.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: atorch-python-test 3 | description: run pytest to execute python test cases of atorch python 4 | runs: 5 | using: 'docker' 6 | image: "registry.cn-hangzhou.aliyuncs.com/atorch/atorch-open-20240430:pt210" 7 | args: 8 | - "/bin/bash" 9 | - "-c" 10 | - "pip install dlrover[torch]==0.4.0 --no-deps \ 11 | && echo -e 'import math\ninf = math.inf\nnan = math.nan\nstring_classes = \ 12 | (str, bytes)' > /opt/conda/lib/python3.8/site-packages/torch/_six.py \ 13 | && pip install dependency_injector \ 14 | && PYTHONPATH=. pytest atorch/tests/common_tests" 15 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### What changes were proposed in this pull request? 2 | 3 | Please describe the changes you have made or proposed in this pull request. 4 | 5 | ### Why are the changes needed? 6 | 7 | Explain the purpose or motivation behind these changes. What problem are you trying to solve? 8 | 9 | ### Does this PR introduce any user-facing change? 10 | 11 | Specify whether this pull request introduces any changes that users will directly interact with or notice. 12 | 13 | ### How was this patch tested? 14 | 15 | Detail the testing process you have undertaken to ensure the changes in this pull request are valid and working as intended. 16 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CI 3 | 4 | on: 5 | pull_request: 6 | workflow_dispatch: 7 | push: 8 | branches: [master] 9 | 10 | jobs: 11 | python-test: 12 | runs-on: self-hosted 13 | steps: 14 | # This step checks out a copy of your repository. 15 | - uses: actions/checkout@v3 16 | with: 17 | clean: false 18 | # This step references the directory that contains the action. 19 | - uses: ./.github/actions/atorch-python-test 20 | pre-commit: 21 | runs-on: ubuntu-latest 22 | steps: 23 | # This step checks out a copy of your repository. 24 | - uses: actions/checkout@v3 25 | # This step references the directory that contains the action. 26 | - uses: ./.github/actions/atorch-pre-commit 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea* 3 | *egg-info* 4 | dist 5 | build 6 | *~ 7 | *__pycache__* 8 | *.pyc 9 | .mypy_cache 10 | .DS_Store 11 | .cache 12 | .bazelrc 13 | .build_platform 14 | .platform_version 15 | bazel-bin 16 | bazel-out 17 | bazel-testlogs 18 | bazel-xpu_timer 19 | *.whl 20 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output=3 3 | line_length=120 4 | known_third_party = accelerate,agd,apex,datasets,deepspeed,dependency_injector,distutils,dlrover,einops,evaluate,example_utils,fairscale,flash_attn,google,grpc,instruction_dataset_utils,matplotlib,megatron,model,model_registry,moe_modules,networkx,numpy,packaging,pandas,peft,psutil,pytest,redis,safetensors,scipy,seaborn,sklearn,tiktoken,torch,torch_npu,torchvision,tqdm,transformers,triton,typing_extensions,utils,wrapt,yaml 5 | include_trailing_comma=True 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/pre-commit/mirrors-isort 4 | rev: v5.10.1 5 | hooks: 6 | - id: isort 7 | exclude: _pb2.py|_pb2_grpc.py 8 | args: [--settings-path, atorch, "--profile", "black"] 9 | - repo: https://github.com/psf/black 10 | rev: 22.6.0 11 | hooks: 12 | - id: black 13 | exclude: _pb2.py|_pb2_grpc.py 14 | args: [--line-length=120] 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v2.4.0 17 | hooks: 18 | - id: flake8 19 | exclude: __init__.py|_pb2.py|_pb2_grpc.py 20 | args: [ 21 | "--max-line-length=120", 22 | "--ignore=E721,W503,E203,E266,E741,F824", 23 | ] 24 | - repo: https://github.com/pre-commit/mirrors-mypy 25 | rev: v0.981 26 | hooks: 27 | - id: mypy 28 | exclude: _pb2.py|_pb2_grpc.py|auto/engine/servicer.py 29 | args: [--ignore-missing-imports, --follow-imports=skip, --namespace-packages, --no-strict-optional, --show-error-codes] 30 | additional_dependencies: ["types_requests", "types-PyYAML"] 31 | -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /atorch/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from importlib.metadata import version 4 | 5 | from .distributed.distributed import coworker_size, init_distributed, local_rank, rank, reset_distributed, world_size 6 | 7 | try: 8 | __version__ = version("atorch") 9 | except ImportError: 10 | __version__ = "0.0.1dev" 11 | 12 | os.environ["PIPPY_PIN_DEVICE"] = "0" 13 | 14 | # patch with atorch addon if exists and not disabled by ATORCH_DISABLE_ADDON env. 15 | disable_addon = False 16 | disable_addon_env = os.getenv("ATORCH_DISABLE_ADDON") 17 | if disable_addon_env is not None and disable_addon_env.lower() in ["true", "t", "1", "y", "yes"]: 18 | disable_addon = True 19 | 20 | if disable_addon: 21 | logging.warning("atorch_addon disabled by env ATORCH_DISABLE_ADDON.") 22 | 23 | addon_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "atorch_addon.py") 24 | 25 | if not disable_addon and os.path.exists(addon_file): 26 | try: 27 | import atorch.atorch_addon 28 | except ImportError: 29 | logging.warning("Failed to import atorch_addon!") 30 | -------------------------------------------------------------------------------- /atorch/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from atorch.common.log_utils import default_logger as logger 2 | 3 | try: 4 | from .amp import initialize, load_state_dict, master_params, scale_loss, state_dict 5 | from .hook import sample_list_to_type 6 | except ImportError: 7 | logger.info("Apex not available") 8 | -------------------------------------------------------------------------------- /atorch/amp/hook.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from apex.amp import _initialize 4 | from torch import Tensor 5 | 6 | to_type_original = _initialize.to_type 7 | 8 | 9 | def sample_list_to_type(dtype, t): 10 | """ 11 | Hook `_initialize.to_type`. Original `to_type` only handle the case 12 | that `t` is a torch.Tensor. `sample_list_to_type` can also handle 13 | the case that t is a list or a dict. 14 | """ 15 | if isinstance(t, Dict): 16 | for k, v in t.items(): 17 | if isinstance(v, Tensor): 18 | if v.is_floating_point(): 19 | t[k] = v.to(dtype) 20 | return t 21 | elif isinstance(t, List): 22 | for i, elem in enumerate(t): 23 | if isinstance(elem, Tensor): 24 | if elem.is_floating_point(): 25 | t[i] = elem.to(dtype) 26 | return t 27 | else: 28 | return to_type_original(dtype, t) 29 | 30 | 31 | _initialize.to_type = sample_list_to_type 32 | -------------------------------------------------------------------------------- /atorch/auto/__init__.py: -------------------------------------------------------------------------------- 1 | from .accelerate import auto_accelerate 2 | from .clip_grad_norm import clip_grad_norm 3 | -------------------------------------------------------------------------------- /atorch/auto/analyser/__init__.py: -------------------------------------------------------------------------------- 1 | from .analyser import Analyser 2 | -------------------------------------------------------------------------------- /atorch/auto/auto_accelerate_context.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | class AutoAccelerateContext: 5 | """ 6 | AutoAccelerateContext is a global storage for auto_accelerate. 7 | Use AutoAccelerateContext.add_ac_attr to add/update an attribute with name and value. 8 | To access an added attribute with attr_name, use AutoAccelerateContext.attr_name. 9 | Use AutoAccelerateContext.reset to delete all attrs added by add_ac_attr. 10 | """ 11 | 12 | # Number of times the function has been called 13 | counter = 0 14 | 15 | @classmethod 16 | def add_ac_attr(cls, name, value): 17 | if hasattr(cls, name): 18 | cls.name = value 19 | else: 20 | setattr(cls, name, value) 21 | 22 | @classmethod 23 | def reset(cls): 24 | reset_white_list = {"counter", "skip_dryrun"} 25 | method_list = inspect.getmembers(cls, predicate=inspect.ismethod) 26 | method_name_list = [method_tuple[0] for method_tuple in method_list] 27 | for attr in dir(cls): 28 | if attr in reset_white_list: 29 | continue 30 | if not (attr.startswith("__") and attr.endswith("__")) and attr not in method_name_list: 31 | delattr(cls, attr) 32 | -------------------------------------------------------------------------------- /atorch/auto/dry_runner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/dry_runner/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/analyser_result.py: -------------------------------------------------------------------------------- 1 | class AnalyserResult(object): 2 | """Store analyzer's result""" 3 | 4 | def __init__(self): 5 | self._res = {} 6 | 7 | def get(self, key): 8 | """Return the value of the key""" 9 | _, res = self._search_recursively(self._res, key) 10 | return res 11 | 12 | def _search_recursively(self, res, key): 13 | """DepthFirstSearch the key in res and return the corresponding value""" 14 | if not isinstance(res, dict): 15 | # reach the leaf node 16 | return False, None 17 | 18 | if key in res: 19 | return True, res[key] 20 | 21 | for k in res: 22 | found, current_res = self._search_recursively(res[k], key) 23 | if found: 24 | return True, current_res 25 | 26 | return False, None 27 | 28 | def put(self, key, val): 29 | self._res[key] = val 30 | 31 | def update(self, res): 32 | """Update all result 33 | Args: 34 | res(dict): use res to update value 35 | """ 36 | self._res.update(res) 37 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/README.md: -------------------------------------------------------------------------------- 1 | # hebo 2 | 3 | This module is from Huawei [HEBO](https://github.com/huawei-noah/HEBO). 4 | 5 | Here, we remove the dependency of pytorch. 6 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | if not hasattr(np, "float"): 4 | setattr(np, "float", np.float32) 5 | if not hasattr(np, "object"): 6 | setattr(np, "object", object) 7 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/acq_optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/hebo/acq_optimizers/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/acquisitions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/hebo/acquisitions/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/design_space/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/hebo/design_space/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/design_space/categorical_param.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .param import Parameter 4 | 5 | 6 | class CategoricalPara(Parameter): 7 | def __init__(self, param): 8 | super().__init__(param) 9 | self.cates = list(param["categories"]) 10 | try: 11 | self._categories_dict = {k: v for v, k in enumerate(self.cates)} 12 | except TypeError: # there are unhashable types 13 | self._categories_dict = None 14 | self.lb = 0 15 | self.ub = len(self.cates) - 1 16 | 17 | def sample(self, num=1): 18 | assert num > 0 19 | return np.random.choice(self.cates, num, replace=True) 20 | 21 | def transform(self, x: np.ndarray): 22 | if self._categories_dict: 23 | ret = np.array(list(map(lambda a: self._categories_dict[a], x))) 24 | else: 25 | # otherwise, we fall back to searching in an array 26 | ret_li = list(map(lambda a: np.where(self.cates == a)[0][0], x)) 27 | ret = np.array(ret_li) 28 | return ret.astype(float) 29 | 30 | def inverse_transform(self, x): 31 | return np.array([self.cates[x_] for x_ in x.round().astype(int)]) 32 | 33 | @property 34 | def is_numeric(self): 35 | return False 36 | 37 | @property 38 | def is_discrete(self): 39 | return True 40 | 41 | @property 42 | def is_discrete_after_transform(self): 43 | return True 44 | 45 | @property 46 | def opt_lb(self): 47 | return self.lb 48 | 49 | @property 50 | def opt_ub(self): 51 | return self.ub 52 | 53 | @property 54 | def num_uniqs(self): 55 | return len(self.cates) 56 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/design_space/param.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | class Parameter(ABC): 8 | def __init__(self, param_dict): 9 | self.param_dict = param_dict 10 | self.name = param_dict["name"] 11 | pass 12 | 13 | @abstractmethod 14 | def sample(self, num=1) -> pd.DataFrame: 15 | pass 16 | 17 | @abstractmethod 18 | def transform(self, x: np.ndarray) -> np.ndarray: 19 | pass 20 | 21 | @abstractmethod 22 | def inverse_transform(self, x: np.ndarray) -> np.ndarray: 23 | pass 24 | 25 | @property 26 | @abstractmethod 27 | def is_numeric(self) -> bool: 28 | pass 29 | 30 | @property 31 | @abstractmethod 32 | def is_discrete(self) -> bool: 33 | """ 34 | Integer and categorical variable 35 | """ 36 | pass 37 | 38 | @property 39 | @abstractmethod 40 | def is_discrete_after_transform(self) -> bool: 41 | pass 42 | 43 | @property 44 | def is_categorical(self) -> bool: 45 | return not self.is_numeric 46 | 47 | @property 48 | @abstractmethod 49 | def opt_lb(self) -> float: 50 | pass 51 | 52 | @property 53 | @abstractmethod 54 | def opt_ub(self) -> float: 55 | pass 56 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/hebo/models/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/models/gauss_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/hebo/models/gauss_process/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/models/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import OneHotEncoder 3 | 4 | 5 | class OneHotTransform(object): 6 | def __init__(self, num_uniqs): 7 | self.num_uniqs = num_uniqs 8 | 9 | @property 10 | def num_out_list(self): 11 | return self.num_uniqs 12 | 13 | @property 14 | def num_out(self) -> int: 15 | return sum(self.num_uniqs) 16 | 17 | def __call__(self, xe): 18 | return np.concatenate( 19 | [ 20 | OneHotEncoder(categories=[list(np.arange(self.num_uniqs[i]))]) 21 | .fit_transform(xe[:, i].reshape([-1, 1])) 22 | .toarray() 23 | for i in range(xe.shape[1]) 24 | ], 25 | axis=1, 26 | ).astype("float") 27 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/models/model_factory.py: -------------------------------------------------------------------------------- 1 | from atorch.auto.engine.sg_algo.hebo.models.base_model import BaseModel 2 | from atorch.auto.engine.sg_algo.hebo.models.gauss_process.gpy_wgp import GPyGP 3 | from atorch.auto.engine.sg_algo.hebo.models.random_forest.rf import RF 4 | 5 | model_dict = {"gpy": GPyGP, "rf": RF} 6 | 7 | model_names = [k for k in model_dict.keys()] 8 | 9 | 10 | def get_model_class(model_name: str): 11 | 12 | assert model_name in model_dict 13 | model_class = model_dict[model_name] 14 | return model_class 15 | 16 | 17 | def get_model(model_name: str, *params, **conf) -> BaseModel: 18 | model_class = get_model_class(model_name) 19 | return model_class(*params, **conf) 20 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/models/random_forest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/hebo/models/random_forest/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/models/random_forest/rf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.ensemble import RandomForestRegressor 3 | 4 | from atorch.auto.engine.sg_algo.hebo.models.base_model import BaseModel 5 | from atorch.auto.engine.sg_algo.hebo.models.layers import OneHotTransform 6 | from atorch.auto.engine.sg_algo.hebo.models.util import filter_nan 7 | 8 | 9 | class RF(BaseModel): 10 | def __init__(self, num_cont, num_enum, num_out, **conf): 11 | super().__init__(num_cont, num_enum, num_out, **conf) 12 | self.n_estimators = self.conf.get("n_estimators", 100) 13 | self.rf = RandomForestRegressor(n_estimators=self.n_estimators) 14 | self.est_noise = np.zeros(self.num_out) 15 | if self.num_enum > 0: 16 | self.one_hot = OneHotTransform(self.conf["num_uniqs"]) 17 | 18 | def xtrans(self, Xc: np.ndarray, Xe: np.ndarray) -> np.ndarray: 19 | if self.num_enum == 0: 20 | return Xc 21 | else: 22 | Xe_one_hot = self.one_hot(Xe) 23 | if Xc is None: 24 | Xc = np.zeros((Xe.shape[0], 0)) 25 | return np.concatenate([Xc, Xe_one_hot], axis=1) 26 | 27 | def fit(self, Xc: np.ndarray, Xe: np.ndarray, y: np.ndarray): 28 | Xc, Xe, y = filter_nan(Xc, Xe, y, "all") 29 | Xtr = self.xtrans(Xc, Xe) 30 | ytr = y.reshape(-1) 31 | self.rf.fit(Xtr, ytr) 32 | var = (self.rf.predict(Xtr).reshape(-1) - ytr) ** 2 33 | self.est_noise = np.mean(var).reshape(self.num_out) 34 | 35 | @property 36 | def noise(self): 37 | return self.est_noise 38 | 39 | def predict(self, Xc: np.ndarray, Xe: np.ndarray): 40 | X = self.xtrans(Xc, Xe) 41 | mean = self.rf.predict(X).reshape(-1, 1) 42 | preds = [] 43 | for estimator in self.rf.estimators_: 44 | preds.append(estimator.predict(X).reshape([-1, 1])) 45 | var = np.var(np.concatenate(preds, axis=1), axis=1) 46 | return mean.reshape([-1, 1]), var.reshape([-1, 1]) + self.noise 47 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/models/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def filter_nan(x, xe, y, keep_rule="any"): 5 | assert x is None or np.isfinite(x).all() 6 | assert xe is None or np.isfinite(xe).all() 7 | assert np.isfinite(y).any(), "No valid data in the dataset" 8 | 9 | if keep_rule == "any": 10 | valid_id = np.isfinite(y).any(axis=1) 11 | else: 12 | valid_id = np.isfinite(y).any(axis=1) 13 | x_filtered = x[valid_id] if x is not None else None 14 | xe_filtered = xe[valid_id] if xe is not None else None 15 | y_filtered = y[valid_id] 16 | return x_filtered, xe_filtered, y_filtered 17 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/auto/engine/sg_algo/hebo/optimizers/__init__.py -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/optimizers/abstract_optimizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from atorch.auto.engine.sg_algo.hebo.design_space.design_space import DesignSpace 7 | 8 | 9 | class AbstractOptimizer(ABC): 10 | support_parallel_opt = False 11 | support_constraint = False 12 | support_multi_objective = False 13 | support_combinatorial = False 14 | support_contextual = False 15 | 16 | def __init__(self, space: DesignSpace) -> None: 17 | self.space = space 18 | 19 | @abstractmethod 20 | def suggest(self, n_suggestions=1, fix_input: dict = None): 21 | """ 22 | Perform optimisation and give recommendation using data observed so far 23 | --------------------- 24 | n_suggestions: number of recommendations in this iteration 25 | 26 | fix_input: parameters NOT to be optimized, but rather fixed, this 27 | can be used for contextual BO. 28 | """ 29 | pass 30 | 31 | @abstractmethod 32 | def observe(self, x: pd.DataFrame, y: np.ndarray): 33 | """ 34 | Observe new data 35 | """ 36 | pass 37 | 38 | @property 39 | @abstractmethod 40 | def best_x(self) -> pd.DataFrame: 41 | pass 42 | 43 | @property 44 | @abstractmethod 45 | def best_y(self) -> float: 46 | pass 47 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/hebo/optimizers/util.py: -------------------------------------------------------------------------------- 1 | from atorch.auto.engine.sg_algo.hebo.design_space.design_space import DesignSpace 2 | 3 | 4 | def parse_space_from_bayesmark(api_config) -> DesignSpace: 5 | """ 6 | Parse design space of bayesmark (https://github.com/uber/bayesmark) 7 | """ 8 | space = DesignSpace() 9 | params = [] 10 | for param_name in api_config: 11 | param_conf = api_config[param_name] 12 | param_type = param_conf["type"] 13 | param_space = param_conf.get("space", None) 14 | param_range = param_conf.get("range", None) 15 | param_values = param_conf.get("values", None) 16 | 17 | bo_param_conf = {"name": param_name} 18 | if param_type == "int": 19 | bo_param_conf["type"] = "int" 20 | bo_param_conf["lb"] = param_range[0] 21 | bo_param_conf["ub"] = param_range[1] 22 | elif param_type == "bool": 23 | bo_param_conf["type"] = "bool" 24 | elif param_type in ("cat", "ordinal"): 25 | bo_param_conf["type"] = "cat" 26 | bo_param_conf["categories"] = list(set(param_values)) 27 | elif param_type == "real": 28 | if param_space in ("log", "logit"): 29 | bo_param_conf["type"] = "pow" 30 | bo_param_conf["base"] = 10 31 | bo_param_conf["lb"] = param_range[0] 32 | bo_param_conf["ub"] = param_range[1] 33 | else: 34 | bo_param_conf["type"] = "num" 35 | bo_param_conf["lb"] = param_range[0] 36 | bo_param_conf["ub"] = param_range[1] 37 | else: 38 | assert False, "type %s not handled in API" % param_type 39 | params.append(bo_param_conf) 40 | space.parse(params) 41 | return space 42 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/sg_algo_lib.py: -------------------------------------------------------------------------------- 1 | try: 2 | from atorch.auto.engine.sg_algo.bayes_opt_sg import BOAlgorithm 3 | 4 | bo_algo_is_available = True 5 | except ModuleNotFoundError: 6 | bo_algo_is_available = False 7 | from atorch.auto.engine.sg_algo.combination_sg import CombinationAlgorithm 8 | 9 | 10 | class StrategyGenerationAlgorithmLibrary(object): 11 | """ 12 | Each strategy generation (SG) algorithm is a StrategyGenerationAlgorithm 13 | instance, which can be called to generate new strategies. 14 | """ 15 | 16 | def __init__(self): 17 | self.algorithms = {} 18 | self.add_algorithms() 19 | 20 | def add_algorithms(self): 21 | algo = CombinationAlgorithm() 22 | self.algorithms[algo.name] = algo 23 | if bo_algo_is_available: 24 | bo_algo = BOAlgorithm() 25 | self.algorithms[bo_algo.name] = bo_algo 26 | 27 | def __getitem__(self, name): 28 | if name in self.algorithms: 29 | return self.algorithms[name] 30 | return None 31 | -------------------------------------------------------------------------------- /atorch/auto/engine/sg_algo/sg_algorithm.py: -------------------------------------------------------------------------------- 1 | class StrategyGenerationAlgorithm(object): 2 | """A strategy generation (SG) algorithm implementation. 3 | Call strategy_generate with executor to generate candidate strategies. 4 | strategy_generate can be called multiple times to generate strategies in 5 | multiple stages. 6 | """ 7 | 8 | def __init__(self, name=None): 9 | self.name = name 10 | self.is_done = False 11 | 12 | def strategy_generate(self, _): 13 | """ 14 | Input: executor which contains optimization method, strategies. 15 | The output is 3-tuple: 16 | is_done: bool incidating if the algorithm finishs after this call. 17 | tasks: None or list(task), new tasks to execute. 18 | new_strategy_num: int for the number of new strategy added. 19 | """ 20 | self.is_done = True 21 | return self.is_done, None, 0 22 | 23 | def __call__(self, executor): 24 | return self.strategy_generate(executor) 25 | -------------------------------------------------------------------------------- /atorch/auto/engine/task.py: -------------------------------------------------------------------------------- 1 | class TaskType: 2 | ANALYSE = "ANALYSE" 3 | TUNE = "TUNE" 4 | FINISH = "FINISH" 5 | SETUP_PARALLEL_GROUP = "SETUP_PARALLEL_GROUP" 6 | DRYRUN = "DRYRUN" 7 | WAIT = "WAIT" 8 | FAIL = "FAIL" 9 | 10 | 11 | class TaskProcessMode: 12 | """ 13 | A task may be run on one or more processes. 14 | For example, an ANALYSE task may use ONE_PROCESS, while a DRYRUN, 15 | a SETUP_PARALLEL_GROUP task or a FINISH task will use ALL_PROCESS. 16 | A TUNE task may use ONE_PROCESS, or ALL_PROCESS. 17 | TODO: support custom parallel group 18 | """ 19 | 20 | ONE_PROCESS = "ONE_PROCESS" 21 | ALL_PROCESS = "ALL_PROCESS" 22 | 23 | 24 | class TaskStatus: 25 | PENDING = 0 26 | ASSIGNING = 1 27 | RUNNING = 2 28 | CANCELLED = 3 29 | FAILED = 4 30 | SUCCEEDED = 5 31 | 32 | 33 | class Task(object): 34 | """ 35 | This Task definition is a superset of Task in atorch.auto.task. 36 | This Task also includes status, process mode, task result. 37 | Task status change: 38 | "ONE_PROCESS": PENDING -> RUNNING -> SUCCEEDED/FAILED 39 | "ALL_PROCESS": PENDING -> ASSIGNING-> RUNNING -> SUCCEEDED/FAILED 40 | """ 41 | 42 | def __init__( 43 | self, 44 | task_type, 45 | task_info, 46 | task_id=-1, 47 | strategy_id=-1, 48 | process_mode=TaskProcessMode.ONE_PROCESS, 49 | time_limit=None, 50 | task_status=TaskStatus.PENDING, 51 | ): 52 | self.task_id = task_id 53 | self.task_type = task_type 54 | self.task_info = task_info 55 | self.time_limit = time_limit 56 | self.process_mode = process_mode 57 | self.process_assigned = [] # list of process id 58 | self.status = task_status 59 | self.task_result = None 60 | self.strategy_id = strategy_id 61 | 62 | def add_process(self, process_id): 63 | self.process_assigned.append(process_id) 64 | -------------------------------------------------------------------------------- /atorch/auto/opt_lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimization_library import OptimizationLibrary 2 | -------------------------------------------------------------------------------- /atorch/auto/opt_lib/dynamo_backends.py: -------------------------------------------------------------------------------- 1 | # listing all backends supported by torch dynamo 2 | import enum 3 | 4 | 5 | class DynamoBackends(enum.Enum): 6 | NO = "NO" 7 | EAGER = "EAGER" 8 | AOT_EAGER = "AOT_EAGER" 9 | INDUCTOR = "INDUCTOR" 10 | NVFUSER = "NVFUSER" 11 | AOT_NVFUSER = "AOT_NVFUSER" 12 | AOT_CUDAGRAPHS = "AOT_CUDAGRAPHS" 13 | OFI = "OFI" 14 | FX2TRT = "FX2TRT" 15 | ONNXRT = "ONNXRT" 16 | IPEX = "IPEX" 17 | -------------------------------------------------------------------------------- /atorch/auto/opt_lib/half_optimization.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | from atorch.auto.opt_lib.optimization import Optimization 6 | from atorch.common.log_utils import default_logger as logger 7 | from atorch.common.util_func import data_float_to_dtype 8 | 9 | 10 | class HalfOptimization(Optimization): 11 | """HalfOptimization will convert model to half (fp16 or bf16) 12 | config is a string, "fp16" (default) or "bf16". 13 | """ 14 | 15 | checkpoint_funcs_before_overrided = None 16 | 17 | def __init__(self): 18 | super().__init__("half", "half", False) 19 | 20 | def tune(self, model_context, config=None, strategy=None, apply_transform=True, time_limit=None): 21 | if apply_transform: 22 | model_context = self.transform(model_context, config) 23 | return True, config, model_context 24 | 25 | def transform(self, model_context, config="fp16"): 26 | model_context.add_wrapper("half", HalfOptimization.apply_wrapper, wrapper_config=config, is_pre_wrapper=True) 27 | return model_context 28 | 29 | @staticmethod 30 | def apply_wrapper(model_context, wrapper_name, wrapper_config=None): 31 | # wrapper_config should be one of "fp16", "bf16". 32 | if wrapper_config not in ("fp16", "bf16"): 33 | logger.error("Invalid config for half optimization. Should be fp16 or bf16 but get %s", wrapper_config) 34 | dtype = torch.float16 if wrapper_config == "fp16" else torch.bfloat16 35 | from atorch.utils.meta_model_utils import custom_transform_model_keep_checkpoint_name 36 | 37 | model_context.model = custom_transform_model_keep_checkpoint_name(model_context.model, lambda m: m.to(dtype)) 38 | 39 | def inputs_to_dtype(data, device, dtype, ori_func): 40 | data = ori_func(data, device) 41 | return data_float_to_dtype(data, dtype) 42 | 43 | model_context.prepare_input = partial(inputs_to_dtype, dtype=dtype, ori_func=model_context.prepare_input) 44 | 45 | return model_context 46 | -------------------------------------------------------------------------------- /atorch/auto/opt_lib/shard_planners/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_stage_planner import BaseStagePlanner, split_into_nstages_equal_size 2 | from .base_tp_planner import BaseTensorParallelPlanner 3 | from .mip_tp_planner import MIPTensorParallelPlanner 4 | from .topology import DeviceTopology, SimpleTopology 5 | -------------------------------------------------------------------------------- /atorch/auto/task.py: -------------------------------------------------------------------------------- 1 | class Task(object): 2 | def __init__(self, id, task_type, process_mode="ONE_PROCESS", task_info=None): 3 | self.id = id 4 | self.type = task_type 5 | self.process_mode = process_mode 6 | if task_type in ["TUNE", "DRYRUN", "FINISH"]: 7 | self.strategy = task_info 8 | elif task_type == "SETUP_PARALLEL_GROUP": 9 | self.parallel_group_info = task_info 10 | elif task_type == "ANALYSE": 11 | self.analysis_method = task_info 12 | -------------------------------------------------------------------------------- /atorch/checkpoint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/checkpoint/__init__.py -------------------------------------------------------------------------------- /atorch/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/common/__init__.py -------------------------------------------------------------------------------- /atorch/common/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from atorch.common.singleton import SingletonMeta 4 | 5 | 6 | def parse_bool_env(name, default="False"): 7 | str_env = os.getenv(name, default) 8 | if str_env in ["1", "True", "true"]: 9 | return True 10 | elif str_env in ["0", "False", "false"]: 11 | return False 12 | else: 13 | raise ValueError(f"Env {name} should be True or False") 14 | 15 | 16 | class EnvSetting(metaclass=SingletonMeta): 17 | # TODO: config the num in moe module which is more comprehensive 18 | MOE_FSDP_PREFETCH_NUM = int(os.getenv("MOE_FSDP_PREFETCH_NUM", 1)) 19 | MOE_NPU_DISABLE_ARGSORT_REPLACE = parse_bool_env("MOE_NPU_DISABLE_ARGSORT_REPLACE") 20 | MOE_DISABLE_SHARED_EXPERT_OVERLAP = parse_bool_env("MOE_DISABLE_SHARED_EXPERT_OVERLAP") 21 | DISABLE_CHECKPOINT_PATCH = parse_bool_env("ATORCH_DISABLE_CHECKPOINT_PATCH") 22 | MOE_NPU_DISABLE_FUSED_KERNEL = parse_bool_env("MOE_NPU_DISABLE_FUSED_KERNEL", "True") 23 | MOE_NV_DISABLE_FUSED_KERNEL = parse_bool_env("MOE_NV_DISABLE_FUSED_KERNEL", "False") 24 | MOE_REPLACE_MINDSPEED_ALLGATHER_TOKENDISPATCHER_INDEX = parse_bool_env( 25 | "MOE_REPLACE_MINDSPEED_ALLGATHER_TOKENDISPATCHER_INDEX", default="True" 26 | ) 27 | MOE_MLP_PREFIX = parse_bool_env("MOE_MLP_PREFIX", "True") 28 | DISABLE_TE_LINEAR_PATCHING = parse_bool_env("DISABLE_TE_LINEAR_PATCHING") 29 | DEBUG = parse_bool_env("ATORCH_DEBUG", "False") 30 | FORCE_FSDP2_RESHARD_AFTER_FORWARD = parse_bool_env("FORCE_FSDP2_RESHARD_AFTER_FORWARD", "False") 31 | CLOSE_FSDP2_BACKWARD_PREFETCH = parse_bool_env("CLOSE_FSDP2_BACKWARD_PREFETCH", "False") 32 | 33 | # FP8 34 | FORCE_QUANTIZE_PER_MICROBATCH = parse_bool_env("FORCE_QUANTIZE_PER_MICROBATCH", "False") 35 | -------------------------------------------------------------------------------- /atorch/common/singleton.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class SingletonMeta(type): 5 | """ 6 | The Singleton class can be implemented in different ways in Python. Some 7 | possible methods include: base class, decorator, metaclass. We will use the 8 | metaclass because it is best suited for this purpose. 9 | Adapted from ColossalAI https://github.com/hpcaitech/ColossalAI 10 | """ 11 | 12 | _instances: Dict = {} 13 | 14 | def __call__(cls, *args, **kwargs): 15 | """ 16 | Possible changes to the value of the `__init__` argument do not affect 17 | the returned instance. 18 | """ 19 | if cls not in cls._instances: 20 | instance = super().__call__(*args, **kwargs) 21 | cls._instances[cls] = instance 22 | else: 23 | assert ( 24 | len(args) == 0 and len(kwargs) == 0 25 | ), f"{cls.__name__} is a singleton class and a instance has been created." 26 | return cls._instances[cls] 27 | -------------------------------------------------------------------------------- /atorch/communication/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/communication/__init__.py -------------------------------------------------------------------------------- /atorch/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .coworker_dataset import build_coworker_dataloader 2 | from .data_utils import expand_batch_dim, get_sample_batch 3 | 4 | try: 5 | from .elastic_dataloader import build_coworker_dataloader_with_elasticdl, get_elastic_dataloader 6 | except TypeError: 7 | print("protobuf version mismatch. elastic_dataloader cannot be used") 8 | build_coworker_dataloader_with_elasticdl = None 9 | get_elastic_dataloader = None 10 | from .preloader import GpuPreLoader, data_to_device 11 | from .shm_context import ShmData, create_coworker_shm_context 12 | from .shm_dataloader import ShmDataloader, create_shm_dataloader 13 | from .unordered_dataloader import UnorderedDataLoader 14 | -------------------------------------------------------------------------------- /atorch/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | from atorch.common.util_func import recursively_apply 6 | from atorch.distributed.distributed import is_pipe_first_stage, is_pipe_last_stage 7 | 8 | 9 | def fast_batch_copy(data): 10 | def do_copy(data): 11 | return torch.from_numpy(data.numpy().copy()) 12 | 13 | return recursively_apply(do_copy, data, error_on_other_type=True) 14 | 15 | 16 | def get_sample_batch(dataset, dataloader_args, num=1): 17 | new_args = {"num_workers": 0} 18 | for key in dataloader_args: 19 | if key in ["num_workers", "prefetch_factor"]: 20 | # multi-process is not needed. 21 | continue 22 | if key == "sampler" and isinstance(dataloader_args["sampler"], DistributedSampler): 23 | # no need for sampler if it is default 24 | continue 25 | new_args[key] = dataloader_args[key] 26 | dataloader = DataLoader(dataset, **new_args) 27 | batches = [] 28 | for idx, data in enumerate(dataloader): 29 | batches.append(data) 30 | if idx == num - 1: 31 | break 32 | if num == 1: 33 | return batches[0] 34 | else: 35 | return batches 36 | 37 | 38 | def expand_batch_dim(data, batch_size=1): 39 | def expand(data, batch_size=1): 40 | shape_list = list(data.shape) 41 | shape_list.insert(0, batch_size) 42 | return data.expand(*shape_list) 43 | 44 | return recursively_apply(expand, data, batch_size=batch_size, error_on_other_type=False) 45 | 46 | 47 | def get_batch(data_iterator): 48 | if (not is_pipe_first_stage(ignore_virtual=True)) and (not is_pipe_last_stage(ignore_virtual=True)): 49 | return None 50 | 51 | # TODO: 52 | # considier tp and cp 53 | 54 | if data_iterator is not None: 55 | batch = next(data_iterator) 56 | else: 57 | batch = None 58 | 59 | return batch 60 | -------------------------------------------------------------------------------- /atorch/data/unshuffled_batch_dataloader.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.distributed as dist 4 | from torch.utils.data.distributed import Sampler 5 | 6 | 7 | class DistributedUnshuffledBatchSampler(Sampler): 8 | def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None): 9 | if num_replicas is None: 10 | if not dist.is_available(): 11 | raise RuntimeError("Requires distributed package to be available") 12 | num_replicas = dist.get_world_size() 13 | if rank is None: 14 | if not dist.is_available(): 15 | raise RuntimeError("Requires distributed package to be available") 16 | rank = dist.get_rank() 17 | if batch_size is None: 18 | raise RuntimeError("Requires batch_size to be available") 19 | 20 | self.dataset = dataset 21 | self.num_replicas = num_replicas 22 | self.rank = rank 23 | self.batch_size = batch_size 24 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 25 | self.total_size = self.num_samples * self.num_replicas 26 | 27 | def __iter__(self): 28 | indices = [] 29 | batch_num = int(self.num_samples / self.batch_size) 30 | for i in range(batch_num): 31 | start_pos = self.rank * self.batch_size + self.num_replicas * self.batch_size * i 32 | end_pos = start_pos + self.batch_size 33 | indices.extend(range(start_pos, end_pos)) 34 | return iter(indices) 35 | 36 | def __len__(self): 37 | return self.num_samples 38 | -------------------------------------------------------------------------------- /atorch/data_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/data_parallel/__init__.py -------------------------------------------------------------------------------- /atorch/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import ( 2 | coworker_local_rank, 3 | coworker_num_per_node, 4 | create_sequence_parallel_group, 5 | destroy_sequence_parallel_group, 6 | get_sequence_parallel_group, 7 | get_sequence_parallel_rank, 8 | get_sequence_parallel_size, 9 | init_distributed, 10 | is_coworker, 11 | is_distributed, 12 | local_rank, 13 | node_size, 14 | nproc_per_node, 15 | rank, 16 | seq_all_to_all, 17 | use_coworker, 18 | worker_local_rank, 19 | worker_num_per_node, 20 | world_size, 21 | ) 22 | -------------------------------------------------------------------------------- /atorch/distributed/elastic_controller.py: -------------------------------------------------------------------------------- 1 | from atorch.common.log_utils import default_logger as logger 2 | 3 | try: 4 | from elasticai_api.pytorch.DDP_controller import DDPController 5 | except ImportError: 6 | logger.warning("Please install elasticai_api >= 1.4.2 .") 7 | 8 | 9 | class ElasticController(DDPController): 10 | def __init__(self, data_shard_service): 11 | super(ElasticController, self).__init__(data_shard_service) 12 | 13 | 14 | def elastic_controller(data_shard_service): 15 | return ElasticController(data_shard_service) 16 | -------------------------------------------------------------------------------- /atorch/distributed/hooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | from contextlib import closing 4 | 5 | from torch.distributed.elastic.agent.server.api import SimpleElasticAgent, _get_fq_hostname 6 | 7 | try: 8 | from torch.distributed.elastic.agent.server.api import _get_socket_with_port 9 | except (ImportError, ModuleNotFoundError): 10 | from torch.distributed.elastic.utils.distributed import get_socket_with_port as _get_socket_with_port 11 | 12 | 13 | def hook_set_master_addr_port(args=None): 14 | def _hook(store, master_addr, master_port, local_dir=None): 15 | """ 16 | PyTorch use master node's hostname as the MASTER_ADDR of process group. However, hostname may not be resolved 17 | in some Kubernetes environments. This function get master's ip address from POD_IP environment variable and 18 | set ip address as MASTER_ADDR. 19 | """ 20 | if master_port is None: 21 | sock = _get_socket_with_port() 22 | with closing(sock): 23 | master_port = sock.getsockname()[1] 24 | 25 | if master_addr is None: 26 | if local_dir is not None: 27 | master_addr = local_dir 28 | else: 29 | master_addr = os.environ.get("POD_IP") or socket.gethostbyname(_get_fq_hostname()) 30 | 31 | store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8")) 32 | store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8")) 33 | 34 | # hook SimpleElasticAgent._set_master_addr_port 35 | if hasattr(SimpleElasticAgent, "_set_master_addr_port"): 36 | setattr(SimpleElasticAgent, "_set_master_addr_port", staticmethod(_hook)) 37 | elif args and args.local_addr is None: 38 | args.local_addr = os.environ.get("POD_IP") or socket.getfqdn() 39 | -------------------------------------------------------------------------------- /atorch/distributed/mesh.py: -------------------------------------------------------------------------------- 1 | from atorch.common.log_utils import default_logger as logger 2 | from atorch.utils.import_util import is_torch_npu_available 3 | 4 | try: 5 | from torch.distributed.device_mesh import init_device_mesh 6 | except ImportError: 7 | init_device_mesh = None 8 | 9 | 10 | def build_mesh(slicing_dim, pg_name_prefix="", reverse_mesh_pg_order=True, device_type="cuda"): 11 | device_type = "npu" if is_torch_npu_available() else device_type 12 | 13 | dims = [] 14 | names = [] 15 | for item in slicing_dim: 16 | name = pg_name_prefix + item[0] 17 | d = item[1] 18 | if d > 1: 19 | dims.append(d) 20 | names.append(name) 21 | if reverse_mesh_pg_order: 22 | dims.reverse() 23 | names.reverse() 24 | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") 25 | names = tuple(names) 26 | return init_device_mesh(device_type, dims, mesh_dim_names=names) 27 | -------------------------------------------------------------------------------- /atorch/fault_tolerance/__init__.py: -------------------------------------------------------------------------------- 1 | from atorch.fault_tolerance.hanging_detector import HangingDetector 2 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/extensions/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/extensions/abstract_extension.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Callable, Optional 3 | 4 | 5 | class AbstractExtension(ABC): 6 | def __init__(self): 7 | self._name = str(self.__class__.__name__).replace("Extension", "") 8 | 9 | @property 10 | def name(self): 11 | return self._name 12 | 13 | @abstractmethod 14 | def is_available(self) -> bool: 15 | pass 16 | 17 | @abstractmethod 18 | def load(self) -> Optional[Callable[..., Any]]: 19 | pass 20 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_atten_1_extension.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .abstract_extension import AbstractExtension 4 | 5 | 6 | class _FlashAttn1Extension(AbstractExtension): 7 | def is_available(self) -> bool: 8 | available = False 9 | try: 10 | import flash_attn_1 # noqa F401 11 | 12 | available = torch.cuda.is_available() 13 | except (ImportError, ModuleNotFoundError): 14 | available = False 15 | return available 16 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_atten_extension.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .abstract_extension import AbstractExtension 4 | 5 | 6 | class _FlashAttnExtension(AbstractExtension): 7 | def is_available(self) -> bool: 8 | available = False 9 | try: 10 | import flash_attn # noqa F401 11 | 12 | available = torch.cuda.is_available() 13 | except (ImportError, ModuleNotFoundError): 14 | available = False 15 | return available 16 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/extensions/flash_attention/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention/dropout_add_layer_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..flash_atten_extension import _FlashAttnExtension 4 | 5 | 6 | class DropoutAddLayernormExtension(_FlashAttnExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | from flash_attn.ops.layer_norm import dropout_add_layer_norm # noqa 15 | 16 | available = True 17 | except (ImportError, ModuleNotFoundError): 18 | available = False 19 | 20 | return available 21 | 22 | def load(self) -> Optional[Callable[..., Any]]: 23 | if not self.is_available(): 24 | return None 25 | 26 | from flash_attn.ops.layer_norm import dropout_add_layer_norm # noqa 27 | 28 | return dropout_add_layer_norm 29 | 30 | 31 | dropout_add_layer_norm = DropoutAddLayernormExtension().load() 32 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention/flash_attn_3_func_ext.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.fa_util import patch_fa_interface_to_autocast 2 | from atorch.utils.import_util import is_flash_attn_3_avaliable 3 | 4 | from ..abstract_extension import AbstractExtension 5 | 6 | 7 | class FlashAttnFunc3Extension(AbstractExtension): 8 | def is_available(self) -> bool: 9 | return is_flash_attn_3_avaliable() 10 | 11 | def load(self): 12 | if not self.is_available(): 13 | return None, None 14 | 15 | import flash_attn_interface 16 | 17 | patch_fa_interface_to_autocast(flash_attn_interface) 18 | 19 | from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa 20 | 21 | return flash_attn_func, flash_attn_varlen_func 22 | 23 | 24 | flash_attn_func_3, flash_attn_varlen_func_3 = FlashAttnFunc3Extension().load() 25 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention/flash_attn_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..flash_atten_extension import _FlashAttnExtension 4 | 5 | 6 | class FlashAttnCrossEntropyExtension(_FlashAttnExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | from flash_attn.losses.cross_entropy import CrossEntropyLoss # noqa 15 | 16 | available = True 17 | except (ImportError, ModuleNotFoundError): 18 | available = False 19 | 20 | return available 21 | 22 | def load(self) -> Optional[Callable[..., Any]]: 23 | if not self.is_available(): 24 | return None 25 | 26 | from flash_attn.losses.cross_entropy import CrossEntropyLoss # noqa 27 | 28 | return CrossEntropyLoss 29 | 30 | 31 | FlashAttnCrossEntropyLoss = FlashAttnCrossEntropyExtension().load() 32 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention/flash_attn_func_ext.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.fa_util import patch_fa_interface_to_autocast 2 | 3 | from ..flash_atten_extension import _FlashAttnExtension 4 | 5 | 6 | class FlashAttnFuncExtension(_FlashAttnExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa 15 | 16 | available = True 17 | except (ImportError, ModuleNotFoundError): 18 | available = False 19 | 20 | return available 21 | 22 | def load(self): 23 | if not self.is_available(): 24 | return None, None 25 | 26 | import flash_attn.flash_attn_interface 27 | 28 | patch_fa_interface_to_autocast(flash_attn.flash_attn_interface) 29 | 30 | from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa 31 | 32 | return flash_attn_func, flash_attn_varlen_func 33 | 34 | 35 | flash_attn_func, flash_attn_varlen_func = FlashAttnFuncExtension().load() 36 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention_1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/extensions/flash_attention_1/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention_1/dropout_add_layer_norm_1.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..flash_atten_1_extension import _FlashAttn1Extension 4 | 5 | 6 | class DropoutAddLayernorm1Extension(_FlashAttn1Extension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | from flash_attn_1.ops.layer_norm import dropout_add_layer_norm # noqa 15 | 16 | available = True 17 | except (ImportError, ModuleNotFoundError): 18 | available = False 19 | 20 | return available 21 | 22 | def load(self) -> Optional[Callable[..., Any]]: 23 | if not self.is_available(): 24 | return None 25 | 26 | from flash_attn_1.ops.layer_norm import dropout_add_layer_norm # noqa 27 | 28 | return dropout_add_layer_norm 29 | 30 | 31 | dropout_add_layer_norm_1 = DropoutAddLayernorm1Extension().load() 32 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/flash_attention_1/flash_attn_func_ext_1.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from atorch.utils.fa_util import patch_fa_interface_to_autocast 4 | 5 | from ..flash_atten_1_extension import _FlashAttn1Extension 6 | 7 | 8 | class FlashAttnFunc1Extension(_FlashAttn1Extension): 9 | def is_available(self) -> bool: 10 | available = super().is_available() 11 | 12 | if not available: 13 | return False 14 | 15 | try: 16 | from flash_attn_1.flash_attn_interface import flash_attn_unpadded_func # noqa 17 | 18 | available = True 19 | except (ImportError, ModuleNotFoundError): 20 | available = False 21 | 22 | return available 23 | 24 | def load(self) -> Optional[Callable[..., Any]]: 25 | if not self.is_available(): 26 | return None 27 | 28 | import flash_attn_1.flash_attn_interface # noqa 29 | 30 | patch_fa_interface_to_autocast(flash_attn_1.flash_attn_interface) 31 | 32 | from flash_attn_1.flash_attn_interface import flash_attn_unpadded_func # noqa 33 | 34 | return flash_attn_unpadded_func 35 | 36 | 37 | flash_attn_unpadded_func_1 = FlashAttnFunc1Extension().load() 38 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/grouped_gemm_exts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/extensions/grouped_gemm_exts/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/extensions/grouped_gemm_exts/grouped_gemm_gmm.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Callable, Optional, Tuple 3 | 4 | import torch 5 | from torch.cuda.amp.autocast_mode import _cast, autocast 6 | 7 | from ..abstract_extension import AbstractExtension 8 | 9 | 10 | # patch fn to handle autocast 11 | @functools.lru_cache 12 | def _cast_fn(fn): 13 | @functools.wraps(fn) 14 | def new_fn(*args, **kwargs): 15 | if torch.is_autocast_enabled(): 16 | cur_dtype = torch.get_autocast_gpu_dtype() 17 | with autocast(enabled=False): 18 | return fn(*_cast(args, cur_dtype), **_cast(kwargs, cur_dtype)) 19 | else: 20 | return fn(*args, **kwargs) 21 | 22 | return new_fn 23 | 24 | 25 | class GroupedGEMMExtension(AbstractExtension): 26 | def is_available(self) -> bool: 27 | 28 | available = False 29 | try: 30 | import grouped_gemm # noqa 31 | import torch 32 | 33 | available = torch.cuda.is_available() 34 | except (ImportError, ModuleNotFoundError): 35 | available = False 36 | return available 37 | 38 | def load(self) -> Optional[Callable[..., Any]]: 39 | if not self.is_available(): 40 | return None 41 | 42 | gmm, _ = self._load_with_ext_package() 43 | return gmm 44 | 45 | def _load_with_ext_package(self) -> Optional[Tuple[Callable[..., Any], Any]]: 46 | import grouped_gemm as gg 47 | 48 | gmm = _cast_fn(gg.ops.gmm) 49 | return gmm, gg 50 | 51 | 52 | gmm = GroupedGEMMExtension().load() 53 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/extensions/npu/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu/adamw_npu.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..npu_extension import _NpuExtension 4 | 5 | 6 | class FusedAdamwNpuExtension(_NpuExtension): 7 | def is_available(self) -> bool: 8 | 9 | available = super().is_available() 10 | if not available: 11 | return False 12 | 13 | try: 14 | import torch_npu # noqa 15 | 16 | available = hasattr(torch_npu, "npu_apply_adam_w") 17 | except (ImportError, ModuleNotFoundError): 18 | available = False 19 | return available 20 | 21 | def load(self) -> Optional[Callable[..., Any]]: 22 | if not self.is_available(): 23 | return None 24 | 25 | import torch_npu # noqa 26 | 27 | return torch_npu.npu_apply_adam_w 28 | 29 | 30 | npu_apply_adam_w = FusedAdamwNpuExtension().load() 31 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu/flash_attention_npu.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..npu_extension import _NpuExtension 4 | 5 | 6 | class FlashAttentionNpuExtension(_NpuExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | import torch_npu # noqa 15 | 16 | available = hasattr(torch_npu, "npu_fusion_attention") 17 | except (ImportError, ModuleNotFoundError): 18 | available = False 19 | return available 20 | 21 | def load(self) -> Optional[Callable[..., Any]]: 22 | if not self.is_available(): 23 | return None 24 | 25 | import torch_npu # noqa 26 | 27 | return torch_npu.npu_fusion_attention 28 | 29 | 30 | npu_fusion_attention = FlashAttentionNpuExtension().load() 31 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu/fused_cross_entropy_npu.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..npu_extension import _NpuExtension 4 | 5 | 6 | class FusedCrossEntropyNpuExtension(_NpuExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | import mindspeed # noqa 15 | 16 | # Note: compile ops and npu_fuse_cross_entropy_loss in ant mindspeed 17 | from mindspeed import ops # noqa 18 | from mindspeed.ops import npu_fuse_cross_entropy_loss # noqa 19 | 20 | available = hasattr(mindspeed.ops.npu_fuse_cross_entropy_loss, "npu_fuse_cross_entropy_loss") 21 | except (ImportError, ModuleNotFoundError): 22 | available = False 23 | return available 24 | 25 | def load(self) -> Optional[Callable[..., Any]]: 26 | if not self.is_available(): 27 | return None 28 | 29 | import mindspeed # noqa 30 | 31 | # Note: compile ops and npu_fuse_cross_entropy_loss in ant mindspeed 32 | from mindspeed import ops # noqa 33 | from mindspeed.ops import npu_fuse_cross_entropy_loss # noqa 34 | 35 | return mindspeed.ops.npu_fuse_cross_entropy_loss.npu_fuse_cross_entropy_loss 36 | 37 | 38 | npu_fuse_cross_entropy_loss = FusedCrossEntropyNpuExtension().load() 39 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu/fused_permute_npu.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..npu_extension import _NpuExtension 4 | 5 | 6 | class FusedPermuteNpuExtension(_NpuExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | import mindspeed # noqa 15 | 16 | # Note: compile ops and npu_moe_token_permute in ant mindspeed 17 | from mindspeed import ops # noqa 18 | from mindspeed.ops import npu_moe_token_permute # noqa 19 | 20 | available = hasattr(mindspeed.ops.npu_moe_token_permute, "npu_moe_token_permute") 21 | except (ImportError, ModuleNotFoundError): 22 | available = False 23 | return available 24 | 25 | def load(self) -> Optional[Callable[..., Any]]: 26 | if not self.is_available(): 27 | return None 28 | 29 | import mindspeed # noqa 30 | 31 | # Note: compile ops and npu_moe_token_permute in ant mindspeed 32 | from mindspeed import ops # noqa 33 | from mindspeed.ops import npu_moe_token_permute # noqa 34 | 35 | return mindspeed.ops.npu_moe_token_permute.npu_moe_token_permute 36 | 37 | 38 | npu_fused_permute = FusedPermuteNpuExtension().load() 39 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu/fused_unpermute_npu.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..npu_extension import _NpuExtension 4 | 5 | 6 | class FusedUnpermuteNpuExtension(_NpuExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | import mindspeed # noqa 15 | 16 | # Note: compile ops and npu_moe_token_permute in ant mindspeed 17 | from mindspeed import ops # noqa 18 | from mindspeed.ops import npu_moe_token_unpermute # noqa 19 | 20 | available = hasattr(mindspeed.ops.npu_moe_token_unpermute, "npu_moe_token_unpermute") 21 | except (ImportError, ModuleNotFoundError): 22 | available = False 23 | return available 24 | 25 | def load(self) -> Optional[Callable[..., Any]]: 26 | if not self.is_available(): 27 | return None 28 | 29 | import mindspeed # noqa 30 | 31 | # Note: compile ops and npu_moe_token_permute in ant mindspeed 32 | from mindspeed import ops # noqa 33 | from mindspeed.ops import npu_moe_token_unpermute # noqa 34 | 35 | return mindspeed.ops.npu_moe_token_unpermute.npu_moe_token_unpermute 36 | 37 | 38 | npu_fused_unpermute = FusedUnpermuteNpuExtension().load() 39 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu/rms_norm_npu.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from ..npu_extension import _NpuExtension 4 | 5 | 6 | class RMSNormNpuExtension(_NpuExtension): 7 | def is_available(self) -> bool: 8 | available = super().is_available() 9 | 10 | if not available: 11 | return False 12 | 13 | try: 14 | import torch_npu # noqa 15 | 16 | available = hasattr(torch_npu, "npu_rms_norm") 17 | except (ImportError, ModuleNotFoundError): 18 | available = False 19 | return available 20 | 21 | def load(self) -> Optional[Callable[..., Any]]: 22 | if not self.is_available(): 23 | return None 24 | 25 | import torch_npu # noqa 26 | 27 | return torch_npu.npu_rms_norm 28 | 29 | 30 | npu_rms_norm = RMSNormNpuExtension().load() 31 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/npu_extension.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.import_util import is_torch_npu_available 2 | 3 | from .abstract_extension import AbstractExtension 4 | 5 | 6 | class _NpuExtension(AbstractExtension): 7 | def is_available(self) -> bool: 8 | available = False 9 | try: 10 | available = is_torch_npu_available() 11 | except (ImportError, ModuleNotFoundError): 12 | available = False 13 | return available 14 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/te/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/extensions/te/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/extensions/te/moe_func.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from ..abstract_extension import AbstractExtension 4 | 5 | 6 | def _permute_helper(tokens, indices, num_out_tokens: int = None, fn=None): 7 | if num_out_tokens is None: 8 | # num_out_tokens required by te 9 | num_out_tokens = tokens.shape[0] 10 | if indices.dim() == 1: 11 | # 2D indices required by te 12 | indices = indices.view(-1, 1) 13 | return fn(tokens, indices, num_out_tokens) 14 | 15 | 16 | class TeMoeExtension(AbstractExtension): 17 | def is_available(self) -> bool: 18 | try: 19 | from transformer_engine.pytorch import moe_permute, moe_unpermute # noqa: F401 20 | 21 | return True 22 | except (ImportError, ModuleNotFoundError): 23 | return False 24 | 25 | def load(self): 26 | if not self.is_available(): 27 | return None, None 28 | 29 | from transformer_engine.pytorch import moe_permute, moe_unpermute 30 | 31 | permute_func = partial(_permute_helper, fn=moe_permute) 32 | return permute_func, moe_unpermute 33 | 34 | 35 | te_permute, te_unpermute = TeMoeExtension().load() 36 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/torch_xla_extension.py: -------------------------------------------------------------------------------- 1 | from .abstract_extension import AbstractExtension 2 | 3 | 4 | class _TorchxlaExtension(AbstractExtension): 5 | def is_available(self) -> bool: 6 | available = False 7 | try: 8 | import torch_xla # noqa 9 | 10 | available = True 11 | except (ImportError, ModuleNotFoundError): 12 | available = False 13 | return available 14 | -------------------------------------------------------------------------------- /atorch/kernels/extensions/xla/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/extensions/xla/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/patches/__init__.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | from atorch.kernels.patches.patch_llama3_fa3 import llama_fa3_attention_forward 4 | from atorch.utils.import_util import is_flash_attn_3_avaliable 5 | from atorch.utils.version import package_version_smaller_than 6 | 7 | 8 | def apply_fa3_to_llama3(): 9 | if not is_flash_attn_3_avaliable(): 10 | raise ModuleNotFoundError(f"Please install flash-attention-3") 11 | 12 | if package_version_smaller_than("transformers", "4.43.0"): 13 | raise NotImplementedError( 14 | f"transformers version should bigger than 4.43.0, but current version is {transformers.__version__}" 15 | ) 16 | from transformers.utils import is_flash_attn_2_available 17 | 18 | if not is_flash_attn_2_available(): 19 | raise ModuleNotFoundError(f"flash-attention-2 is needed when using flash-attention-3") 20 | 21 | transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_fa3_attention_forward 22 | -------------------------------------------------------------------------------- /atorch/kernels/triton_jit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/kernels/triton_jit/__init__.py -------------------------------------------------------------------------------- /atorch/kernels/triton_jit/rope.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.import_util import is_liger_kernel_available 2 | 3 | 4 | def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 5 | if is_liger_kernel_available(): 6 | from liger_kernel.ops.rope import LigerRopeFunction 7 | 8 | return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) 9 | else: 10 | raise RuntimeError("liger_rotary_pos_emb is not available") 11 | -------------------------------------------------------------------------------- /atorch/kernels/triton_jit/swiglu.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.import_util import is_liger_kernel_available 2 | 3 | 4 | def liger_fused_silu(a, b): 5 | """Compute silu(a) * b""" 6 | if is_liger_kernel_available(): 7 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction 8 | 9 | return LigerSiLUMulFunction.apply(a, b) 10 | else: 11 | raise RuntimeError("atorch fused silu is not available") 12 | -------------------------------------------------------------------------------- /atorch/kernels/triton_jit/triton_import_lib.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | 4 | class Library(object): 5 | constexpr = int 6 | autotune = lambda *args, **kwargs: wraps # noqa: E731 7 | Config = lambda *args, **kwargs: None # noqa: E731 8 | jit = wraps 9 | heuristics = lambda *args, **kwargs: wraps # noqa: E731 10 | -------------------------------------------------------------------------------- /atorch/local_sgd/DDP/__init__.py: -------------------------------------------------------------------------------- 1 | from .outer_optim_model_averager import OuterOptimPeriodicModelAverager 2 | from .stateful_post_localSGD_optimizer import StatefulPostLocalSGDOptimizer 3 | -------------------------------------------------------------------------------- /atorch/local_sgd/DDP/stateful_post_localSGD_optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.distributed.optim.post_localSGD_optimizer import PostLocalSGDOptimizer 2 | 3 | from .outer_optim_model_averager import OuterOptimPeriodicModelAverager 4 | 5 | 6 | class StatefulPostLocalSGDOptimizer(PostLocalSGDOptimizer): 7 | def state_dict(self): 8 | post_local_sgd_sd = super().state_dict() 9 | if isinstance(self.averager, OuterOptimPeriodicModelAverager): 10 | averager_sd = self.averager.state_dict() 11 | post_local_sgd_sd["averager"] = averager_sd 12 | 13 | return post_local_sgd_sd 14 | 15 | def load_state_dict(self, state_dict): 16 | averager_sd = state_dict.pop("averager", None) 17 | if averager_sd is not None and isinstance(self.averager, OuterOptimPeriodicModelAverager): 18 | self.averager.load_state_dict(averager_sd) 19 | 20 | super().load_state_dict(state_dict) 21 | -------------------------------------------------------------------------------- /atorch/local_sgd/FSDP/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from atorch.common.log_utils import default_logger as logger 4 | from atorch.utils.version import torch_version # noqa: E402 5 | 6 | 7 | def patch_local_sgd_to_fsdp(): 8 | if torch_version()[:2] == (2, 1): # type: ignore 9 | from .torch_2_1_0 import patch_local_sgd_to_fsdp 10 | 11 | patch_local_sgd_to_fsdp() 12 | elif torch_version()[0] == 2 and torch_version()[1] >= 4: # type: ignore 13 | from .torch_2_4_0 import patch_local_sgd_to_fsdp 14 | 15 | patch_local_sgd_to_fsdp() 16 | else: 17 | raise ValueError("Only pytorch 2.1.x and >=2.4.x supports local sgd!") 18 | logger.info(f"Local SGD hacked on Pytorch {torch.__version__}!") 19 | -------------------------------------------------------------------------------- /atorch/local_sgd/FSDP/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/local_sgd/FSDP/core/__init__.py -------------------------------------------------------------------------------- /atorch/local_sgd/FSDP/torch_2_1_0/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..core._state_dict_utils import ( 4 | _load_local_sgd_state_dict, 5 | _local_sgd_state_dict, 6 | _pre_state_dict_hook, 7 | _save_local_sgd_state_dict, 8 | ) 9 | from ._init_utils import fsdp_inits 10 | from ._runtime_utils import _init_streams, _reduce_grad, _share_state_and_init_handle_attrs, _unshard, forward 11 | 12 | 13 | def patch_local_sgd_to_fsdp(): 14 | torch.distributed.fsdp._runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs 15 | torch.distributed.fsdp._runtime_utils._init_streams = _init_streams 16 | torch.distributed.fsdp._runtime_utils._reduce_grad = _reduce_grad 17 | torch.distributed.fsdp._runtime_utils._unshard = _unshard 18 | 19 | torch.distributed.fsdp._state_dict_utils._pre_state_dict_hook = _pre_state_dict_hook 20 | 21 | torch.distributed.fsdp.FullyShardedDataParallel.__init__ = fsdp_inits 22 | torch.distributed.fsdp.FullyShardedDataParallel.forward = forward 23 | torch.distributed.fsdp.FullyShardedDataParallel.local_sgd_state_dict = staticmethod(_local_sgd_state_dict) 24 | torch.distributed.fsdp.FullyShardedDataParallel.save_local_sgd_state_dict = staticmethod(_save_local_sgd_state_dict) 25 | torch.distributed.fsdp.FullyShardedDataParallel.load_local_sgd_state_dict = staticmethod(_load_local_sgd_state_dict) 26 | -------------------------------------------------------------------------------- /atorch/local_sgd/FSDP/torch_2_4_0/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..core._state_dict_utils import ( 4 | _load_local_sgd_state_dict, 5 | _local_sgd_state_dict, 6 | _pre_state_dict_hook, 7 | _save_local_sgd_state_dict, 8 | ) 9 | from ._init_utils import fsdp_inits 10 | from ._runtime_utils import _init_streams, _reduce_grad, _share_state_and_init_handle_attrs, _unshard, forward 11 | 12 | 13 | def patch_local_sgd_to_fsdp(): 14 | torch.distributed.fsdp._runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs 15 | torch.distributed.fsdp._runtime_utils._init_streams = _init_streams 16 | torch.distributed.fsdp._runtime_utils._reduce_grad = _reduce_grad 17 | torch.distributed.fsdp._runtime_utils._unshard = _unshard 18 | 19 | torch.distributed.fsdp._state_dict_utils._pre_state_dict_hook = _pre_state_dict_hook 20 | 21 | torch.distributed.fsdp.FullyShardedDataParallel.__init__ = fsdp_inits 22 | torch.distributed.fsdp.FullyShardedDataParallel.forward = forward 23 | torch.distributed.fsdp.FullyShardedDataParallel.local_sgd_state_dict = staticmethod(_local_sgd_state_dict) 24 | torch.distributed.fsdp.FullyShardedDataParallel.save_local_sgd_state_dict = staticmethod(_save_local_sgd_state_dict) 25 | torch.distributed.fsdp.FullyShardedDataParallel.load_local_sgd_state_dict = staticmethod(_load_local_sgd_state_dict) 26 | -------------------------------------------------------------------------------- /atorch/local_sgd/README.md: -------------------------------------------------------------------------------- 1 | # Unified Local SGD (EDiT) Support 2 | 3 | ATorch supports local sgd (EDiT) integration with mainstream parallelism strategies. For detail, refer to [EDiT](../../docs/README-EDiT.md) 4 | -------------------------------------------------------------------------------- /atorch/local_sgd/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE since megatron impl is very version dependent, do not import anything megatron here 2 | from .configs import GTAConfig, LocalSGDConfig, OuterOptimizerConfig 3 | from .DDP import OuterOptimPeriodicModelAverager, StatefulPostLocalSGDOptimizer 4 | from .FSDP import patch_local_sgd_to_fsdp 5 | -------------------------------------------------------------------------------- /atorch/local_sgd/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .configs import GTAConfig, LocalSGDConfig, OuterOptimizerConfig 2 | -------------------------------------------------------------------------------- /atorch/local_sgd/configs/configs.py: -------------------------------------------------------------------------------- 1 | # Here we accommodate configs for local sgd/reducer 2 | from dataclasses import dataclass, field 3 | from typing import Optional, Type 4 | 5 | from torch.optim import Optimizer 6 | from typing_extensions import Literal 7 | 8 | 9 | @dataclass 10 | class LocalSGDConfig: 11 | # If set use_async=False, normal Local sgd will be used and ranks are synced with steps; 12 | # If set use_async=True, async Local sgd will be used and ranks are synced with time. 13 | use_async: bool = False 14 | # Normal Local SGD 15 | local_sgd_sync_interval: int = 1 16 | # Async Local SGD 17 | local_sgd_sync_time: float = 600 # seconds 18 | min_total_global_steps: int = 100 19 | use_step_weight: bool = False 20 | step_weight_ratio: float = 0.5 21 | # General parameters 22 | local_sgd_warmup_steps: int = 0 23 | gradient_accumulation_steps: int = 1 24 | clip_pseudo_grad: Optional[float] = None 25 | pseudo_gradnorm_reduce: bool = False 26 | weight_softmax_temperature: Optional[float] = None 27 | # anomaly detection related 28 | skip_anomaly: bool = False 29 | ewma_alpha: float = 0.02 30 | ewma_warmup_steps: int = 120 31 | ewma_threshold: int = 3 32 | cpu_offload: bool = False 33 | is_debug: bool = False 34 | 35 | 36 | @dataclass 37 | class OuterOptimizerConfig: 38 | outer_optim_class: Optional[Type[Optimizer]] = None 39 | outer_optim_kwargs: dict = field(default_factory=dict) 40 | 41 | 42 | @dataclass 43 | class GTAConfig: 44 | reducer: Optional[Literal["linear", "gta"]] = None 45 | consensus_method: Optional[Literal["sum", "count"]] = None 46 | sparsification_method: Optional[Literal["magnitude", "random", "rescaled_random"]] = None 47 | normalize: bool = True 48 | density: float = 1.0 49 | int8_mask: bool = False 50 | -------------------------------------------------------------------------------- /atorch/local_sgd/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | 3 | from atorch.common.log_utils import default_logger as logger 4 | from atorch.utils.import_util import is_megatron_lm_available 5 | from atorch.utils.version import get_megatron_version, is_megatron_version_bigger_than 6 | 7 | if is_megatron_lm_available(): 8 | import megatron 9 | import megatron.training 10 | import megatron.training.arguments 11 | from megatron import core as megatron_core 12 | from megatron.core.package_info import __version__ as megatron_version 13 | 14 | if is_megatron_version_bigger_than("0.9.0"): 15 | # these also imports megatron 16 | from .optimizer import get_megatron_optimizer 17 | from .parallel_state import initialize_model_parallel 18 | from .timers import Timers as LSDTimers 19 | from .training import get_model 20 | else: 21 | logger.info(f"Local SGD is not supported on Megatron {get_megatron_version()}") 22 | 23 | from .arguments import local_sgd_args_provider 24 | 25 | 26 | def _set_lsd_timers(args): 27 | from megatron.training.global_vars import _ensure_var_is_not_initialized 28 | 29 | _ensure_var_is_not_initialized(megatron.training.global_vars._GLOBAL_TIMERS, "timers") 30 | megatron.training.global_vars._GLOBAL_TIMERS = LSDTimers(args.timing_log_level, args.timing_log_option) 31 | 32 | 33 | def patch_megatron_for_local_sgd(): 34 | logger.warning("Local SGD Megatron patches must be applied before megatron is initialized") 35 | if is_megatron_lm_available() and is_megatron_version_bigger_than("0.9.0"): 36 | # patch the Timers 37 | megatron_core.Timers = LSDTimers 38 | # patch initializer 39 | megatron_core.parallel_state.initialize_model_parallel = initialize_model_parallel 40 | megatron.training.training.get_model = get_model 41 | megatron.training.global_vars._set_timers = _set_lsd_timers 42 | -------------------------------------------------------------------------------- /atorch/local_sgd/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .anomaly_detection import OnlineDynamicEWMA 2 | from .reduce_methods import GTAReducer, LinearReducer, TensorReducer 3 | -------------------------------------------------------------------------------- /atorch/local_sgd/utils/reduce_methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import TensorReducer 2 | from .generalized_task_arithmetic import GTAReducer 3 | from .linear import LinearReducer 4 | -------------------------------------------------------------------------------- /atorch/local_sgd/utils/reduce_methods/linear.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from .base import TensorReducer 7 | 8 | 9 | class LinearReducer(TensorReducer): 10 | def __init__( 11 | self, 12 | process_group: dist.ProcessGroup, 13 | normalize: bool = True, 14 | weight_softmax_temperature: Optional[float] = None, 15 | ): 16 | super().__init__(process_group, weight_softmax_temperature) 17 | # use all gather based softmax in case a softmax temperature is given 18 | self.normalize = normalize and self.weight_softmax_temperature is None 19 | 20 | def _reduce_tensor(self, tensor: torch.Tensor, **kwargs): 21 | weight = self._refine_weight(tensor.device, tensor.dtype, **kwargs) 22 | 23 | with torch.no_grad(): 24 | tensor *= weight 25 | if self.normalize: 26 | if weight is not None: 27 | divisor = weight.clone() 28 | dist.all_reduce(divisor, op=dist.ReduceOp.SUM, group=self.process_group) 29 | divisor[divisor.abs() < 1e-8] = 1.0 30 | else: 31 | divisor = dist.get_world_size(group=self.process_group) 32 | 33 | tensor /= divisor 34 | dist.all_reduce(tensor, group=self.process_group) 35 | -------------------------------------------------------------------------------- /atorch/local_sgd/utils/reduce_methods/sparsify.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | 5 | 6 | class SparsificationMethod(str, Enum): 7 | magnitude = "magnitude" 8 | random = "random" 9 | rescaled_random = "rescaled_random" 10 | 11 | 12 | def magnitude(tensor: torch.Tensor, density: float) -> torch.Tensor: 13 | """Masks out the smallest values, retaining a proportion of `density`.""" 14 | if density >= 1: 15 | return tensor 16 | 17 | k = int(density * tensor.view(-1).shape[0]) 18 | 19 | assert k > 0, "not gonna zero out the whole tensor buddy" 20 | mask = torch.zeros_like(tensor) 21 | w = tensor.abs().view(-1) 22 | if w.device.type == "cpu": 23 | w = w.float() 24 | topk = torch.topk(w, k=k, largest=True) 25 | mask.view(-1)[topk.indices] = 1 26 | mask.to(device=tensor.device) 27 | 28 | return tensor * mask 29 | 30 | 31 | def bernoulli(tensor: torch.Tensor, density: float, rescale: bool = True) -> torch.Tensor: 32 | if density >= 1: 33 | return tensor 34 | 35 | if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16: 36 | work_dtype = tensor.dtype 37 | else: 38 | # torch.bernoulli not implemented for float16 on CPU, upcast to float32 39 | work_dtype = torch.float32 40 | 41 | mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density, dtype=work_dtype)).to(device=tensor.device) 42 | res = tensor.to(work_dtype) * mask 43 | if rescale: 44 | res /= density 45 | return res.to(tensor.dtype) 46 | 47 | 48 | def sparsify(tensor: torch.Tensor, density: float, method: SparsificationMethod) -> torch.Tensor: 49 | if method == SparsificationMethod.magnitude: 50 | return magnitude(tensor, density=density) 51 | elif method == SparsificationMethod.random: 52 | return bernoulli(tensor, density=density, rescale=False) 53 | elif method == SparsificationMethod.rescaled_random: 54 | return bernoulli(tensor, density=density, rescale=True) 55 | else: 56 | raise NotImplementedError(method) 57 | -------------------------------------------------------------------------------- /atorch/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/modules/__init__.py -------------------------------------------------------------------------------- /atorch/modules/distributed_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/modules/distributed_modules/__init__.py -------------------------------------------------------------------------------- /atorch/modules/distributed_modules/compilers/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipe_compiler.distributed_pippy_compiler import DeviceSafeDriver, SafeStage, pippy_compiler 2 | from .tp_compiler.tp_compiler import tp_compiler 3 | -------------------------------------------------------------------------------- /atorch/modules/distributed_modules/compilers/pipe_compiler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/modules/distributed_modules/compilers/pipe_compiler/__init__.py -------------------------------------------------------------------------------- /atorch/modules/distributed_modules/compilers/tp_compiler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/modules/distributed_modules/compilers/tp_compiler/__init__.py -------------------------------------------------------------------------------- /atorch/modules/distributed_modules/compilers/tp_compiler/dtensor_compiler.py: -------------------------------------------------------------------------------- 1 | # A Tensor Parallel Compiler that compiles a model into DTensor implementation. 2 | # Module parameters are sharded with distribute_module api. 3 | # tensors flowing through the graph are resharded with DTensor.redistributed api. 4 | -------------------------------------------------------------------------------- /atorch/modules/distributed_modules/materialize_modules.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.meta_model_utils import is_meta, reload_meta_module 2 | 3 | from .layers import ATorchTPLayer 4 | 5 | 6 | def materialize_modules_to_device(model, device="cpu"): 7 | # Base case: if the model is an instance of ATorchTPLayer 8 | if not is_meta(model): 9 | model.to(device) 10 | else: 11 | if isinstance(model, ATorchTPLayer): 12 | model.reset_parameters() 13 | model.to(device) 14 | else: 15 | # If the model is not an instance of ATorchTPLayer 16 | # we have to check its submodules 17 | # and see if any of them are instances of ATorchTPLayer 18 | has_ATorchTPLayer = any(isinstance(module, ATorchTPLayer) for _, module in model.named_modules()) 19 | 20 | # If the model doesn't contain any ATorchTPLayer 21 | if not has_ATorchTPLayer: 22 | # We can safely reload the meta modules and move the whole model to the device 23 | reload_meta_module(model, device) 24 | else: 25 | # Otherwise we need to process each child separately 26 | for name, child in model.named_children(): 27 | materialize_modules_to_device(child, device) 28 | 29 | if is_meta(model): 30 | reload_meta_module(model, device, delete_ckpt_name=False) 31 | -------------------------------------------------------------------------------- /atorch/modules/distributed_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_attention import DistributedSelfAttention, DistributedSoftmax 2 | -------------------------------------------------------------------------------- /atorch/modules/fp8/__init__.py: -------------------------------------------------------------------------------- 1 | from .precision_switchable_linear import LinearPrecision, PrecisionSwitchableLinear 2 | from .scaled_linear import ScaledLinear, fp8_valid_shape_check 3 | from .utils import ( 4 | get_fp8_module_count, 5 | set_linear_modules_pre_cast_input_fp8_current_scaling, 6 | set_linear_modules_precision, 7 | ) 8 | -------------------------------------------------------------------------------- /atorch/modules/moe/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddp import MoEMixtureDistributedDataParallel 2 | from .moe_layer import ( 3 | Experts, 4 | MOEGroupContext, 5 | MOELayer, 6 | get_experts_ddp_process_group, 7 | get_experts_process_group, 8 | set_experts_process_group, 9 | ) 10 | from .switch_gating import SwitchGate 11 | from .topk_gating import TopkGate 12 | -------------------------------------------------------------------------------- /atorch/modules/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | try: 4 | from apex.normalization import FusedLayerNorm as LayerNorm 5 | from apex.parallel import SyncBatchNorm 6 | except (ImportError, ModuleNotFoundError) as e: 7 | warnings.warn("Try using apex LayerNorm but import fail:%s" % e) 8 | from torch.nn import LayerNorm as LayerNorm 9 | from torch.nn import SyncBatchNorm 10 | from .layernorm import AtorchLayerNorm 11 | 12 | __all__ = ["LayerNorm", "SyncBatchNorm", "AtorchLayerNorm"] 13 | -------------------------------------------------------------------------------- /atorch/modules/normalization/layernorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from atorch.kernels import atorch_layer_norm 4 | from atorch.utils.import_util import is_triton_available 5 | 6 | 7 | class AtorchLayerNorm(torch.nn.LayerNorm): 8 | def __init__(self, *args, **kwargs): 9 | if not is_triton_available(): 10 | raise RuntimeError("Triton is not installed. AtorchLayerNorm need it") 11 | return super().__init__(*args, **kwargs) 12 | 13 | def forward(self, input): 14 | return atorch_layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) 15 | -------------------------------------------------------------------------------- /atorch/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import CrossEntropyLoss 2 | -------------------------------------------------------------------------------- /atorch/modules/transformer/losses.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.import_util import is_triton_available 2 | 3 | if is_triton_available(): 4 | from .cross_entropy import AtorchCrossEntropyLoss as CrossEntropyLoss # noqa F401 5 | else: 6 | from torch.nn import CrossEntropyLoss # type:ignore # noqa F401 7 | -------------------------------------------------------------------------------- /atorch/modules/transformer/rmsnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from atorch.utils.import_util import is_torch_npu_available, is_triton_available 4 | 5 | if is_triton_available(): 6 | from atorch.kernels.triton_jit.rmsnorm_kernel import AtorchRmsNormFunc 7 | 8 | if is_torch_npu_available(): 9 | from atorch.kernels import npu_rms_norm 10 | 11 | 12 | class AtorchRmsNorm(torch.nn.Module): 13 | def __init__(self, hidden_size, eps=1e-06, dtype=torch.float32, reset_fn=None): 14 | if not is_triton_available() and not is_torch_npu_available(): 15 | raise NotImplementedError("No backend found, AtorchRmsNorm is not available") 16 | super().__init__() 17 | self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype)) 18 | self.variance_epsilon = eps 19 | self.dtype = dtype 20 | self.reset_fn = reset_fn 21 | 22 | def forward(self, x): 23 | if is_torch_npu_available(): 24 | return npu_rms_norm(x, self.weight, self.variance_epsilon)[0] 25 | return AtorchRmsNormFunc.apply(x, self.weight, self.variance_epsilon) 26 | 27 | def reset_parameters(self): 28 | if self.reset_fn is not None: 29 | self.reset_fn(self.weight) 30 | return 31 | torch.nn.init.ones_(self.weight) 32 | -------------------------------------------------------------------------------- /atorch/mup/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import MupLinear, MupModule, OutputLayer, QKVLayer, QLayer, SharedOutputLayer 2 | from .optim import MuAdam, MuAdamParamGroupsAdjust, MuSGD, MuSGDParamGroupsAdjust 3 | from .shape import make_base_shapes, save_base_shapes, set_base_shapes 4 | -------------------------------------------------------------------------------- /atorch/npu/op_builder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/npu/op_builder/__init__.py -------------------------------------------------------------------------------- /atorch/npu/op_builder/gmm_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .npu_builder import NPUOpBuilder 4 | 5 | 6 | class GMMOpBuilder(NPUOpBuilder): 7 | OP_NAME = "grouped_matmul" 8 | OP_PROTO = ( 9 | "npu_gmm.List(Tensor x, Tensor weight, *, Tensor? bias, int[]? group_list, int? group_type) -> Tensor", 10 | "npu_gmm.Tensor(Tensor x, Tensor weight, *, Tensor? bias, Tensor? group_list, int? group_type) -> Tensor", 11 | ) 12 | TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split(".")[:2]) 13 | 14 | def __init__(self): 15 | super(GMMOpBuilder, self).__init__(self.OP_NAME) 16 | self.register_op_proto(self.OP_PROTO) 17 | 18 | def sources(self): 19 | return ["npu/csrc/cann/gmm.cpp"] 20 | 21 | def include_paths(self): 22 | paths = super().include_paths() 23 | paths += ["npu/csrc/inc"] 24 | return paths 25 | 26 | def cxx_args(self): 27 | args = super().cxx_args() 28 | args += [ 29 | "-Wno-sign-compare", 30 | "-Wno-deprecated-declarations", 31 | "-Wno-return-type", 32 | "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'", 33 | ] 34 | if self.TORCH_MAJOR >= 2 and self.TORCH_MINOR >= 1: 35 | cpp_std = " -std=c++17" 36 | else: 37 | cpp_std = " -std=c++14" 38 | args.append(cpp_std) 39 | return args 40 | -------------------------------------------------------------------------------- /atorch/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/ops/__init__.py -------------------------------------------------------------------------------- /atorch/ops/accelerator/__init__.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright 2023 AntGroups, Inc. 2 | # ATorch Team 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | # DeepSpeed Team 8 | 9 | from .abstract_accelerator import BaseAccelerator 10 | from .real_accelerator import get_accelerator, set_accelerator 11 | -------------------------------------------------------------------------------- /atorch/ops/csrc/includes/kernel_utils.h: -------------------------------------------------------------------------------- 1 | // Modifications Copyright 2023 AntGroups, Inc. 2 | // ATorch Team 3 | 4 | // Copyright (c) Microsoft Corporation. 5 | // SPDX-License-Identifier: Apache-2.0 6 | 7 | // DeepSpeed Team 8 | 9 | /* 10 | Centralized header file for preprocessor macros and constants 11 | used throughout the codebase. 12 | */ 13 | 14 | #pragma once 15 | 16 | #include 17 | 18 | #define HD_INLINE __host__ __device__ __forceinline__ 19 | #define D_INLINE __device__ __forceinline__ 20 | 21 | #ifdef __HIP_PLATFORM_HCC__ 22 | 23 | // constexpr variant of warpSize for templating 24 | constexpr int hw_warp_size = 64; 25 | #define HALF_PRECISION_AVAILABLE = 1 26 | #include 27 | 28 | #else // !__HIP_PLATFORM_HCC__ 29 | 30 | // constexpr variant of warpSize for templating 31 | constexpr int hw_warp_size = 32; 32 | 33 | #if __CUDA_ARCH__ >= 530 34 | #define HALF_PRECISION_AVAILABLE = 1 35 | #define PTX_AVAILABLE 36 | #endif // __CUDA_ARCH__ >= 530 37 | 38 | #if __CUDA_ARCH__ >= 800 39 | #define ASYNC_COPY_AVAILABLE 40 | #endif // __CUDA_ARCH__ >= 800 41 | 42 | #include 43 | 44 | #endif // __HIP_PLATFORM_HCC__ 45 | 46 | inline int next_pow2(const int val) { 47 | int rounded_val = val - 1; 48 | rounded_val |= rounded_val >> 1; 49 | rounded_val |= rounded_val >> 2; 50 | rounded_val |= rounded_val >> 4; 51 | rounded_val |= rounded_val >> 8; 52 | return rounded_val + 1; 53 | } 54 | -------------------------------------------------------------------------------- /atorch/ops/csrc/includes/quantization.h: -------------------------------------------------------------------------------- 1 | // Modifications Copyright 2023 AntGroups, Inc. 2 | // ATorch Team 3 | 4 | // Copyright (c) Microsoft Corporation. 5 | // SPDX-License-Identifier: Apache-2.0 6 | 7 | // DeepSpeed Team 8 | 9 | #pragma once 10 | 11 | #include 12 | 13 | #include "kernel_utils.h" 14 | 15 | namespace quantize { 16 | 17 | enum class Type { Symmetric, Asymmetric }; 18 | 19 | struct PackedInt4 { 20 | int8_t high : 4; 21 | int8_t low : 4; 22 | }; 23 | 24 | HD_INLINE bool requires_offset(Type qType) { 25 | return qType == Type::Asymmetric; 26 | } 27 | 28 | } // namespace quantize 29 | 30 | void launch_quant(int8_t* output_data, float* params, const __half* input_data, 31 | const int groups, const int elems_per_group, 32 | const int num_bits, const quantize::Type quant_type, 33 | cudaStream_t stream); 34 | 35 | template 36 | void launch_dequantize_kernel(T* dequant_data, const int8_t* q_data, 37 | const float* q_params, quantize::Type q_type, 38 | int num_bits, int elems_per_group, 39 | int total_elems, cudaStream_t stream); 40 | 41 | void launch_swizzled_quant(int8_t* q_data, float* q_scales, 42 | const __half* input_data, int num_bits, 43 | quantize::Type q_type, int groups, 44 | int elems_per_group, int pipelining, int nodes, 45 | int devices_per_node, cudaStream_t stream); 46 | 47 | void launch_dequant_reduce(int8_t* reduced_data, float* reduced_scales, 48 | const int8_t* input_data, const float* input_scales, 49 | int num_gpus, int num_bits, 50 | quantize::Type quant_type, int out_groups, 51 | int elems_per_out_group, int elems_per_in_tensor, 52 | int groups_per_in_tensor, int elems_per_in_group, 53 | cudaStream_t stream); 54 | -------------------------------------------------------------------------------- /atorch/ops/csrc/includes/quantization_optimizer.h: -------------------------------------------------------------------------------- 1 | // Modifications Copyright 2023 AntGroups, Inc. 2 | 3 | // Copyright (c) Tsinghua Statistical Artificial Intelligence & Learning Group. 4 | // SPDX-License-Identifier: Apache-2.0 5 | 6 | #ifndef ATORCH_OPS_CSRC_INCLUDES_QUANTIZATION_OPTIMIZER_H_ 7 | #define ATORCH_OPS_CSRC_INCLUDES_QUANTIZATION_OPTIMIZER_H_ 8 | 9 | // Helper for type check 10 | #define CHECK_CUDA_TENSOR_DIM_TYPE(name, n_dim, type) \ 11 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 12 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 13 | TORCH_CHECK(name.dim() == n_dim, \ 14 | "The dimension of " #name " is not correct!"); \ 15 | TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!"); \ 16 | 17 | // Helper for type check 18 | #define CHECK_CUDA_TENSOR_TYPE(name, type) \ 19 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 20 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 21 | TORCH_CHECK(name.dtype() == type, "The type of " #name " is not correct!"); \ 22 | 23 | // Helper for type check 24 | #define CHECK_CUDA_TENSOR_FLOAT(name) \ 25 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 26 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 27 | TORCH_CHECK(name.dtype() == torch::kFloat32 \ 28 | || name.dtype() == torch::kFloat16, \ 29 | "The type of " #name " is not kFloat32 or kFloat16!"); \ 30 | 31 | // Helper for type check 32 | #define CHECK_CUDA_TENSOR_DIM_FLOAT(name, n_dim) \ 33 | TORCH_CHECK(name.device().is_cuda(), #name " must be a CUDA tensor!"); \ 34 | TORCH_CHECK(name.is_contiguous(), #name " must be contiguous!"); \ 35 | TORCH_CHECK(name.dim() == n_dim, \ 36 | "The dimension of " #name " is not correct!"); \ 37 | TORCH_CHECK(name.dtype() == torch::kFloat32 \ 38 | || name.dtype() == torch::kFloat16, \ 39 | "The type of " #name " is not kFloat32 or kFloat16!"); \ 40 | 41 | #endif // ATORCH_OPS_CSRC_INCLUDES_QUANTIZATION_OPTIMIZER_H_ 42 | -------------------------------------------------------------------------------- /atorch/ops/csrc/includes/quantizer.h: -------------------------------------------------------------------------------- 1 | // Modifications Copyright 2023 AntGroups, Inc. 2 | // ATorch Team 3 | 4 | // Copyright (c) Microsoft Corporation. 5 | // SPDX-License-Identifier: Apache-2.0 6 | 7 | // DeepSpeed Team 8 | 9 | #pragma once 10 | 11 | #ifdef __HIP_PLATFORM_HCC__ 12 | #include 13 | #else 14 | #include 15 | #endif 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | -------------------------------------------------------------------------------- /atorch/ops/git_version_info.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright 2023 AntGroups, Inc. 2 | # ATorch Team 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | # DeepSpeed Team 8 | 9 | try: 10 | # This is populated by setup.py 11 | from atorch.ops.git_version_info_installed import ( 12 | compatible_ops, 13 | git_branch, 14 | git_hash, 15 | installed_ops, 16 | torch_info, 17 | version, 18 | ) 19 | except ModuleNotFoundError: 20 | import os 21 | 22 | if os.path.isfile("version.txt"): 23 | # Will be missing from checkouts that haven't been installed (e.g., readthedocs) 24 | version = open("version.txt", "r").read().strip() 25 | else: 26 | version = "0.0.0" 27 | git_hash = "[none]" 28 | git_branch = "[none]" 29 | 30 | from atorch.ops.op_builder.all_ops import ALL_OPS 31 | 32 | installed_ops = dict.fromkeys(ALL_OPS.keys(), False) 33 | compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) 34 | torch_info = {"version": "0.0", "cuda_version": "0.0", "hip_version": "0.0"} 35 | -------------------------------------------------------------------------------- /atorch/ops/op_builder/all_ops.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright 2023 AntGroups, Inc. 2 | # ATorch Team 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | # DeepSpeed Team 8 | 9 | import importlib 10 | import os 11 | import pkgutil 12 | 13 | try: 14 | # during installation time accelerator is visible, otherwise return atorch.ops.accelerator 15 | from atorch.ops.accelerator import get_accelerator 16 | except ImportError: 17 | raise ImportError("Import get_accelerator error") 18 | 19 | # List of all available ops 20 | 21 | # reflect all builder names into __op_builders__ 22 | op_builder_dir = get_accelerator().op_builder_dir() 23 | op_builder_module = importlib.import_module(op_builder_dir) 24 | __op_builders__ = [] 25 | 26 | for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]): # type: ignore 27 | # avoid self references 28 | if module_name != "all_ops" and module_name != "builder": 29 | module = importlib.import_module("{}.{}".format(op_builder_dir, module_name)) 30 | for member_name in module.__dir__(): 31 | if member_name.endswith("Builder"): 32 | # append builder to __op_builders__ list 33 | builder = get_accelerator().create_op_builder(member_name) 34 | __op_builders__.append(builder) 35 | 36 | ALL_OPS = {op.name: op for op in __op_builders__ if op is not None} 37 | -------------------------------------------------------------------------------- /atorch/ops/op_builder/quantization_optimizer.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright 2023 AntGroups, Inc. 2 | # ATorch Team 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | # DeepSpeed Team 8 | 9 | from .builder import CUDAOpBuilder 10 | 11 | 12 | class QuantizationOptimizerBuilder(CUDAOpBuilder): 13 | BUILD_VAR = "ATorch_BUILD_QUANTIZATION_OPTIMIZER" 14 | NAME = "quantization_optimizer" 15 | 16 | def __init__(self, name=None): 17 | name = self.NAME if name is None else name 18 | super().__init__(name=name) 19 | 20 | def absolute_name(self): 21 | return f"atorch.ops.quantization_optimizer.{self.NAME}_op" 22 | 23 | def sources(self): 24 | return [ 25 | "csrc/quantization/quantization_optimizer.cc", 26 | "csrc/quantization/quantization_optimizer.cu", 27 | ] 28 | 29 | def include_paths(self): 30 | return ["csrc/includes"] 31 | 32 | def extra_ldflags(self): 33 | return ["-lcurand"] 34 | -------------------------------------------------------------------------------- /atorch/ops/op_builder/quantizer.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright 2023 AntGroups, Inc. 2 | # ATorch Team 3 | 4 | # Copyright (c) Microsoft Corporation. 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | # DeepSpeed Team 8 | 9 | from .builder import CUDAOpBuilder 10 | 11 | 12 | class QuantizerBuilder(CUDAOpBuilder): 13 | BUILD_VAR = "ATorch_BUILD_QUANTIZER" 14 | NAME = "quantizer" 15 | 16 | def __init__(self, name=None): 17 | name = self.NAME if name is None else name 18 | super().__init__(name=name) 19 | 20 | def absolute_name(self): 21 | return f"atorch.ops.quantizer.{self.NAME}_op" 22 | 23 | def sources(self): 24 | return [ 25 | "csrc/quantization/pt_binding.cpp", 26 | "csrc/quantization/quantize.cu", 27 | "csrc/quantization/dequantize.cu", 28 | "csrc/quantization/swizzled_quantize.cu", 29 | "csrc/quantization/quant_reduce.cu", 30 | ] 31 | 32 | def include_paths(self): 33 | return ["csrc/includes"] 34 | 35 | def extra_ldflags(self): 36 | return ["-lcurand"] 37 | -------------------------------------------------------------------------------- /atorch/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaml import AdamL 2 | from .agd import AGD 3 | from .bf16_optimizer import BF16Optimizer 4 | from .wsam import WeightedSAM 5 | -------------------------------------------------------------------------------- /atorch/optimizers/low_bit/__init__.py: -------------------------------------------------------------------------------- 1 | from .optim import Q_AGD, Q_CAME, Q_Adafactor, Q_AdamW 2 | -------------------------------------------------------------------------------- /atorch/optimizers/low_bit/config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | 3 | config = configparser.ConfigParser() 4 | 5 | config["M"] = { 6 | "ENABLE": "True", 7 | "THRESHOLD": "4096", 8 | "BITS": "4", 9 | "SCALE_TYPE": "group", 10 | "QUANT_TYPE": "nonlinear", 11 | "ROUND_TYPE": "real-nearest", 12 | "GROUP_SIZE": "128", 13 | "SIGNED": "True", 14 | } 15 | 16 | config["SQM"] = { 17 | "ENABLE": "True", 18 | "THRESHOLD": "4096", 19 | "BITS": "4", 20 | "SCALE_TYPE": "rank1", 21 | "QUANT_TYPE": "power-1", 22 | "ROUND_TYPE": "real-nearest", 23 | "GROUP_SIZE": "128", 24 | "SIGNED": "False", 25 | } 26 | 27 | 28 | def get_config(q_bits): 29 | config["M"]["BITS"] = str(q_bits) 30 | config["SQM"]["BITS"] = str(q_bits) 31 | 32 | return config 33 | -------------------------------------------------------------------------------- /atorch/optimizers/low_bit/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .q_adafactor import Q_Adafactor 2 | from .q_adamw import Q_AdamW 3 | from .q_agd import Q_AGD 4 | from .q_came import Q_CAME 5 | -------------------------------------------------------------------------------- /atorch/optimizers/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def disable_running_stats(model): 7 | def _disable(module): 8 | if isinstance(module, nn.BatchNorm2d): 9 | module.backup_momentum = module.momentum 10 | module.momentum = 0 11 | 12 | model.apply(_disable) 13 | 14 | 15 | def enable_running_stats(model): 16 | def _enable(module): 17 | if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"): 18 | module.momentum = module.backup_momentum 19 | 20 | model.apply(_enable) 21 | 22 | 23 | def whether_to_sync(model, sync=False): 24 | if not sync: 25 | return model.no_sync() 26 | else: 27 | return contextlib.ExitStack() 28 | -------------------------------------------------------------------------------- /atorch/pipeline_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/pipeline_parallel/__init__.py -------------------------------------------------------------------------------- /atorch/protos/__init__.py: -------------------------------------------------------------------------------- 1 | from atorch.utils.version import package_version_bigger_than 2 | 3 | 4 | if package_version_bigger_than("protobuf", "3.20.3"): 5 | from .protobuf_4_25_3 import acceleration_pb2, acceleration_pb2_grpc 6 | from .protobuf_4_25_3 import coworker_pb2, coworker_pb2_grpc 7 | else: 8 | from .protobuf_3_20_3 import acceleration_pb2, acceleration_pb2_grpc 9 | from .protobuf_3_20_3 import coworker_pb2, coworker_pb2_grpc 10 | 11 | 12 | __all__ = ["acceleration_pb2", "acceleration_pb2_grpc", "coworker_pb2", "coworker_pb2_grpc"] -------------------------------------------------------------------------------- /atorch/protos/acceleration.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package proto; 4 | 5 | import "google/protobuf/empty.proto"; 6 | 7 | message GetAutoAccelerationTaskRequest { 8 | int32 process_id = 1; 9 | } 10 | 11 | message OptimizationMethod { 12 | string name = 1; 13 | bytes config = 2; 14 | bool tunable = 3; 15 | } 16 | 17 | message Strategy { 18 | repeated OptimizationMethod opt = 1; 19 | } 20 | 21 | message AnalysisMethod { 22 | repeated string names = 1; 23 | } 24 | 25 | message AutoAccelerationTask { 26 | int32 task_id = 1; 27 | string task_type = 2; 28 | string process_mode = 3; 29 | oneof task_info { 30 | Strategy strategy = 4; 31 | AnalysisMethod analysis_method = 5; 32 | bytes parallel_group_info = 6; 33 | } 34 | int32 time_limit = 7; 35 | } 36 | 37 | message AutoAccelerationTaskResult { 38 | int32 task_id = 1; 39 | int32 process_id = 2; 40 | bool status = 3; 41 | oneof result { 42 | Strategy strategy = 4; 43 | bytes model_meta = 5; 44 | bytes dryrun_result = 6; 45 | } 46 | string task_type = 7; 47 | } 48 | 49 | service AutoAccelerationService { 50 | rpc get_task(GetAutoAccelerationTaskRequest) returns (AutoAccelerationTask); 51 | rpc report_task_result(AutoAccelerationTaskResult) 52 | returns (google.protobuf.Empty); 53 | } 54 | -------------------------------------------------------------------------------- /atorch/protos/coworker.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package protos; 4 | 5 | import "google/protobuf/empty.proto"; 6 | 7 | message BatchData { 8 | bytes data = 1; 9 | } 10 | 11 | message DataInfo { 12 | string coworker_addr = 1; 13 | int32 batch_num = 2; 14 | } 15 | 16 | service CoworkerRpcService { 17 | rpc get_batch_data(google.protobuf.Empty) returns (BatchData); 18 | } 19 | 20 | service DataInfoService { 21 | rpc report_data_info(DataInfo) returns (google.protobuf.Empty); 22 | rpc get_data_info(google.protobuf.Empty) returns (DataInfo); 23 | } -------------------------------------------------------------------------------- /atorch/protos/protobuf_3_20_3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/protos/protobuf_3_20_3/__init__.py -------------------------------------------------------------------------------- /atorch/protos/protobuf_4_25_3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/protos/protobuf_4_25_3/__init__.py -------------------------------------------------------------------------------- /atorch/protos/protobuf_4_25_3/coworker_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: coworker.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x63oworker.proto\x12\x06protos\x1a\x1bgoogle/protobuf/empty.proto\"\x19\n\tBatchData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"4\n\x08\x44\x61taInfo\x12\x15\n\rcoworker_addr\x18\x01 \x01(\t\x12\x11\n\tbatch_num\x18\x02 \x01(\x05\x32Q\n\x12\x43oworkerRpcService\x12;\n\x0eget_batch_data\x12\x16.google.protobuf.Empty\x1a\x11.protos.BatchData2\x8a\x01\n\x0f\x44\x61taInfoService\x12<\n\x10report_data_info\x12\x10.protos.DataInfo\x1a\x16.google.protobuf.Empty\x12\x39\n\rget_data_info\x12\x16.google.protobuf.Empty\x1a\x10.protos.DataInfob\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'coworker_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | 24 | DESCRIPTOR._options = None 25 | _globals['_BATCHDATA']._serialized_start=55 26 | _globals['_BATCHDATA']._serialized_end=80 27 | _globals['_DATAINFO']._serialized_start=82 28 | _globals['_DATAINFO']._serialized_end=134 29 | _globals['_COWORKERRPCSERVICE']._serialized_start=136 30 | _globals['_COWORKERRPCSERVICE']._serialized_end=217 31 | _globals['_DATAINFOSERVICE']._serialized_start=220 32 | _globals['_DATAINFOSERVICE']._serialized_end=358 33 | # @@protoc_insertion_point(module_scope) 34 | -------------------------------------------------------------------------------- /atorch/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.31.0 2 | networkx 3 | pynvml>=11.4.1 4 | grpcio 5 | pyarrow>=12.0.0 6 | pandas>=2.0.1 7 | tensorboard>=2.11.0 8 | dlrover[torch] 9 | protobuf 10 | safetensors 11 | -------------------------------------------------------------------------------- /atorch/rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/rl/__init__.py -------------------------------------------------------------------------------- /atorch/rl/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import create_dataset 2 | -------------------------------------------------------------------------------- /atorch/rl/ds_hybrid_engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/rl/ds_hybrid_engine/__init__.py -------------------------------------------------------------------------------- /atorch/rl/ds_hybrid_engine/module_inject/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/rl/ds_hybrid_engine/module_inject/__init__.py -------------------------------------------------------------------------------- /atorch/rl/ds_hybrid_engine/module_inject/containers/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import DS_LLAMAContainer, LLAMALayerPolicy 2 | -------------------------------------------------------------------------------- /atorch/rl/ds_hybrid_engine/module_inject/utils.py: -------------------------------------------------------------------------------- 1 | from deepspeed.utils import log_dist 2 | 3 | 4 | # helper function to map between DS policies and DS containers 5 | def policy_to_ds_container(**kwargs): 6 | 7 | from .containers import DS_LLAMAContainer, LLAMALayerPolicy 8 | 9 | policy_to_container = { 10 | LLAMALayerPolicy: DS_LLAMAContainer, 11 | } 12 | 13 | container = None 14 | policy = kwargs["policy"] 15 | assert policy is not None, "Policy cannot be None" 16 | policy_type = type(policy) 17 | 18 | if policy_type not in policy_to_container: 19 | log_dist(f"Policy type {policy_type} not supported", [0]) 20 | else: 21 | container = policy_to_container[policy_type](**kwargs) 22 | 23 | return container 24 | -------------------------------------------------------------------------------- /atorch/rl/ds_hybrid_engine/replace_policy.py: -------------------------------------------------------------------------------- 1 | # copy from deepspeed https://github.com/microsoft/DeepSpeed/ 2 | from .module_inject.containers import LLAMALayerPolicy 3 | 4 | replace_policies = [ 5 | LLAMALayerPolicy, 6 | ] 7 | 8 | # non-transformer-based policies 9 | generic_policies = [] # type: ignore 10 | -------------------------------------------------------------------------------- /atorch/rl/inference_backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/rl/inference_backend/__init__.py -------------------------------------------------------------------------------- /atorch/rl/inference_backend/vllm_backend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from atorch.common.log_utils import default_logger as logger 4 | 5 | try: 6 | from vllm import LLM, SamplingParams 7 | except Exception: 8 | logger.warning("vllm not installed") 9 | 10 | 11 | class VLLMComm: 12 | def __init__(self, ip, port): 13 | from vllm_comm import vllm_comm 14 | 15 | self.client = vllm_comm.vllmClient(ip, port) 16 | self.client.create_session() 17 | 18 | def send_data(self, data): 19 | self.client.send_data(data.data_ptr(), data.numel() * 2) 20 | self.client.delete_session() 21 | self.client.create_session() 22 | 23 | 24 | class VLLMBackend: 25 | def __init__(self, checkpoint_path=None, gen_kwargs=None, gpu_memory_utilization=0.4, dtype="bfloat16"): 26 | self.gen_kwargs = gen_kwargs 27 | max_tokens = self.gen_kwargs.get("max_new_tokens", 500) 28 | self.sampling_params = SamplingParams(temperature=0, max_tokens=max_tokens) 29 | self.llm = LLM( 30 | model=checkpoint_path, 31 | trust_remote_code=True, 32 | gpu_memory_utilization=gpu_memory_utilization, 33 | tensor_parallel_size=1, 34 | dtype=dtype, 35 | ) 36 | for param in self.llm.llm_engine.workers[0].model.parameters(): 37 | param.data = torch.empty(0, dtype=param.dtype, device=param.device) 38 | torch.cuda.empty_cache() 39 | self.tokenizer = None 40 | 41 | def set_train_model_weights(self, train_model): 42 | for p1, p2 in zip(self.llm.llm_engine.workers[0].model.parameters(), train_model.parameters()): 43 | p1.data = p2.data 44 | 45 | def set_tokenizer(self, tokenizer): 46 | self.llm.llm_engine.tokenizer = tokenizer 47 | 48 | def generate(self, prompts): 49 | return self.llm.generate(prompts, sampling_params=self.sampling_params) 50 | -------------------------------------------------------------------------------- /atorch/rl/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import atorch 4 | from atorch.rl.config import AtorchRLConfig 5 | from atorch.rl.data import create_dataset 6 | from atorch.rl.model_engine import ModelEngine 7 | from atorch.rl.trainer.ppo_trainer import PPOTrainer 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Process arguments") 12 | parser.add_argument("--config_file", type=str, default="my_config.yml", required=False) 13 | return parser.parse_args() 14 | 15 | 16 | def rl_train(args): 17 | config = AtorchRLConfig.load_yaml(args.config_file) 18 | 19 | atorch.init_distributed() 20 | 21 | # create model, optimizer, tokenizer, etc. 22 | engine = ModelEngine(config) 23 | 24 | # create prompt dataset 25 | dataset = create_dataset(config) 26 | 27 | # init trainer 28 | trainer = PPOTrainer(engine, dataset, config) 29 | 30 | trainer.train() 31 | 32 | 33 | if __name__ == "__main__": 34 | args = parse_args() 35 | rl_train(args) 36 | -------------------------------------------------------------------------------- /atorch/rl/model_engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_engine import ModelEngine, ModelEngineState 2 | -------------------------------------------------------------------------------- /atorch/rl/model_engine/strategy.py: -------------------------------------------------------------------------------- 1 | # This file defines some commonly used optimization strategies. 2 | from atorch.utils.import_util import import_module_from_py_file 3 | 4 | 5 | # use constant to refer actor/critic/reward_model/ref_model 6 | def get_strategy(file_path, model_type="actor"): 7 | """ 8 | Actor/critic/ref_model/reward_models' strategy 9 | could be defined in the same or seperate file. 10 | If they are defined in the same file, each strategy 11 | should be explicitly announced. For example: 12 | actor_strategy = [] 13 | critic_strategy = [] 14 | ref_model_strategy = [] 15 | reward_model_strategy = [] 16 | """ 17 | module = import_module_from_py_file(file_path) 18 | strategy = None 19 | role_strategy = "{}_strategy".format(model_type) 20 | if hasattr(module, role_strategy): 21 | strategy = getattr(module, role_strategy) 22 | elif hasattr(module, "strategy"): 23 | strategy = getattr(module, "strategy") 24 | else: 25 | # if user doesn't define the strategy, atorch would't modify anyting 26 | strategy = "torch_native" 27 | return strategy 28 | 29 | 30 | def ddp_strategy(): 31 | return ["parallel_mode"] 32 | 33 | 34 | def zero1_strategy(): 35 | return ["parallel_mode", "zero1"] 36 | -------------------------------------------------------------------------------- /atorch/rl/model_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/rl/model_utils/__init__.py -------------------------------------------------------------------------------- /atorch/rl/ppo_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/rl/ppo_utils/__init__.py -------------------------------------------------------------------------------- /atorch/rl/replay_buffer/__init__.py: -------------------------------------------------------------------------------- 1 | from .replay_buffer import ReplayBuffer 2 | -------------------------------------------------------------------------------- /atorch/rl/replay_buffer/replay_buffer.py: -------------------------------------------------------------------------------- 1 | from atorch.common.log_utils import default_logger as logger 2 | from atorch.rl.data.data_utils import RLTrainingDataset 3 | 4 | 5 | class ReplayBuffer: 6 | def __init__(self, config, element_keys=None): 7 | self.config = config 8 | self.element_keys = element_keys 9 | self.data = {} 10 | self.num = 0 11 | 12 | # Reset buffer 13 | def reset(self): 14 | for k in self.data.keys(): 15 | self.data[k] = [] 16 | self.num = 0 17 | 18 | def add_samples(self, samples): 19 | assert isinstance(samples, list) 20 | for sample in samples: 21 | self.add_sample(sample) 22 | 23 | # Add a sample or update a sample with index. 24 | def add_sample(self, sample, index=None): 25 | sample_keys = [k for k in sample.keys()] 26 | if self.element_keys is not None: 27 | assert set(sample_keys).issubset(set(self.element_key)), "replay buffer doesn't contains samples key" 28 | if index is not None: 29 | sample_exist = True 30 | for k in sample_keys: 31 | if len(self.data.get(k, [])) <= index: 32 | logger.warning("failed to update a sample with index {}".format(index)) 33 | sample_exist = False 34 | break 35 | if sample_exist: 36 | for k in sample_keys: 37 | self.data[k][index] = sample.get(k) 38 | else: 39 | for k in sample_keys: 40 | new_sample = sample.get(k) 41 | if k not in self.data.keys(): 42 | self.data[k] = [new_sample] 43 | else: 44 | self.data[k].append(new_sample) 45 | self.num += 1 46 | 47 | # Sync buffer in process_group using allgather. 48 | def sync(self, process_group=None): 49 | pass 50 | 51 | # Create a dataset 52 | def create_dataset(self): 53 | return RLTrainingDataset(self) 54 | -------------------------------------------------------------------------------- /atorch/rl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .rl_trainer import RLTrainer 2 | -------------------------------------------------------------------------------- /atorch/rl/trainer/ppo_trainer.py: -------------------------------------------------------------------------------- 1 | from atorch.rl.trainer import RLTrainer 2 | 3 | 4 | class PPOTrainer(RLTrainer): 5 | def __init__(self, model_engine, dataset, config): 6 | super().__init__(model_engine, dataset, config) 7 | 8 | def make_experience(self, data): 9 | pass 10 | 11 | def rl_training(self): 12 | pass 13 | -------------------------------------------------------------------------------- /atorch/service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/service/__init__.py -------------------------------------------------------------------------------- /atorch/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/tensor_parallel/__init__.py -------------------------------------------------------------------------------- /atorch/tests/common_tests/acc_executor_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from atorch.auto.engine.acceleration_engine import AccelerationEngine 5 | from atorch.auto.engine.task import TaskType 6 | 7 | os.environ["BO_SG_MAX_IETR"] = "2" 8 | 9 | 10 | def process_task(tasks, executor): 11 | for (task, process_id) in tasks: 12 | result = None 13 | if task.task_type == TaskType.ANALYSE: 14 | result = {"model_params_num": 10000, "model_params_mb": 40000} 15 | if task.task_type == TaskType.TUNE: 16 | result = task.task_info 17 | if task.task_type == TaskType.DRYRUN: 18 | result = {"throughput": min(task.task_id + 2.0, 2)} 19 | if task.task_type != TaskType.WAIT: 20 | executor.report_task_result(task.task_id, process_id, True, result) 21 | 22 | 23 | class TestExecutor(unittest.TestCase): 24 | def test_executor(self): 25 | device_context = {"node_num": 1, "nproc_per_node": 2} 26 | 27 | executor = AccelerationEngine.create_executor(device_context=device_context) 28 | 29 | process_running = [True for _ in range(2)] 30 | while any(process_running): 31 | tasks = [] 32 | for idx, status in enumerate(process_running): 33 | if status: 34 | task = executor.get_task(idx) 35 | tasks.append((task, idx)) 36 | if task.task_type == TaskType.FINISH or task.task_type == TaskType.FAIL: 37 | self.assertTrue(task.task_type != TaskType.FAIL) 38 | process_running[idx] = False 39 | process_task(tasks, executor) 40 | self.assertTrue(executor.strategy_infos.num_strategy > 1) 41 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/analyzer_result_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from atorch.auto.engine.analyser_result import AnalyserResult 4 | 5 | 6 | class AnalyzerResultTest(unittest.TestCase): 7 | def test_analyzer_result(self): 8 | result = { 9 | "analyse_basic": { 10 | "model_params_num": {"layer_0": 10, "layer_1": 20}, 11 | "model_params_mb": {"layer_0": 0.1, "layer_1": 0.2}, 12 | }, 13 | "analyse_dynamic": {"fixed_data_size": True, "data_size": 1024}, 14 | } 15 | 16 | analyzer_result = AnalyserResult() 17 | analyzer_result.update(result) 18 | self.assertEqual( 19 | analyzer_result.get("analyse_basic"), 20 | { 21 | "model_params_num": {"layer_0": 10, "layer_1": 20}, 22 | "model_params_mb": {"layer_0": 0.1, "layer_1": 0.2}, 23 | }, 24 | ) 25 | 26 | self.assertEqual( 27 | analyzer_result.get("model_params_num"), 28 | {"layer_0": 10, "layer_1": 20}, 29 | ) 30 | 31 | self.assertEqual(analyzer_result.get("data_size"), 1024) 32 | self.assertEqual(analyzer_result.get("analyse_transformer"), None) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/device_context_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from atorch.auto.device_context import DeviceContext 6 | 7 | 8 | class DeviceContextTest(unittest.TestCase): 9 | def test_device_cpu(self): 10 | dc = DeviceContext() 11 | context = dc.detect() 12 | self.assertGreater(context["cpu_num_per_node"], 0) 13 | self.assertGreater(context["cpu_memory_per_node"], 0) 14 | 15 | @unittest.skipIf(not torch.cuda.is_available(), "Skip for non gpu device.") 16 | def test_device_gpu(self): 17 | dc = DeviceContext() 18 | context = dc.detect() 19 | self.assertGreater(context["total_gpu"], 0) 20 | self.assertGreater(context["gpu_memory"], 0) 21 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/log_util_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | 4 | from atorch.common.log_utils import Timer, TimeStats 5 | 6 | 7 | class LogUitlTests(unittest.TestCase): 8 | def test_timer(self): 9 | timer = Timer("test") 10 | timer.start() 11 | time.sleep(1) 12 | timer.end() 13 | self.assertAlmostEqual(timer.elapsed_time, 1, places=1) 14 | 15 | timestate = TimeStats("test") 16 | timestate[timer.name] = timer.elapsed_time 17 | with Timer("forward", timestate): 18 | time.sleep(1) 19 | 20 | self.assertAlmostEqual(timestate["forward"], 1, places=1) 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/optimizer_offload_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, unicode_literals 3 | 4 | import unittest 5 | 6 | import torch 7 | import torch.nn 8 | 9 | try: 10 | from atorch.optimizers.adam_offload import PartitionAdam 11 | 12 | is_apex_available = True 13 | except (ModuleNotFoundError, ImportError): 14 | is_apex_available = False 15 | 16 | 17 | @unittest.skipIf(not torch.cuda.is_available(), "Offload optimizer need apex and gpu") 18 | @unittest.skipIf(not is_apex_available, "PartitionAdam import error, need apex and gpu") 19 | class OffloadOptimizerTest(unittest.TestCase): 20 | def test_optimizer_state_dict(self): 21 | device = torch.device("cuda") 22 | model = torch.nn.Linear(10, 20).to(device) 23 | optimizer = PartitionAdam([{"params": model.parameters()}]) 24 | 25 | x = torch.randn(10, 10, device=device) 26 | y = model(x) 27 | dy = torch.randn_like(y) 28 | y.backward(dy) 29 | 30 | optimizer.state_dict() 31 | optimizer.step() 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/popen_redirect_io_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, annotations, unicode_literals 3 | 4 | import tempfile 5 | import unittest 6 | from subprocess import Popen 7 | 8 | 9 | class TestRedirectIO(unittest.TestCase): 10 | def testIO(self): 11 | """Write stdout/stderr to fd and check result""" 12 | with tempfile.NamedTemporaryFile() as stdout, tempfile.NamedTemporaryFile() as stderr: 13 | cmd = [ 14 | "python", 15 | "-c", 16 | "import sys;print(1, file=sys.stderr, flush=True);print(2, file=sys.stdout, flush=True)", 17 | ] 18 | process = Popen(cmd, shell=False, stdout=stdout, stderr=stderr, universal_newlines=True) 19 | process.wait() 20 | stdout.seek(0) 21 | stderr.seek(0) 22 | self.assertEqual(b"1\n", stderr.readline()) 23 | self.assertEqual(b"2\n", stdout.readline()) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/sparse_tensor_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | 8 | from atorch.common.util_func import find_free_port 9 | from atorch.utils.sparse import all_reduce_sparse 10 | 11 | 12 | def run_all_reduce_sparse(rank, size, backend="gloo"): 13 | """Distributed function to be implemented later.""" 14 | dist.init_process_group(backend, rank=rank, world_size=size) 15 | if rank == 0: 16 | data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.1, 1.2]]).to_sparse(2) 17 | result = all_reduce_sparse(data) 18 | else: 19 | data = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]).to_sparse(2) 20 | result = all_reduce_sparse(data) 21 | 22 | indices = torch.tensor([[1, 1], [1, 2]]) 23 | values = torch.tensor([1.1000, 3.2000]) 24 | tensor_size = torch.tensor([2, 3]) 25 | assert torch.equal(result.indices(), indices) 26 | assert torch.equal(result.values(), values) 27 | assert torch.equal(torch.tensor(result.size()), tensor_size) 28 | 29 | 30 | class SparseTensorCommunicationTest(unittest.TestCase): 31 | @unittest.skipIf(True, "Failed on gpu") 32 | def test_all_reduce_sparse(self): 33 | os.environ["MASTER_ADDR"] = "localhost" 34 | os.environ["MASTER_PORT"] = str(find_free_port()) 35 | world_size = 2 36 | mp.spawn(run_all_reduce_sparse, args=(world_size,), nprocs=world_size, join=True) 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytest 4 | 5 | torch = pytest.importorskip("torch", "2.0.0") 6 | 7 | from atorch.modules.transformer import losses # noqa: E402 8 | 9 | 10 | @unittest.skipIf(not torch.cuda.is_available(), "cuda is not available") 11 | class TestCrossEntropy(unittest.TestCase): 12 | def test_cross_entropy(self): 13 | batch = 2 14 | seq = 1024 15 | hidden = 32000 16 | loss_fn_pt = torch.nn.CrossEntropyLoss() 17 | loss_fn_atorch = losses.CrossEntropyLoss() 18 | torch.random.manual_seed(0) 19 | dtypes = [torch.float16, torch.float32] 20 | if torch.cuda.is_bf16_supported(): 21 | dtypes.append(torch.bfloat16) 22 | for dtype in dtypes: 23 | rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) 24 | with torch.device("cuda"): 25 | input_gt = torch.randn(batch * seq, hidden).requires_grad_(True) 26 | input_pt = input_gt.clone().detach().to(dtype).requires_grad_(True) 27 | input_atorch = input_gt.clone().detach().to(dtype).requires_grad_(True) 28 | target = torch.empty(batch * seq, dtype=torch.long).random_(hidden) 29 | loss_pt = loss_fn_pt(input_pt.float(), target) 30 | loss_atorch = loss_fn_atorch(input_atorch, target) 31 | loss_pt.backward() 32 | loss_atorch.backward() 33 | self.assertTrue(torch.allclose(loss_pt, loss_atorch, rtol=1e-5, atol=1e-6)) 34 | self.assertTrue(torch.allclose(input_pt.grad, input_atorch.grad, rtol=rtol, atol=atol)) 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/test_loss_spike_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import torch 5 | 6 | from atorch.utils.loss_spike_utils import TokenLossSpike 7 | 8 | 9 | @unittest.skipIf(torch.cuda.is_available(), "Skip on gpu as cpu test is enough.") 10 | class LossSpikeTest(unittest.TestCase): 11 | def init(self): 12 | self.min_iter = 100 13 | self.min_loss = 4.0 14 | 15 | sample_data_paths = [("wikipedia", "corpus/base"), ("wikipedia", "corpus/base"), ("wikipedia", "corpus/base")] 16 | each_sample_len = 2 17 | 18 | if not os.path.exists("loss_spike"): 19 | os.mkdir("loss_spike") 20 | 21 | self.loss_ins = TokenLossSpike( 22 | "loss_spike", 23 | sample_data_paths, 24 | each_sample_len, 25 | self.min_iter, 26 | self.min_loss, 27 | ) 28 | 29 | def setUp(self): 30 | self.init() 31 | 32 | def test_save(self): 33 | self.loss_ins.save_loss( 34 | "test_loss.txt", 35 | 4.05, 36 | 103, 37 | losses_str="2.44,2.33,4.05", 38 | sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2", 39 | ) 40 | 41 | def test_decode(self): 42 | self.loss_ins.decode_loss_spike("res.txt", None) 43 | 44 | def test_parse(self): 45 | self.loss_ins.parse_sample_content( 46 | losses_str="2.44,2.33,4.05", 47 | sample_infos_str="20-1-1385697-14158189-2,20-1-1385697-14158189-2,20-1-1385697-14158189-2", 48 | tokenizer=None, 49 | ) 50 | 51 | def test_fetch(self): 52 | self.loss_ins.fetch("20-1-1385697-14158189-2") 53 | -------------------------------------------------------------------------------- /atorch/tests/common_tests/unshuffled_batch_dataloder_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from torch.utils.data import Dataset 4 | 5 | from atorch.data.unshuffled_batch_dataloader import DistributedUnshuffledBatchSampler 6 | 7 | 8 | class _TestDataset(Dataset): 9 | def __init__(self, data_size=32): 10 | self.size = data_size 11 | self.data = [i for i in range(data_size)] 12 | 13 | def __len__(self): 14 | return self.size 15 | 16 | def __getitem__(self, idx): 17 | return self.data[idx] 18 | 19 | 20 | class DistributedUnshuffledBatchSamplerTest(unittest.TestCase): 21 | def test_unshuffled_batch_sampler(self): 22 | dataset = _TestDataset(data_size=64) 23 | num_replicas = 8 24 | rank = 0 25 | batch_size = 4 26 | indices = [0, 1, 2, 3, 32, 33, 34, 35] 27 | sampler = DistributedUnshuffledBatchSampler( 28 | dataset, num_replicas=num_replicas, rank=rank, batch_size=batch_size 29 | ) 30 | res = [] 31 | for i in sampler: 32 | res.append(i) 33 | self.assertListEqual(res, indices) 34 | 35 | 36 | if __name__ == "__main__": 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /atorch/tests/test_define_rl_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/tests/test_define_rl_models/__init__.py -------------------------------------------------------------------------------- /atorch/tests/test_define_rl_models/independent_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/tests/test_define_rl_models/independent_models/__init__.py -------------------------------------------------------------------------------- /atorch/tests/test_define_rl_models/independent_models/hg_model_def.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | actor: 3 | model_path: /home/glm-large-chinese 4 | model_cls: transformers.AutoModelForSeq2SeqLM 5 | model_params: 6 | features_in: 10 7 | features_out: 10 8 | train_strategy: null 9 | inference_strategy: null 10 | critic: 11 | model_path: /home/glm-large-chinese 12 | model_cls: transformers.AutoModelForSeq2SeqLM 13 | model_params: 14 | dims_in: 2 15 | dims_out: 2 16 | train_strategy: null 17 | inference_strategy: null 18 | ref_model: 19 | model_path: /home/glm-large-chinese 20 | model_cls: transformers.AutoModelForSeq2SeqLM 21 | model_params: 22 | dims_in: 2 23 | dims_out: 2 24 | inference_strategy: null 25 | reward_model: 26 | model_path: /home/glm-large-chinese 27 | model_cls: transformers.AutoModelForSeq2SeqLM 28 | model_params: 29 | dims_in: 2 30 | dims_out: 2 31 | inference_strategy: null 32 | train: 33 | seq_length: 1024 34 | batch_size: 4 35 | epoch: 1 36 | num_rollouts: 4 37 | generation: 38 | batch_size: 4 39 | epoch: 10 40 | gen_kwargs: 41 | max_new_tokens: 512 42 | top_k: 0 43 | top_p: 1.0 44 | do_sample: false 45 | gen_experience_kwargs: 46 | max_new_tokens: 512 47 | do_sample: false 48 | temperature: 1.0 49 | top_k: 50 50 | top_p: 0.95 51 | 52 | tokenizer: 53 | tokenizer_path: /home/glm-large-chinese 54 | params: 55 | truncation_side: right 56 | method: 57 | PPOConfig: 58 | ppo_epoch: 2 59 | init_kl_coef: 0.02 60 | gamma: 1 61 | lam: 0.95 62 | cliprange: 0.2 63 | cliprange_value: 0.2 64 | vf_coef: 0.1 65 | cliprange_reward: 50 66 | clip_ratio: true 67 | ent_coef: 0.01 68 | scale_reward: running 69 | ref_mean: null 70 | ref_std: null -------------------------------------------------------------------------------- /atorch/tests/test_define_rl_models/independent_models/model_definition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FakeActor(nn.Module): 6 | def __init__(self, features_in=10, features_out=10): 7 | super().__init__() 8 | self.linear = torch.nn.Linear(features_in, features_out) 9 | 10 | def forward(self, x): 11 | return self.linear(x) 12 | 13 | 14 | class FakeCritic(nn.Module): 15 | def __init__(self, dims_in=10, dims_out=10): 16 | super().__init__() 17 | self.linear = torch.nn.Linear(dims_in, dims_out) 18 | 19 | def forward(self, x): 20 | return self.linear(x) 21 | 22 | 23 | class FakeRewardModel(nn.Module): 24 | def __init__(self, dims_in=10, dims_out=10): 25 | super().__init__() 26 | self.linear = torch.nn.Linear(dims_in, dims_out) 27 | 28 | def forward(self, x): 29 | return self.linear(x) 30 | 31 | 32 | class FakeRefModel(nn.Module): 33 | def __init__(self, dims_in=10, dims_out=10): 34 | super().__init__() 35 | self.linear = torch.nn.Linear(dims_in, dims_out) 36 | 37 | def forward(self, x): 38 | return self.linear(x) 39 | -------------------------------------------------------------------------------- /atorch/tests/test_define_rl_models/independent_models/strategy.py: -------------------------------------------------------------------------------- 1 | import atorch 2 | 3 | p_mode = ([("data", atorch.world_size())], None) 4 | strategy = [("parallel_mode", p_mode), "amp_native", ("fsdp")] 5 | -------------------------------------------------------------------------------- /atorch/tests/test_define_rl_models/share_weights_models/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | actor_critic_ref: 3 | model_path: /home/glm-large-chinese 4 | model_cls: atorch.tests.test_define_rl_models.independent_models.actor_critic_ref.ActorCriticRef 5 | model_params: 6 | num_layers: 2 7 | train_strategy: ./benchmarks/glm_rlhf/sequential_case_share_weights/strategy.py 8 | inference_strategy: torch_native 9 | loss: atorch.rl.ppo_utils.ppo_util.loss 10 | optimizer: 11 | name: torch.optim.adam 12 | kwargs: 13 | lr: 1.0e-6 14 | betas: 15 | - 0.9 16 | - 0.95 17 | eps: 1.0e-8 18 | weight_decay: 0.01 19 | reward_model: 20 | model_path: /home/glm-large-chinese 21 | model_cls: benchmarks.glm_rlhf.sequential_case_share_weights.reward_model.reward_model.RewardModel 22 | train_strategy: ./benchmarks/glm_rlhf/sequential_case_share_weights/strategy.py 23 | train: 24 | seq_length: 1024 25 | batch_size: 4 26 | epoch: 1 27 | num_rollouts: 10 28 | generation: 29 | batch_size: 4 30 | epoch: 10 31 | gen_kwargs: 32 | max_new_tokens: 512 33 | top_k: 0 34 | top_p: 1.0 35 | do_sample: false 36 | gen_experience_kwargs: 37 | max_new_tokens: 512 38 | do_sample: false 39 | temperature: 1.0 40 | top_k: 50 41 | top_p: 0.95 42 | 43 | tokenizer: 44 | tokenizer_path: /home 45 | params: 46 | truncation_side: right 47 | method: 48 | PPOConfig: 49 | ppo_epoch: 2 50 | init_kl_coef: 0.02 51 | gamma: 1 52 | lam: 0.95 53 | cliprange: 0.2 54 | cliprange_value: 0.2 55 | vf_coef: 0.1 56 | cliprange_reward: 50 57 | clip_ratio: true 58 | ent_coef: 0.01 59 | scale_reward: running 60 | ref_mean: null 61 | ref_std: null 62 | -------------------------------------------------------------------------------- /atorch/tests/test_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/tests/test_modules/__init__.py -------------------------------------------------------------------------------- /atorch/tests/toy_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/tests/toy_modules/__init__.py -------------------------------------------------------------------------------- /atorch/tests/tp_modules/model_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class ModelArgs: 6 | dim: int = 512 7 | n_layers: int = 8 8 | n_heads: int = 8 9 | vocab_size: int = -1 # defined later by tokenizer 10 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 11 | norm_eps: float = 1e-5 12 | 13 | max_batch_size: int = 32 14 | max_seq_len: int = 2048 15 | -------------------------------------------------------------------------------- /atorch/tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/tests/utils/__init__.py -------------------------------------------------------------------------------- /atorch/tests/utils/test_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | import atorch 6 | 7 | 8 | def init_dist(rank, world_size): 9 | os.environ["LOCAL_RANK"] = str(rank) 10 | os.environ["RANK"] = str(rank) 11 | os.environ["WORLD_SIZE"] = str(world_size) 12 | os.environ["NPROC_PER_NODE"] = str(world_size) 13 | 14 | atorch.init_distributed("nccl") 15 | torch.cuda.device(atorch.local_rank()) 16 | -------------------------------------------------------------------------------- /atorch/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .atorch_args import AtorchArguments 2 | from .atorch_trainer import STREAMING_CKPT_DIR, AtorchTrainer 3 | -------------------------------------------------------------------------------- /atorch/trainer/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/base/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/base/async_save/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/base/async_save/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/base/async_save/megatron_async_save/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/base/async_save/megatron_async_save/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/base/atorch_module.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torch import nn 4 | from transformers.modeling_utils import PreTrainedModel 5 | 6 | 7 | class AtorchIRModel(nn.Module): 8 | def __init__(self, model: Union[PreTrainedModel]): 9 | self.origin_model = model 10 | self.model = self._convert_to_IR(model) 11 | 12 | def _convert_to_IR(self, model): 13 | return model 14 | -------------------------------------------------------------------------------- /atorch/trainer/base/ckptloader.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from atorch.trainer.args import AtorchTrainingArgs 4 | 5 | 6 | class CkptLoader(ABC): 7 | @abstractmethod 8 | def load(self, resume_from_ckpt: str, model, train_args: AtorchTrainingArgs = None, **kwargs): 9 | pass 10 | -------------------------------------------------------------------------------- /atorch/trainer/base/ckptsaver.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from transformers import TrainerCallback 4 | 5 | from atorch.trainer.args import AtorchTrainingArgs 6 | 7 | 8 | class CkptSaver(ABC, TrainerCallback): 9 | @abstractmethod 10 | def save( 11 | self, iteration: int, output_dir: str, train_args: AtorchTrainingArgs, best_model_checkpoint=None, **kwargs 12 | ) -> str: 13 | pass 14 | 15 | @abstractmethod 16 | def get_interation_path(self, output_dir: str, iteration: int, **kwargs) -> str: 17 | pass 18 | -------------------------------------------------------------------------------- /atorch/trainer/base/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | 4 | class AtorchDataloader(DataLoader): 5 | def __init__(self, dataset, sampler, **kwargs): 6 | self.dataset = dataset 7 | self.sampler = sampler 8 | 9 | def __iter__(self): 10 | pass 11 | 12 | def __len__(self): 13 | pass 14 | -------------------------------------------------------------------------------- /atorch/trainer/base/dataset.py: -------------------------------------------------------------------------------- 1 | class AtorchDataset: 2 | def from_config(cls, distributed_type, *args, **kwargs) -> "AtorchDataset": 3 | pass 4 | 5 | def __len__(self): 6 | raise NotImplementedError() 7 | 8 | def __getitem__(self, idx): 9 | raise NotImplementedError() 10 | 11 | 12 | class OdpsDataset(AtorchDataset): 13 | def __init__(self): 14 | pass 15 | 16 | 17 | class PcacheDataset(AtorchDataset): 18 | def __init__(self): 19 | pass 20 | -------------------------------------------------------------------------------- /atorch/trainer/base/dist_checkpointing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/base/dist_checkpointing/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/base/dist_checkpointing/strategies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/base/dist_checkpointing/strategies/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/base/dist_checkpointing/strategies/async_torch_save_strategy.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Tuple 3 | 4 | import torch 5 | from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy 6 | from megatron.training.checkpointing import ensure_directory_exists 7 | 8 | from atorch.trainer.base.dist_checkpointing.strategies.async_utils import AsyncRequest 9 | 10 | 11 | class AsyncTorchSaveStrategy(SaveShardedStrategy): 12 | def __init__(self, backend: str, version: int, thread_count: int = 2): 13 | super().__init__(backend, version) 14 | self.thread_count = thread_count 15 | 16 | def async_save(self, save_fn_args) -> AsyncRequest: 17 | return self._get_save_and_finalize_callbacks(save_fn_args) 18 | 19 | def save(self, state_dict, file_name: Path): 20 | """Each async strategy can be trivially used as a sync strategy.""" 21 | async_request = self.async_save([(state_dict, file_name)]) 22 | async_request.execute_sync() 23 | 24 | def _get_save_and_finalize_callbacks(self, save_fn_args: List[Tuple]) -> AsyncRequest: 25 | def save_by_torch(save_fn_args): 26 | for save_args in save_fn_args: 27 | obj = save_args[0] 28 | file_name = save_args[1] 29 | if len(obj) > 0: 30 | ensure_directory_exists(file_name) 31 | torch.save(obj, file_name) 32 | else: 33 | pass 34 | 35 | return AsyncRequest(save_by_torch, (save_fn_args,), []) 36 | 37 | def can_handle_sharded_objects(self): 38 | return False 39 | -------------------------------------------------------------------------------- /atorch/trainer/base/inferface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict 3 | 4 | import torch 5 | 6 | 7 | class Savable(ABC): 8 | @abstractmethod 9 | def save_state(self, file_path, state_dict=None, **kwargs): 10 | pass 11 | 12 | @abstractmethod 13 | def load_state(self, file_path, **kwargs): 14 | pass 15 | 16 | 17 | class AtorchStateful(Savable): 18 | """ 19 | This will be an instance of torch torch.distributed.checkpoint.stateful.Stateful, since it has the same methods 20 | as Stateful, although it doesn't implicitly inherit from Stateful class. 21 | >>> a = AtorchStateful() 22 | >>> print(isinstance(a, Stateful)) 23 | ## return True 24 | """ 25 | 26 | @abstractmethod 27 | def state_dict(self) -> Dict[str, Any]: 28 | pass 29 | 30 | @abstractmethod 31 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 32 | pass 33 | 34 | @abstractmethod 35 | def get_save_filepath(self, output_dir, **kwargs): 36 | pass 37 | 38 | def save_state(self, file_path, **kwargs): 39 | torch.save(self.state_dict(), file_path) 40 | 41 | def load_state(self, file_path, **kwargs): 42 | self.load_state_dict(torch.load(file_path)) 43 | -------------------------------------------------------------------------------- /atorch/trainer/base/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | from torch.optim import Optimizer 3 | 4 | 5 | class AtorchOptimizer(Optimizer): 6 | def __init__(self, optimizer: Optimizer, scaler): 7 | self.optimizer = optimizer 8 | self.scaler = scaler 9 | 10 | @classmethod 11 | def from_config(cls, distributed_type, configs=None, *args, **kwargs): 12 | # Optimizer 13 | # Split weights in two groups, one with weight decay and the other not. 14 | if configs is None: 15 | model_named_parameters = kwargs.get("model_named_parameters") 16 | no_decay = ["bias", "layer_norm.weight"] 17 | optimizer_grouped_parameters = [ 18 | { 19 | "params": [p for n, p in model_named_parameters if not any(nd in n for nd in no_decay)], 20 | "weight_decay": configs.weight_decay, 21 | }, 22 | { 23 | "params": [p for n, p in model_named_parameters if any(nd in n for nd in no_decay)], 24 | "weight_decay": 0.0, 25 | }, 26 | ] 27 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=kwargs.get("learning_rate")) 28 | return optimizer 29 | else: 30 | optimizer = torch.optim.AdamW(configs) 31 | 32 | return optimizer 33 | 34 | def zero_grad(self, set_to_none=True): 35 | pass 36 | 37 | def step(self, closure: None = ...) -> None: # type: ignore[override] 38 | pass 39 | 40 | def train(self): 41 | self.optimizer.train() 42 | 43 | def eval(self): 44 | self.optimizer.eval() 45 | -------------------------------------------------------------------------------- /atorch/trainer/base/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from torch.optim import lr_scheduler 4 | 5 | from atorch.trainer.base.optimizer import AtorchOptimizer 6 | 7 | 8 | class AtorchScheduler(lr_scheduler.LRScheduler): # type: ignore[name-defined] 9 | def __init__(self, scheduler, optimizers: Union[AtorchOptimizer, List[AtorchOptimizer]]): 10 | self.scheduler = scheduler 11 | self.optimizers = optimizers if isinstance(optimizers, list) else [optimizers] 12 | 13 | @classmethod 14 | def from_config(cls, distributed_type, *args, **kwargs): 15 | pass 16 | 17 | def get_last_lr(self): 18 | return self.scheduler.get_last_lr() 19 | 20 | def get_lr(self): 21 | return self.scheduler.get_lr() 22 | 23 | def step(self): 24 | pass 25 | -------------------------------------------------------------------------------- /atorch/trainer/base/train_step.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class AtorchTrainStep(ABC): 5 | """ 6 | Abstract base class of train step to regulate the three main function: get batch, loss func, and forward step. 7 | Users can customize their TrainStep by inheriting AtorchTrainStep and implementing the abstract functions. 8 | """ 9 | 10 | def __init__(self): 11 | pass 12 | 13 | @abstractmethod 14 | def get_batch_func(self, **kwargs): 15 | pass 16 | 17 | @abstractmethod 18 | def get_loss_func(self, **kwargs): 19 | pass 20 | 21 | @abstractmethod 22 | def get_forward_step_func(self, **kwargs): 23 | pass 24 | -------------------------------------------------------------------------------- /atorch/trainer/debug_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/debug_utils/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/debug_utils/debug_wrappers.py: -------------------------------------------------------------------------------- 1 | import wrapt 2 | 3 | from atorch.common.log_utils import default_logger as logger 4 | 5 | TORCH_SAVE_WRAPPED = False 6 | 7 | 8 | @wrapt.decorator 9 | def log_torch_save(wrapped, instance, args, kwargs): 10 | file_name = None 11 | if args is not None and len(args) > 1: 12 | file_name = args[1] 13 | 14 | if kwargs is not None: 15 | file_name = kwargs.get("f", file_name) 16 | 17 | file_name = file_name or "unknown" 18 | 19 | logger.info(f"Start to {wrapped.__name__} with file_name={file_name}") 20 | result = wrapped(*args, **kwargs) 21 | logger.info(f"Finish {wrapped.__name__} with file_name={file_name}, result={result}") 22 | return result 23 | 24 | 25 | def wrap_torch_save(): 26 | global TORCH_SAVE_WRAPPED 27 | if not TORCH_SAVE_WRAPPED: 28 | import torch 29 | 30 | torch.save = log_torch_save(torch.save) 31 | TORCH_SAVE_WRAPPED = True 32 | -------------------------------------------------------------------------------- /atorch/trainer/fsdp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/fsdp/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/fsdp/fsdp_ckpt_loader.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from atorch.trainer.args import AtorchTrainingArgs 4 | from atorch.trainer.base.ckptloader import CkptLoader 5 | from atorch.trainer.fsdp.fsdp_ckpt_saver import ExtraState 6 | 7 | 8 | class FsdpCkptLoader(CkptLoader, ABC): 9 | @abstractmethod 10 | def load( # type: ignore[override] 11 | self, 12 | resume_from_ckpt, 13 | model, 14 | train_args: AtorchTrainingArgs = None, 15 | optimizer=None, 16 | extra_state: ExtraState = None, 17 | ckpt_step=None, 18 | **kwargs 19 | ) -> int: 20 | """ 21 | 22 | Args: 23 | resume_from_ckpt: checkpoint folder to load from, normally is the parent of the ckpt iteration path. 24 | model: the FSDP model to load state dict into, should be the same structor as the saving model. 25 | train_args: atorch trainer args. 26 | optimizer: the FSDP optimizer to load state dict into, should be the same structor as the saving optimizer. 27 | extra_state: the customized state to load from the ckpt 28 | ckpt_step: load a certain step from the ckpt root path. If None, will read the iteration mate from the 29 | trace mate data file(e.g. latest_checkpointed_iteration.txt) 30 | **kwargs: 31 | 32 | Returns: 33 | iteration step as int 34 | 35 | """ 36 | pass 37 | -------------------------------------------------------------------------------- /atorch/trainer/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | from .megatron_dataloader import AtorchMegatronDataloader 2 | from .megatron_train_step import BertTrainStep, GPTTrainStep, MegatronTrainStep, T5TrainStep 3 | from .megatron_wrapper import AtorchMegatronEngine 4 | -------------------------------------------------------------------------------- /atorch/trainer/megatron/megatron_ckpt_loader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | import torch 6 | from megatron.training import get_args 7 | 8 | from atorch.trainer.args import AtorchTrainingArgs 9 | from atorch.trainer.base.ckptloader import CkptLoader 10 | 11 | 12 | class MegatronCkptLoader(CkptLoader): 13 | @abstractmethod 14 | def load( # type: ignore[override] 15 | self, 16 | resume_from_ckpt: Path = None, 17 | model=None, 18 | optimizer=None, 19 | scheduler=None, 20 | train_args: AtorchTrainingArgs = None, 21 | ) -> Tuple[int, int]: 22 | pass 23 | 24 | 25 | class MegatronOriginSyncLoader(MegatronCkptLoader): 26 | def load( # type: ignore[override] 27 | self, 28 | resume_from_ckpt: Path = None, 29 | model=None, 30 | optimizer=None, 31 | scheduler=None, 32 | train_args: AtorchTrainingArgs = None, 33 | ): 34 | assert model is not None, "Megatron load model should not be None" 35 | assert optimizer is not None, "Megatron load optimizer should not be None" 36 | assert scheduler is not None, "Megatron load scheduler should not be None" 37 | 38 | if resume_from_ckpt is not None: 39 | megatron_args = get_args() 40 | if isinstance(resume_from_ckpt, Path): 41 | megatron_args.load = str(resume_from_ckpt) 42 | else: # normally is str 43 | megatron_args.load = resume_from_ckpt 44 | 45 | torch.distributed.barrier() 46 | 47 | from megatron.training.checkpointing import load_checkpoint 48 | 49 | iteration, num_floating_point_operations_so_far = load_checkpoint(model, optimizer, scheduler) 50 | 51 | torch.distributed.barrier() 52 | 53 | return iteration, num_floating_point_operations_so_far 54 | -------------------------------------------------------------------------------- /atorch/trainer/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/trainer/models/__init__.py -------------------------------------------------------------------------------- /atorch/trainer/models/atorch_model.py: -------------------------------------------------------------------------------- 1 | class AtorchModel: 2 | def __init__(self): 3 | pass 4 | 5 | def load_model(self, _models): 6 | self._warpped_model = _models 7 | 8 | def forward_step(self): 9 | pass 10 | -------------------------------------------------------------------------------- /atorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .prof import AProfiler 2 | from .timer import ThroughputTimer 3 | -------------------------------------------------------------------------------- /atorch/utils/dev_utils.py: -------------------------------------------------------------------------------- 1 | def raise_not_impl(func): 2 | import functools 3 | 4 | @functools.wraps(func) 5 | def wrapper(*args, **kwargs): 6 | class_name = args[0].__class__.__name__ 7 | method_name = func.__name__ 8 | raise NotImplementedError(f"{class_name} does not implement function {method_name}") 9 | 10 | return wrapper 11 | -------------------------------------------------------------------------------- /atorch/utils/dynamic_profiler/__init__.py: -------------------------------------------------------------------------------- 1 | from ._dynamic_profile import init 2 | 3 | __all__ = ["init"] 4 | -------------------------------------------------------------------------------- /atorch/utils/fa_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from functools import lru_cache 3 | 4 | import torch 5 | from torch.cuda.amp.autocast_mode import _cast, autocast 6 | 7 | 8 | # patch fn to handle autocast 9 | def _cast_fa_fn(fa_fn): 10 | @functools.wraps(fa_fn) 11 | def new_fa_fn(*args, **kwargs): 12 | if torch.is_autocast_enabled(): 13 | cur_dtype = torch.get_autocast_gpu_dtype() 14 | with autocast(enabled=False): 15 | return fa_fn(*_cast(args, cur_dtype), **_cast(kwargs, cur_dtype)) 16 | else: 17 | return fa_fn(*args, **kwargs) 18 | 19 | return new_fa_fn 20 | 21 | 22 | @lru_cache() 23 | def patch_fa_interface_to_autocast(interface): 24 | fn_names = [i for i in dir(interface) if i.startswith("flash_attn_") and i.endswith("_func")] 25 | for fn_name in fn_names: 26 | new_fa_fn = _cast_fa_fn(getattr(interface, fn_name)) 27 | setattr(interface, fn_name, new_fa_fn) 28 | -------------------------------------------------------------------------------- /atorch/utils/gc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | 4 | class ManualGarbageCollection: 5 | def __init__(self, gc_freq=300, disable_auto_gc=False): 6 | assert gc_freq > 0, "gc_freq must be a positive integer" 7 | self.gc_freq = gc_freq 8 | if disable_auto_gc: 9 | self.disable_auto_gc() 10 | gc.collect() 11 | 12 | def enable_auto_gc(self): 13 | gc.enable() 14 | 15 | def disable_auto_gc(self): 16 | gc.disable() 17 | 18 | def run(self, step_count): 19 | if step_count > 1 and step_count % self.gc_freq == 0: 20 | gc.collect(1) 21 | -------------------------------------------------------------------------------- /atorch/utils/hooks.py: -------------------------------------------------------------------------------- 1 | from enum import auto 2 | from typing import Callable, Dict, List 3 | 4 | 5 | class ATorchHooks(object): 6 | COMPUTE_GPU_UTIL_HOOK = auto() 7 | REPORT_METRICS_HOOK = auto() 8 | ADDITIONAL_TENSORBOARD_HOOK = auto() 9 | 10 | # hooks stored as dict, key for hook_type, value for corresponding hook list. 11 | hooks: Dict[auto, List[Callable]] = {} 12 | 13 | @staticmethod 14 | def register_hook(hook_type, hook_func): 15 | if hook_type not in ATorchHooks.hooks: 16 | ATorchHooks.hooks[hook_type] = [hook_func] 17 | elif hook_func not in ATorchHooks.hooks[hook_type]: 18 | ATorchHooks.hooks[hook_type].append(hook_func) 19 | 20 | @staticmethod 21 | def remove_hook(hook_type, hook_func): 22 | if hook_type in ATorchHooks.hooks and hook_func in ATorchHooks.hooks[hook_type]: 23 | ATorchHooks.hooks[hook_type].remove(hook_func) 24 | 25 | @staticmethod 26 | def call_hooks(hook_type, *args, **kargs): 27 | if hook_type in ATorchHooks.hooks: 28 | for func in ATorchHooks.hooks[hook_type]: 29 | func(*args, **kargs) 30 | -------------------------------------------------------------------------------- /atorch/utils/inspector/__init__.py: -------------------------------------------------------------------------------- 1 | from .hooks import TensorInspector 2 | -------------------------------------------------------------------------------- /atorch/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def path_is_empty(dir_path: Path) -> bool: 5 | """ 6 | Check whether a Path is empty. 7 | Args: 8 | dir_path: a Path to check 9 | Returns: 10 | True if dir_path doesn't exist as a dir (could be dir_path is not a dir, no such Path or is None), 11 | of dir_path exists and has no file inside. 12 | False if dir_path exists as a dir and is not empty. 13 | """ 14 | if dir_path is None: 15 | return True 16 | 17 | if dir_path.exists() and dir_path.is_dir() and next(dir_path.iterdir(), None) is not None: 18 | return False 19 | else: 20 | return True 21 | -------------------------------------------------------------------------------- /atorch/utils/rank_reorder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/utils/rank_reorder/__init__.py -------------------------------------------------------------------------------- /atorch/utils/virtual_optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/atorch/utils/virtual_optimizer/__init__.py -------------------------------------------------------------------------------- /dev/docker/README.md: -------------------------------------------------------------------------------- 1 | # atorch image 2 | 3 | To build docker images 4 | ```bash 5 | # Build pytorch base image under dev/docker/base folder. 6 | sudo docker build -f Dockerfile --net host -t "easydl/pytorch_gpu_base:2.0.1-cuda12.1-cudnn8-devel" . 7 | sudo docker build -f Dockerfile-pt21 --net host -t "easydl/pytorch_gpu_base:2.1.0-cuda12.1-cudnn8-devel" . 8 | 9 | # Build atorch image 10 | sudo docker build -f dev/docker/Dockerfile-ubuntu2004-pt210 --net host -t "reg.docker.alibaba-inc.com/atorch/atorch-open:pt210" . 11 | # To build base image, usually not needed, base Dockerfile is copied from pytorch repo for reference. 12 | make -f docker.Makefile 13 | ``` 14 | 15 | We use Docker container for development. The Dockerfile can be found at dlrover/atorch/dev/docker/ 16 | 17 | ```bash 18 | # Pull Docker image based on pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel 19 | sudo docker pull "registry.cn-hangzhou.aliyuncs.com/atorch/atorch-open-20240430:pt210" 20 | # Run Docker container and mount source directory to /v 21 | sudo docker run -it --rm --net=host --shm-size=1G -v ${PWD}:/v -w /v "registry.cn-hangzhou.aliyuncs.com/atorch/atorch-open-20240430:pt210" /bin/bash 22 | ``` 23 | 24 | For development, refer to following steps: 25 | ```bash 26 | # build proto 27 | sh dev/scripts/build_proto.sh 28 | 29 | # run pre-commit 30 | sh dev/scripts/pre-commit.sh 31 | 32 | # run unittest 33 | PYTHONPATH=#ATORCH_ROOT# pytest atorch/tests 34 | 35 | # build atorch wheel 36 | sh dev/scripts/build.sh 37 | ``` 38 | -------------------------------------------------------------------------------- /dev/docker/base/.condarc: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | show_channel_urls: true 4 | default_channels: 5 | - http://mirrors.aliyun.com/anaconda/pkgs/main 6 | - http://mirrors.aliyun.com/anaconda/pkgs/r 7 | - http://mirrors.aliyun.com/anaconda/pkgs/msys2 8 | custom_channels: 9 | conda-forge: http://mirrors.aliyun.com/anaconda/cloud 10 | msys2: http://mirrors.aliyun.com/anaconda/cloud 11 | bioconda: http://mirrors.aliyun.com/anaconda/cloud 12 | menpo: http://mirrors.aliyun.com/anaconda/cloud 13 | pytorch: http://mirrors.aliyun.com/anaconda/cloud 14 | simpleitk: http://mirrors.aliyun.com/anaconda/cloud 15 | -------------------------------------------------------------------------------- /dev/docker/base/pip.conf: -------------------------------------------------------------------------------- 1 | [global] 2 | index-url = http://mirrors.aliyun.com/pypi/simple 3 | [install] 4 | trusted-host = mirrors.aliyun.com -------------------------------------------------------------------------------- /dev/docker/base/requirements.txt: -------------------------------------------------------------------------------- 1 | # Python dependencies required for development 2 | astunparse 3 | expecttest 4 | hypothesis 5 | numpy 6 | psutil 7 | pyyaml 8 | requests 9 | setuptools 10 | types-dataclasses 11 | typing-extensions 12 | sympy 13 | filelock 14 | networkx 15 | jinja2 16 | fsspec 17 | -------------------------------------------------------------------------------- /dev/docker/handle_driver_compat.sh: -------------------------------------------------------------------------------- 1 | MIN_DIRVER_VERSION=450.80.02 2 | CUDA_HOME=/usr/local/cuda 3 | 4 | version_lte(){ 5 | [ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ]; 6 | } 7 | version_lt(){ 8 | [ "$1" = "$2" ] && return 1 || version_lte $1 $2; 9 | } 10 | which nvidia-smi >> /dev/null 2>&1 && \ 11 | version_lt `nvidia-smi --query-gpu=driver_version --format=csv,noheader | awk 'NR==1{print}'` ${MIN_DIRVER_VERSION} && \ 12 | export LD_LIBRARY_PATH=${CUDA_HOME}/compat:$LD_LIBRARY_PATH 13 | 14 | -------------------------------------------------------------------------------- /dev/scripts/build.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | if [ -z "$1" ]; then 3 | version=0.1.0dev 4 | else 5 | version=$1 6 | fi 7 | 8 | # Check if version starts with "release_" 9 | if [[ $version == release_* ]]; then 10 | # version starts with "release", assign version after "release_" 11 | version=${version: 8} 12 | fi 13 | 14 | echo "Building ATorch version $version" 15 | python dev/scripts/render_setup.py --version $version 16 | python setup.py bdist_wheel 17 | -------------------------------------------------------------------------------- /dev/scripts/build_image_atorch_dev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | tag=$1 3 | dockerfile=$2 4 | 5 | if [ -z $tag ] 6 | then 7 | read -p "image tag is NULL, please input: " tag 8 | fi 9 | 10 | if [ -z $dockerfile ] 11 | then 12 | read -p "dockerfile path is empty, please input: " dockerfile 13 | fi 14 | 15 | sudo docker build -f $dockerfile --net host -t "easydl/atorch:$tag" . -------------------------------------------------------------------------------- /dev/scripts/build_proto.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | PROTOBUF_VERSION_SRC=$(pip show protobuf | grep Version | cut -d' ' -f2) 7 | GRPCIO_VERSION_SRC=$(pip show grpcio | grep Version | cut -d' ' -f2) 8 | GRPCIO_TOOLS_VERSION_SRC=$(pip show grpcio-tools | grep Version | cut -d' ' -f2) 9 | PROTOS_DIR="atorch/protos" 10 | 11 | CUR_PYTHON_VERSION=$(python3 --version | awk -F " " '{print $NF}'| awk -F. '{print $1 $2}') 12 | if [[ ${CUR_PYTHON_VERSION} == "38" ]];then 13 | PROTOBUF_VERSION_2="3.20.3" 14 | GRPCIO_VERSION_2="1.34.1" 15 | GRPCIO_TOOLS_VERSION_2="1.34.1" 16 | pip install protobuf==$PROTOBUF_VERSION_2 grpcio==$GRPCIO_VERSION_2 grpcio-tools==$GRPCIO_TOOLS_VERSION_2 17 | cp $PROTOS_DIR/*.proto $PROTOS_DIR/protobuf_3_20_3/ 18 | pushd . 19 | cd $PROTOS_DIR/protobuf_3_20_3 20 | python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. ./acceleration.proto 21 | python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. ./coworker.proto 22 | sed -i 's/import acceleration_pb2/from \. import acceleration_pb2/g' acceleration_pb2_grpc.py 23 | sed -i 's/import coworker_pb2/from \. import coworker_pb2/g' coworker_pb2_grpc.py 24 | rm *.proto 25 | popd 26 | fi 27 | 28 | PROTOBUF_VERSION_3="4.25.3" 29 | GRPCIO_VERSION_3="1.62.1" 30 | GRPCIO_TOOLS_VERSION_3="1.58.0" 31 | pip install protobuf==$PROTOBUF_VERSION_3 grpcio==$GRPCIO_VERSION_3 grpcio-tools==$GRPCIO_TOOLS_VERSION_3 32 | cp $PROTOS_DIR/*.proto $PROTOS_DIR/protobuf_4_25_3/ 33 | pushd . 34 | cd $PROTOS_DIR/protobuf_4_25_3 35 | python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. ./acceleration.proto 36 | python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. ./coworker.proto 37 | sed -i 's/import acceleration_pb2/from \. import acceleration_pb2/g' acceleration_pb2_grpc.py 38 | sed -i 's/import coworker_pb2/from \. import coworker_pb2/g' coworker_pb2_grpc.py 39 | rm *.proto 40 | popd 41 | 42 | 43 | if [[ ${CUR_PYTHON_VERSION} == "38" ]];then 44 | pip install protobuf==$PROTOBUF_VERSION_SRC grpcio==$GRPCIO_VERSION_SRC grpcio-tools==$GRPCIO_TOOLS_VERSION_SRC 45 | fi 46 | -------------------------------------------------------------------------------- /dev/scripts/import_atorch_after_build.py: -------------------------------------------------------------------------------- 1 | import atorch # type: ignore # noqa: F401 2 | from atorch.auto.accelerate import auto_accelerate # type: ignore # noqa: F401, F403 3 | from atorch.kernels import * # type: ignore # noqa: F401, F403 4 | from atorch.tensor_parallel.manual_tp import TPInfo # type: ignore # noqa: F401 5 | -------------------------------------------------------------------------------- /dev/scripts/pre-commit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | pip install pre-commit==2.21.0 4 | git config --global --add safe.directory '*' 5 | 6 | Config=.pre-commit-config.yaml 7 | 8 | py_files=$(find . -path "./atorch/protos" -prune -o -name "*.py" -print0 | tr '\0' ' ') 9 | pre-commit run -v --files ${py_files} -c ${Config} 10 | 11 | STATUS=$? 12 | 13 | if [ ${STATUS} -ne 0 ] 14 | then 15 | echo "============================== Hello Atorch =================================" 16 | echo "| |" 17 | echo "| Please check above error message. |" 18 | echo "| You can run sh dev/scripts/pre-commit.sh locally |" 19 | echo "| |" 20 | echo "============================== Hello Atorch =================================" 21 | exit ${STATUS} 22 | fi -------------------------------------------------------------------------------- /dev/scripts/render_setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, unicode_literals 3 | 4 | from argparse import ArgumentParser 5 | from string import Template 6 | 7 | if __name__ == "__main__": 8 | # Running in atorch root dir 9 | parser = ArgumentParser() 10 | parser.add_argument("--version", required=True) 11 | args = parser.parse_args() 12 | with open("setup.py.tpl", encoding="u8") as fin, open("setup.py", "w", encoding="u8") as fout: 13 | t = Template(fin.read()) 14 | fout.write(t.safe_substitute(version=args.version)) 15 | -------------------------------------------------------------------------------- /dev/scripts/test_whl_import.sh: -------------------------------------------------------------------------------- 1 | pip install --no-deps dist/atorch* 2 | python dev/scripts/import_atorch_after_build.py -------------------------------------------------------------------------------- /docs/README-AGD.md: -------------------------------------------------------------------------------- 1 |

AGD Optimizer

2 |

an Auto-switchable Optimizer using Stepwise Gradient Difference as Preconditioning Matrix

3 | 4 | We present PyTorch code for [AGD: an Auto-switchable Optimizer using Stepwise Gradient Difference as Preconditioning Matrix](https://openreview.net/forum?id=A954O4tDmU¬eId=wLS9DFtY0I), NeurIPS 2023. 5 | 6 | AGD employs the gradient difference between the current and previous steps to form the preconditioning matrix, which can dynamically transition between the adaptive and stochastic forms through an automated switching mechanism. Thanks to these dual approaches, AGD attains swifter convergence and superior generalization performance compared to state-of-the-art optimizers. 7 | 8 |

9 | Toy example on 10 |

11 | 12 | ## Usage 13 | 14 | AGD can be a drop-in replacement for AdamW. 15 | 16 | ```python 17 | from atorch.optimizers.agd import AGD 18 | ``` 19 | 20 | ## Hyperparameters 21 | 22 | - `lr`: Empirically set to 1/10 of AdamW's value. 23 | - `delta`: Please refer to the settings in the paper. For Transformer-like models, you can typically keep the default value at 1e-14. 24 | - `clip`: Generally, there's no need to set it, but if you encounter training instability, you can try clip=5. 25 | - Others: Set them based on general empirical guidelines. 26 | 27 | ## AGD's performance on nanoGPT 28 | 29 | Given the popularity of large-scale models, we also tested the effectiveness of AGD on nanoGPT. As expected, AGD converges very quickly, providing up to a 1.5x acceleration compared to AdamW. This can significantly save training time and reduce training costs. 30 | 31 |

32 | 33 |

34 | -------------------------------------------------------------------------------- /docs/developer_guide.md: -------------------------------------------------------------------------------- 1 | # Introduction to develop DlRover/atorch 2 | 3 | The document describes how to make contribution to atorch in DLRover, see DlRover/docs/developer_guide.md firstly. 4 | 5 | ## Submit a PR 6 | 7 | - Fork DLRover Repo to your owner namespace. 8 | - `git clone git@github.com:intelligent-machine-learning/dlrover.git` 9 | - `cd dlrover` 10 | - `git remote rename origin upstream` 11 | - `git remote add origin ${YOUR OWNER REPO}` 12 | - `git checkout -b {DEV-BRANCH}` 13 | - `git push -u origin {DEV-BRANCH}` 14 | 15 | Then, you create a PR on your own github repo. If you has modified DLRover/atorch codes of the repo, 16 | you need to execute `pre-commit` to check codestyle and unittest cases 17 | by the following steps. 18 | 19 | - ```docker run -v `pwd`:/dlrover -it easydl/atorch:aci /bin/bash``` 20 | - `cd /dlrover/atorch` 21 | - `bash dev/scripts/pre-commit.sh` 22 | - `exit` 23 | - ```docker run -v `pwd`:/dlrover -it easydl/atorch:iml_pt210 /bin/bash``` 24 | - `pip install pytest dlrover[torch] pandas Gpy ` 25 | - `pip install accelerate datasets==2.14.6 peft==0.4.0 scikit-learn pymoo==0.5.0` 26 | - `echo -e 'import math\ninf = math.inf\nnan = math.nan\nstring_classes = (str, bytes)' > /opt/conda/lib/python3.8/site-packages/torch/_six.py` 27 | - `PYTHONPATH=. pytest atorch/tests` 28 | - `cd ..` 29 | - `git config --global --add safe.directory /github/workspace` 30 | - `git clean -xdf` 31 | 32 | Otherwise,follow the testing steps in DlRover/docs/developer_guide.md. -------------------------------------------------------------------------------- /docs/feature_required_packages.md: -------------------------------------------------------------------------------- 1 | # Additional Requirements for ATorch Features 2 | 3 | Some ATorch features require more Python packages than those listed in ATorch's requirement. Users need to install corresponding requirements when using these features. 4 | 5 | ## atorch.modules.moe.grouped_gemm_moe.Grouped_GEMM_MoE 6 | - grouped_gemm 7 | - megablocks (if implementation_type="MegaBlocks") 8 | 9 | ## fp8 10 | - transformer_engine 11 | 12 | ## atorch.rl 13 | - deepspeed 14 | 15 | ## ATorch megatron trainer 16 | - megatron 17 | 18 | ## auto_accelerate fully automatic mode 19 | - pymoo==0.5.0 20 | - GPy 21 | 22 | ## apex fused kernels 23 | - apex 24 | -------------------------------------------------------------------------------- /docs/img/agd_beale.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/docs/img/agd_beale.gif -------------------------------------------------------------------------------- /docs/img/agd_nanogpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/docs/img/agd_nanogpt.png -------------------------------------------------------------------------------- /docs/img/atorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/docs/img/atorch.png -------------------------------------------------------------------------------- /docs/img/atorch_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/docs/img/atorch_fig.png -------------------------------------------------------------------------------- /docs/img/edit_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/docs/img/edit_illustration.png -------------------------------------------------------------------------------- /docs/img/wsam_traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/docs/img/wsam_traj.png -------------------------------------------------------------------------------- /examples/auto_accelerate/train_gpt2_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 3 | WORLD_SIZE=${WORLD_SIZE:-1} 4 | NUM_GPUS=$((NUM_GPUS_PER_NODE * WORLD_SIZE)) 5 | 6 | python -m atorch.distributed.run --nnodes="$WORLD_SIZE" \ 7 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 8 | train.py --model_type gpt2 \ 9 | --distributed \ 10 | --hidden_size 64 \ 11 | --head_num 4 \ 12 | --layer_num 4 \ 13 | --seq_length 32 \ 14 | --load_strategy \ 15 | --use_fsdp \ 16 | --use_amp \ 17 | --use_module_replace \ 18 | 2>&1 | tee log_gpt2_"${WORLD_SIZE}"n"${NUM_GPUS}"g.txt 19 | -------------------------------------------------------------------------------- /examples/auto_accelerate/train_llama_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 3 | WORLD_SIZE=${WORLD_SIZE:-1} 4 | NUM_GPUS=$((NUM_GPUS_PER_NODE * WORLD_SIZE)) 5 | 6 | python -m atorch.distributed.run --nnodes="$WORLD_SIZE" \ 7 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 8 | train.py --model_type llama \ 9 | --distributed \ 10 | --hidden_size 64 \ 11 | --head_num 4 \ 12 | --layer_num 4 \ 13 | --seq_length 32 \ 14 | --load_strategy \ 15 | --use_fsdp \ 16 | --use_amp \ 17 | --use_module_replace \ 18 | --use_checkpointing \ 19 | --user_created_dataloader \ 20 | 2>&1 | tee log_llama_"${WORLD_SIZE}"n"${NUM_GPUS}"g.txt -------------------------------------------------------------------------------- /examples/auto_accelerate/train_toy_distributed_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 3 | WORLD_SIZE=${WORLD_SIZE:-1} 4 | NUM_GPUS=$((NUM_GPUS_PER_NODE * WORLD_SIZE)) 5 | 6 | python -m atorch.distributed.run --nnodes="$WORLD_SIZE" \ 7 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 8 | train.py --model_type toy \ 9 | --distributed \ 10 | 2>&1 | tee log_toy_distributed_"${WORLD_SIZE}"n"${NUM_GPUS}"g.txt 11 | -------------------------------------------------------------------------------- /examples/auto_accelerate/train_toy_fully_automatic_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py --model_type toy \ 3 | 2>&1 | tee log_toy_fully_automatic.txt -------------------------------------------------------------------------------- /examples/llama2/bayes_opt_sg_llama2_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ./dataset_model.sh 3 | pip install GPy 4 | pip install pymoo==0.5.0 5 | 6 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 7 | WORLD_SIZE=${WORLD_SIZE:-1} 8 | NUM_GPUS=$((NUM_GPUS_PER_NODE * WORLD_SIZE)) 9 | PER_DEVICE_TRAIN_BATCH_SIZE=4 10 | TOTAL_TRAIN_BATCH_SIZE=$((NUM_GPUS_PER_NODE * WORLD_SIZE * PER_DEVICE_TRAIN_BATCH_SIZE)) 11 | export BO_SG_MAX_IETR=12 12 | export RANDOM_SAMPLE=4 13 | 14 | 15 | python -m atorch.distributed.run --nnodes="$WORLD_SIZE" \ 16 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 17 | bayes_opt_sg_llama2.py \ 18 | --dataset_path $DATASET_PATH \ 19 | --config_name $MODEL_NAME_OR_PATH \ 20 | --tokenizer_name $MODEL_NAME_OR_PATH \ 21 | --total_train_batch_size $TOTAL_TRAIN_BATCH_SIZE \ 22 | --block_size 2048 \ 23 | --seed 42 \ 24 | --preprocessing_num_workers 12 \ 25 | --ignore_mismatched_sizes \ 26 | 2>&1 | tee log_llama2_"${WORLD_SIZE}"n"${NUM_GPUS}"g.txt 27 | -------------------------------------------------------------------------------- /examples/llama2/dataset_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(echo ~) 4 | 5 | # Dataset path, would download in `example_utils.py` if not exist 6 | DATASET_PATH=${DATASET_PATH:-$HOME/.cache/wikitext-2-raw-v1} 7 | 8 | # Llama model path, download and convert it if not exist 9 | MODEL_SIZE=${MODEL_SIZE-7B} 10 | MODEL_NAME_OR_PATH=${MODEL_NAME_OR_PATH:-$HOME/.cache/Llama-2-`echo $MODEL_SIZE|tr '[:upper:]' '[:lower:]'`-hf} 11 | if ! [[ -d $MODEL_NAME_OR_PATH && \ 12 | -f ${MODEL_NAME_OR_PATH%/}/config.json && \ 13 | -f ${MODEL_NAME_OR_PATH%/}/tokenizer_config.json && \ 14 | -f ${MODEL_NAME_OR_PATH%/}/tokenizer.json && \ 15 | -f ${MODEL_NAME_OR_PATH%/}/tokenizer.model ]]; then 16 | echo "$MODEL_NAME_OR_PATH not cached." 17 | mkdir -p $HOME/.cache/ 18 | pushd $HOME/.cache/ 19 | git clone https://github.com/shawwn/llama-dl.git 20 | pushd llama-dl 21 | sed 's/MODEL_SIZE="7B,13B,30B,65B"/MODEL_SIZE="'$MODEL_SIZE'"/g' llama.sh > llama$MODEL_SIZE.sh 22 | bash llama$MODEL_SIZE.sh 23 | pip install transformers sentencepiece 24 | python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir=. --model_size=$MODEL_SIZE --output_dir=$MODEL_NAME_OR_PATH 25 | popd 26 | popd 27 | fi -------------------------------------------------------------------------------- /examples/llama2/ds_3d_llama2_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | source ./dataset_model.sh 5 | 6 | WORLD_SIZE=${WORLD_SIZE:-1} 7 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 8 | 9 | PIPELINE_PARALLEL_SIZE=${PIPELINE_PARALLEL_SIZE:-2} 10 | MODEL_PARALLEL_SIZE=${MODEL_PARALLEL_SIZE:-2} 11 | BLOCK_SIZE=${BLOCK_SIZE:-4096} 12 | 13 | # ds config 14 | script_path=$(realpath $BASH_SOURCE) 15 | script_dir=$(dirname $script_path) 16 | DS_CONFIG="$script_dir/ds_config.json" 17 | MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} 18 | ACCU_STEPS=${ACCU_STEPS:-8} 19 | cat < $DS_CONFIG 20 | { 21 | "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, 22 | "gradient_accumulation_steps": $ACCU_STEPS, 23 | "steps_per_print": 50, 24 | "gradient_clipping": 1.0, 25 | "zero_optimization": { 26 | "stage": 1 27 | }, 28 | "zero_allow_untested_optimizer": true, 29 | "fp16": { 30 | "enabled": true, 31 | "loss_scale": 0, 32 | "loss_scale_window": 1000, 33 | "initial_scale_power": 16, 34 | "hysteresis": 2, 35 | "min_loss_scale": 1 36 | }, 37 | "activation_checkpointing": { 38 | "partition_activations": false, 39 | "contiguous_memory_optimization": false 40 | }, 41 | "wall_clock_breakdown": false, 42 | "pipeline": { 43 | "activation_checkpoint_interval": 1 44 | } 45 | } 46 | EOT 47 | 48 | 49 | python -u -m atorch.distributed.run \ 50 | --nnodes=$WORLD_SIZE --nproc_per_node=$NUM_GPUS_PER_NODE ds_3d_llama2.py \ 51 | --pipeline_parallel_size $PIPELINE_PARALLEL_SIZE \ 52 | --model_parallel_size $MODEL_PARALLEL_SIZE \ 53 | --block_size $BLOCK_SIZE \ 54 | --ds_config $DS_CONFIG \ 55 | --model_name_or_path $MODEL_NAME_OR_PATH \ 56 | --dataset_path $DATASET_PATH 57 | -------------------------------------------------------------------------------- /examples/llama2/fsdp_llama2_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | source ./dataset_model.sh 6 | 7 | WORLD_SIZE=${WORLD_SIZE:-1} 8 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 9 | 10 | PER_DEVICE_TRAIN_BATCH_SIZE=${PER_DEVICE_TRAIN_BATCH_SIZE:-4} 11 | BLOCK_SIZE=${BLOCK_SIZE:-4096} 12 | 13 | if [ -z "$USE_LORA" ]; then 14 | LORA_OPT="" 15 | else 16 | LORA_OPT=" 17 | --peft_type lora \ 18 | --lora_r 16 \ 19 | --lora_alpha 16 \ 20 | --lora_target_modules q_proj v_proj k_proj o_proj \ 21 | --lora_dropout 0.05 \ 22 | " 23 | fi 24 | 25 | if [ -z "$USE_FP8" ]; then 26 | FP8_OPT="" 27 | else 28 | FP8_OPT="--fp8" 29 | fi 30 | 31 | python -m atorch.distributed.run \ 32 | --nnodes="$WORLD_SIZE" \ 33 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 34 | fsdp_llama2.py \ 35 | --block_size $BLOCK_SIZE \ 36 | --model_name_or_path $MODEL_NAME_OR_PATH \ 37 | --dataset_path $DATASET_PATH \ 38 | --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ 39 | --precision bf16_amp \ 40 | --gradient_checkpointing \ 41 | $LORA_OPT $FP8_OPT -------------------------------------------------------------------------------- /examples/llama2/llama2_dummy_data_13b.sh: -------------------------------------------------------------------------------- 1 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 2 | 3 | STRATEGY_LIST="--use_fsdp --use_amp --use_module_replace --use_checkpointing" 4 | 5 | HIDDEN_SIZE=5120 6 | INTERMEDIATE_SIZE=13824 7 | HEAD_NUM=40 8 | KEY_VALUE_HEAD_NUM=40 9 | LAYER_NUM=40 10 | SEQ_LENGTH=4096 11 | 12 | BATCHSIZE_PER_GPU=4 13 | TRAIN_STEP=30 14 | 15 | EXTRA_PARAM="--use_meta_init" 16 | 17 | 18 | # Loop through all the arguments 19 | while [[ $# -gt 0 ]]; do 20 | case "$1" in 21 | --batchsize_per_gpu) 22 | BATCHSIZE_PER_GPU="$2" 23 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 24 | ;; 25 | --train_step) 26 | TRAIN_STEP="$2" 27 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 28 | ;; 29 | --max_checkpoint_module_num) 30 | EXTRA_PARAM="$EXTRA_PARAM --max_checkpoint_module_num $2" 31 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 32 | ;; 33 | --use_fp8) 34 | EXTRA_PARAM="$EXTRA_PARAM --use_fp8" 35 | ;; 36 | esac 37 | shift # Move to the next argument in the list 38 | done 39 | 40 | 41 | echo HIDDEN_SIZE=$HIDDEN_SIZE 42 | echo INTERMEDIATE_SIZE=$INTERMEDIATE_SIZE 43 | echo HEAD_NUM=$HEAD_NUM 44 | echo KEY_VALUE_HEAD_NUM=$KEY_VALUE_HEAD_NUM 45 | echo LAYER_NUM=$LAYER_NUM 46 | echo SEQ_LENGTH=$SEQ_LENGTH 47 | echo BATCHSIZE_PER_GPU = $BATCHSIZE_PER_GPU 48 | echo TRAIN_STEP = $TRAIN_STEP 49 | echo EXTRA_PARAM is $EXTRA_PARAM 50 | echo STRATEGY_LIST is $STRATEGY_LIST 51 | 52 | 53 | python -m atorch.distributed.run \ 54 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 55 | train_llama2_dummy_data.py $STRATEGY_LIST \ 56 | --hidden_size $HIDDEN_SIZE \ 57 | --intermediate_size $INTERMEDIATE_SIZE \ 58 | --head_num $HEAD_NUM \ 59 | --layer_num $LAYER_NUM \ 60 | --seq_length $SEQ_LENGTH \ 61 | --key_value_head_num $KEY_VALUE_HEAD_NUM \ 62 | --max_train_step $TRAIN_STEP \ 63 | --batchsize_per_gpu $BATCHSIZE_PER_GPU $EXTRA_PARAM 64 | -------------------------------------------------------------------------------- /examples/llama2/llama2_dummy_data_70b.sh: -------------------------------------------------------------------------------- 1 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 2 | 3 | STRATEGY_LIST="--use_fsdp --use_amp --use_module_replace --use_checkpointing" 4 | 5 | HIDDEN_SIZE=8192 6 | INTERMEDIATE_SIZE=28672 7 | HEAD_NUM=64 8 | KEY_VALUE_HEAD_NUM=8 9 | LAYER_NUM=80 10 | SEQ_LENGTH=4096 11 | 12 | BATCHSIZE_PER_GPU=4 13 | TRAIN_STEP=30 14 | 15 | EXTRA_PARAM="--use_meta_init" 16 | 17 | 18 | # Loop through all the arguments 19 | while [[ $# -gt 0 ]]; do 20 | case "$1" in 21 | --batchsize_per_gpu) 22 | BATCHSIZE_PER_GPU="$2" 23 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 24 | ;; 25 | --train_step) 26 | TRAIN_STEP="$2" 27 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 28 | ;; 29 | --max_checkpoint_module_num) 30 | EXTRA_PARAM="$EXTRA_PARAM --max_checkpoint_module_num $2" 31 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 32 | ;; 33 | --use_fp8) 34 | EXTRA_PARAM="$EXTRA_PARAM --use_fp8" 35 | ;; 36 | esac 37 | shift # Move to the next argument in the list 38 | done 39 | 40 | 41 | echo HIDDEN_SIZE=$HIDDEN_SIZE 42 | echo INTERMEDIATE_SIZE=$INTERMEDIATE_SIZE 43 | echo HEAD_NUM=$HEAD_NUM 44 | echo KEY_VALUE_HEAD_NUM=$KEY_VALUE_HEAD_NUM 45 | echo LAYER_NUM=$LAYER_NUM 46 | echo SEQ_LENGTH=$SEQ_LENGTH 47 | echo BATCHSIZE_PER_GPU = $BATCHSIZE_PER_GPU 48 | echo TRAIN_STEP = $TRAIN_STEP 49 | echo EXTRA_PARAM is $EXTRA_PARAM 50 | echo STRATEGY_LIST is $STRATEGY_LIST 51 | 52 | 53 | python -m atorch.distributed.run \ 54 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 55 | train_llama2_dummy_data.py $STRATEGY_LIST \ 56 | --hidden_size $HIDDEN_SIZE \ 57 | --intermediate_size $INTERMEDIATE_SIZE \ 58 | --head_num $HEAD_NUM \ 59 | --layer_num $LAYER_NUM \ 60 | --seq_length $SEQ_LENGTH \ 61 | --key_value_head_num $KEY_VALUE_HEAD_NUM \ 62 | --max_train_step $TRAIN_STEP \ 63 | --batchsize_per_gpu $BATCHSIZE_PER_GPU $EXTRA_PARAM 64 | -------------------------------------------------------------------------------- /examples/llama2/llama2_dummy_data_7b.sh: -------------------------------------------------------------------------------- 1 | NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l) 2 | 3 | STRATEGY_LIST="--use_fsdp --use_amp --use_module_replace --use_checkpointing" 4 | 5 | HIDDEN_SIZE=4096 6 | INTERMEDIATE_SIZE=11008 7 | HEAD_NUM=32 8 | KEY_VALUE_HEAD_NUM=32 9 | LAYER_NUM=32 10 | SEQ_LENGTH=2048 11 | 12 | BATCHSIZE_PER_GPU=8 13 | TRAIN_STEP=30 14 | 15 | EXTRA_PARAM= 16 | 17 | 18 | # Loop through all the arguments 19 | while [[ $# -gt 0 ]]; do 20 | case "$1" in 21 | --batchsize_per_gpu) 22 | BATCHSIZE_PER_GPU="$2" 23 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 24 | ;; 25 | --train_step) 26 | TRAIN_STEP="$2" 27 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 28 | ;; 29 | --max_checkpoint_module_num) 30 | EXTRA_PARAM="$EXTRA_PARAM --max_checkpoint_module_num $2" 31 | shift # Removes the current value of $1, making $2 become $1, $3 become $2, and so forth 32 | ;; 33 | --use_fp8) 34 | EXTRA_PARAM="$EXTRA_PARAM --use_fp8" 35 | ;; 36 | --use_meta_init) 37 | EXTRA_PARAM="$EXTRA_PARAM --use_meta_init" 38 | ;; 39 | esac 40 | shift # Move to the next argument in the list 41 | done 42 | 43 | 44 | echo HIDDEN_SIZE=$HIDDEN_SIZE 45 | echo INTERMEDIATE_SIZE=$INTERMEDIATE_SIZE 46 | echo HEAD_NUM=$HEAD_NUM 47 | echo KEY_VALUE_HEAD_NUM=$KEY_VALUE_HEAD_NUM 48 | echo LAYER_NUM=$LAYER_NUM 49 | echo SEQ_LENGTH=$SEQ_LENGTH 50 | echo BATCHSIZE_PER_GPU = $BATCHSIZE_PER_GPU 51 | echo TRAIN_STEP = $TRAIN_STEP 52 | echo EXTRA_PARAM is $EXTRA_PARAM 53 | echo STRATEGY_LIST is $STRATEGY_LIST 54 | 55 | 56 | python -m atorch.distributed.run \ 57 | --nproc_per_node="$NUM_GPUS_PER_NODE" \ 58 | train_llama2_dummy_data.py $STRATEGY_LIST \ 59 | --hidden_size $HIDDEN_SIZE \ 60 | --intermediate_size $INTERMEDIATE_SIZE \ 61 | --head_num $HEAD_NUM \ 62 | --layer_num $LAYER_NUM \ 63 | --seq_length $SEQ_LENGTH \ 64 | --key_value_head_num $KEY_VALUE_HEAD_NUM \ 65 | --max_train_step $TRAIN_STEP \ 66 | --batchsize_per_gpu $BATCHSIZE_PER_GPU $EXTRA_PARAM 67 | -------------------------------------------------------------------------------- /examples/llama2/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets>=2.14.6 2 | peft==0.4.0 3 | modelscope 4 | atorch>=0.1.7 5 | -------------------------------------------------------------------------------- /examples/llama2_7b_ATorchTrainer/README.md: -------------------------------------------------------------------------------- 1 | # Llama2 7B Finetune By ATorchTrainer 2 | 3 | This document presents 2 examples of using ATorchTrainer api to finetune the HuggingFace Llama-2-7b-hf model, using the ways of fsdp or fsdp with lora mainly. 4 | 5 | - Note: 6 | - Llama2 model and alpaca dataset is used in the examples. The training script will automatically download them for you. Note that downloading may take quite some time. 7 | 8 | 9 | ## ATorchTrainer FSDP 10 | 11 | Fully Sharded Data Parallel (FSDP) is a default training config in ATorchTrainer. This is implemented by calling auto_accelerate API with load_strategy argument, and load_strategy specifies the training optimization method combination. 12 | 13 | ### Scripts 14 | 15 | - training file [llama2_clm_atorch_trainer.py](llama2_clm_atorch_trainer.py) 16 | 17 | - launch script [llama2_7b_trainer_entry.sh](llama2_7b_trainer_entry.sh) 18 | 19 | ```bash 20 | cd dlrover/atorch/examples/llama2_7b_ATorchTrainer 21 | pip install -r requirements.txt 22 | 23 | WORLD_SIZE=8 bash llama2_7b_trainer_entry.sh output_dir 24 | ``` 25 | 26 | 27 | ## ATorchTrainer FSDP with LoRA 28 | 29 | LoRA is compatible by ATorchTrainer FSDP training, you can load peft lora model firstly. 30 | 31 | ### Scripts 32 | 33 | - training file [llama2_clm_atorch_trainer.py](llama2_clm_atorch_trainer.py) 34 | 35 | - launch script [llama2_7b_trainer_lora_entry.sh](llama2_7b_trainer_lora_entry.sh) 36 | 37 | ```bash 38 | cd dlrover/atorch/examples/llama2_7b_AtorchTrainer 39 | pip install -r requirements.txt 40 | 41 | WORLD_SIZE=8 bash llama2_7b_trainer_lora_entry.sh output_dir 42 | ``` -------------------------------------------------------------------------------- /examples/llama2_7b_ATorchTrainer/deepspeed_configs/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto", 21 | "warmup_type": "linear" 22 | } 23 | }, 24 | "zero_optimization": { 25 | "stage": 2, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "overlap_comm": false, 29 | "reduce_scatter": true, 30 | "reduce_bucket_size": 5e8, 31 | "contiguous_gradients": true 32 | }, 33 | "steps_per_print": 50, 34 | "gradient_accumulation_steps": "auto", 35 | "gradient_clipping": "auto", 36 | "train_batch_size": "auto", 37 | "train_micro_batch_size_per_gpu": "auto" 38 | } 39 | -------------------------------------------------------------------------------- /examples/llama2_7b_ATorchTrainer/prepare_dataset_and_weight.sh: -------------------------------------------------------------------------------- 1 | cd dataset_and_weight/ 2 | git clone https://github.com/gururise/AlpacaDataCleaned.git 3 | git clone https://github.com/huggingface/evaluate.git 4 | 5 | MODEL_SIZE="7B" 6 | MODEL_NAME_OR_PATH=${MODEL_NAME_OR_PATH:-$HOME/.cache/Llama-2-`echo $MODEL_SIZE|tr '[:upper:]' '[:lower:]'`-hf} 7 | if ! [[ -d $MODEL_NAME_OR_PATH && \ 8 | -f ${MODEL_NAME_OR_PATH%/}/config.json && \ 9 | -f ${MODEL_NAME_OR_PATH%/}/tokenizer_config.json && \ 10 | -f ${MODEL_NAME_OR_PATH%/}/tokenizer.json && \ 11 | -f ${MODEL_NAME_OR_PATH%/}/tokenizer.model ]]; then 12 | echo "$MODEL_NAME_OR_PATH not cached." 13 | mkdir -p $HOME/.cache/ 14 | pushd $HOME/.cache/ 15 | git clone https://github.com/shawwn/llama-dl.git 16 | pushd llama-dl 17 | sed 's/MODEL_SIZE="7B,13B,30B,65B"/MODEL_SIZE="'$MODEL_SIZE'"/g' llama.sh > llama$MODEL_SIZE.sh 18 | bash llama$MODEL_SIZE.sh 19 | pip install transformers sentencepiece 20 | python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir=. --model_size=$MODEL_SIZE --output_dir=$MODEL_NAME_OR_PATH 21 | popd 22 | popd 23 | fi 24 | cd .. -------------------------------------------------------------------------------- /examples/llama2_7b_ATorchTrainer/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.31.0 2 | peft==0.4.0 3 | datasets >= 1.8.0 4 | accelerate==0.21.0 5 | evaluate 6 | scikit-learn 7 | matplotlib 8 | tensorboard 9 | atorch>=1.1.0 -------------------------------------------------------------------------------- /examples/local_sgd/auto_accelerate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/examples/local_sgd/auto_accelerate/__init__.py -------------------------------------------------------------------------------- /examples/local_sgd/auto_accelerate/run_local_sgd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m atorch.distributed.run --nproc_per_node 8 train.py \ 4 | --model_type llama \ 5 | --datasize 500 \ 6 | --distributed \ 7 | --hidden_size 64 \ 8 | --head_num 4 \ 9 | --layer_num 4 \ 10 | --seq_length 32 \ 11 | --load_strategy \ 12 | --use_fsdp \ 13 | --use_amp \ 14 | --use_module_replace \ 15 | --use_local_sgd \ 16 | --local_sgd_sync_interval 5 \ 17 | --local_sgd_warmup_steps 10 \ 18 | --clip_pseudo_grad 10 \ 19 | --gradnorm_weighted \ 20 | --skip_anomaly \ 21 | --skip_anomaly_warmup_steps 10 \ 22 | --outer_optim_class sgd -------------------------------------------------------------------------------- /examples/nanoGPTATorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/nanoGPTATorch/README.md: -------------------------------------------------------------------------------- 1 | # nanoGPTATorch 2 | 3 | NanoGPT is the simplest, fastest repository for training/finetuning medium-sized GPTs. This example is modified from nanoGPT for the adaptation of atorch, see [source repo](https://github.com/karpathy/nanoGPT) for detail. 4 | 5 | ## requirements 6 | 7 | ``` 8 | pip install torch numpy transformers datasets tiktoken tqdm 9 | ``` 10 | 11 | Dependencies: 12 | 13 | - [pytorch](https://pytorch.org) <3 14 | - [numpy](https://numpy.org/install/) <3 15 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints) 16 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) 17 | - `tiktoken` for OpenAI's fast BPE code <3 18 | - `tqdm` for progress bars <3 19 | 20 | 21 | ## Usage 22 | 23 | The example uses the default config values designed to train a gpt2 (124M) on OpenWebText, you can check and change config values in train_atorch.py. Several default config values: 24 | 25 | - backend: "nccl" # 'nccl', 'gloo', etc. 26 | - device: "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 27 | - dtype: "float16" 28 | - train_type: "fsdp" 29 | 30 | ### Scripts 31 | The script will download openwebtext dataset firstly, then launch atorch training. 32 | 33 | ``` 34 | bash train_atorch_entry.sh output_dir 35 | ``` -------------------------------------------------------------------------------- /examples/nanoGPTATorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/examples/nanoGPTATorch/__init__.py -------------------------------------------------------------------------------- /examples/nanoGPTATorch/assets/gpt2_124M_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/examples/nanoGPTATorch/assets/gpt2_124M_loss.png -------------------------------------------------------------------------------- /examples/nanoGPTATorch/assets/nanogpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/examples/nanoGPTATorch/assets/nanogpt.jpg -------------------------------------------------------------------------------- /examples/nanoGPTATorch/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if "=" not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith("--") 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith("--") 32 | key, val = arg.split("=") 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /examples/nanoGPTATorch/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /examples/nanoGPTATorch/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | pip install -U huggingface_hub 3 | export HF_ENDPOINT=https://hf-mirror.com 4 | huggingface-cli download --repo-type dataset --resume-download openwebtext --local-dir openwebtext 5 | 6 | cd openwebtext/subsets/ 7 | for tarfile in *.tar; do 8 | tar -xvf "$tarfile" 9 | done 10 | cd ../../ 11 | pip install datasets tiktoken 12 | python openwebtext/prepare.py -------------------------------------------------------------------------------- /examples/nanoGPTATorch/train_atorch_entry.sh: -------------------------------------------------------------------------------- 1 | export WANDB_DISABLED=true 2 | 3 | if [ ! -d openwebtext/subsets/ ]; then 4 | bash prepare_dataset.sh 5 | fi 6 | 7 | if [ ! -d /tmp ]; then 8 | mkdir /tmp 9 | fi 10 | 11 | NUM_GPUS=$(nvidia-smi -L | wc -l) 12 | 13 | if [ ! -d $1 ]; then mkdir -p $1; fi; 14 | cp -r $0 ${1} 15 | 16 | nvidia-smi >> ${1}/nanoGPT.log 17 | printenv >> ${1}/nanoGPT.log 18 | 19 | python -m atorch.distributed.launch \ 20 | --nproc_per_node $NUM_GPUS \ 21 | --master_port 20456 \ 22 | train_atorch.py 2>&1 |tee -a $3/nanoGPT.log 23 | -------------------------------------------------------------------------------- /examples/optimizer/README.md: -------------------------------------------------------------------------------- 1 | # A demo to using AGD and WSAM Optimizers 2 | 3 | ## Usage 4 | ``` 5 | python main.py [--use-gpu] [--dataset DataSet] [--model Model] [--batch-size BS] [--epochs Epochs] [--scheduler Scheduler] [--base_optimizer Base] [--lr LR] [--weight_decay WD] [--optimizer Optimizer] [--mode Mode] [--rho Rho] [...] 6 | ``` 7 | 8 | - Supported dataset: cifar10 & cifar100 9 | - Supported model: Resnet18, Resnet34, Resnet50 10 | - more parameters can be found in [main.py](./main.py) 11 | 12 | ## Example 13 | Before running experiments, please set the environment variable: 14 | ``` 15 | export CUDA_VISIBLE_DEVICES=0 16 | ``` 17 | Train Resnet18 on Cifar10 using AGD optimizer: 18 | ``` 19 | python main.py --use-gpu --dataset cifar10 --model resnet18 --batch-size 128 --epochs 200 --scheduler cosine --base_optimizer agd --lr 0.001 --eps 1e-8 --weight-decay 5e-4 20 | ``` 21 | 22 | Train Resnet18 on Cifar10 using WSAM optimizer with sgd as the base optimizer: 23 | ``` 24 | python main.py --use-gpu --dataset cifar10 --model resnet18 --batch-size 128 --epochs 200 --scheduler cosine --base_optimizer sgd --lr 0.1 --weight-decay 5e-4 --optimizer wsam --mode decouple --rho 0.2 --gamma 0.9 25 | ``` 26 | -------------------------------------------------------------------------------- /examples/optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelligent-machine-learning/atorch/6651b69ee690d885380ba6b9f78fcdcb1a08e4dc/examples/optimizer/__init__.py -------------------------------------------------------------------------------- /examples/optimizer/train_agd_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | DATASET=cifar10 5 | MODEL=resnet18 6 | BATCH_SIZE=128 7 | EPOCHS=200 8 | SCHEDULER=cosine 9 | BASE_OPTIMIZER=sgd 10 | LR=0.001 11 | WEIGHT_DECAY=5e-4 12 | 13 | python main.py --use-gpu \ 14 | --dataset $DATASET \ 15 | --model $MODEL \ 16 | --batch-size $BATCH_SIZE \ 17 | --epochs $EPOCHS \ 18 | --scheduler $SCHEDULER \ 19 | --base_optimizer $BASE_OPTIMIZER \ 20 | --lr $LR \ 21 | --eps 1e-8 \ 22 | --weight-decay $WEIGHT_DECAY \ 23 | 2>&1 | tee log_agd.txt -------------------------------------------------------------------------------- /examples/optimizer/train_wsam_entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | DATASET=cifar10 5 | MODEL=resnet18 6 | BATCH_SIZE=128 7 | EPOCHS=200 8 | SCHEDULER=cosine 9 | BASE_OPTIMIZER=sgd 10 | LR=0.1 11 | WEIGHT_DECAY=5e-4 12 | 13 | python main.py --use-gpu \ 14 | --dataset $DATASET \ 15 | --model $MODEL \ 16 | --batch-size $BATCH_SIZE \ 17 | --epochs $EPOCHS \ 18 | --scheduler $SCHEDULER \ 19 | --base_optimizer $BASE_OPTIMIZER \ 20 | --lr $LR \ 21 | --weight-decay $WEIGHT_DECAY \ 22 | --optimizer wsam \ 23 | --mode decouple \ 24 | --rho 0.2 \ 25 | --gamma 0.9 \ 26 | 2>&1 | tee log_wsam.txt -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | core24: tests to run in pytorch24 + python 3.10 4 | fp8: tests for fp8 5 | --------------------------------------------------------------------------------