├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ └── main.yml ├── .gitignore ├── .gitlab-ci.yml ├── AWESOME-JITTOR-LIST.md ├── CHANGELOG.md ├── Dockerfile ├── LICENSE.txt ├── MANIFEST.in ├── README.cn.md ├── README.md ├── README.src.md ├── doc ├── Makefile ├── build_doc.sh ├── logo.png └── source │ ├── Jittor性能测试与对比方法.md │ ├── Jittor显存以及内存优化方法.md │ ├── Jittor调试技巧.md │ ├── README.cn.md │ ├── conf.py │ ├── index.rst │ ├── jittor.attention.md │ ├── jittor.console.md │ ├── jittor.contrib.md │ ├── jittor.dataset.md │ ├── jittor.distributions.md │ ├── jittor.init.md │ ├── jittor.linalg.md │ ├── jittor.loss3d.md │ ├── jittor.md │ ├── jittor.models.md │ ├── jittor.mpi.md │ ├── jittor.nn.md │ ├── jittor.optim.md │ ├── jittor.transform.md │ └── todo.md ├── python ├── jittor │ ├── __init__.py │ ├── __init__.pyi │ ├── attention.py │ ├── ccl │ │ ├── __init__.py │ │ ├── ccl_2d.py │ │ ├── ccl_3d.py │ │ └── ccl_link.py │ ├── compatibility │ │ ├── __init__.py │ │ ├── autograd.py │ │ ├── compiler.py │ │ ├── cuda.py │ │ ├── distributed.py │ │ ├── distributions.py │ │ ├── fft │ │ │ └── __init__.py │ │ ├── fx.py │ │ ├── gradscaler.py │ │ ├── gradscaler_old.py │ │ ├── misc.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── init.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ └── rnn.py │ │ ├── optim.py │ │ ├── src │ │ │ ├── jtorch_core.cc │ │ │ └── jtorch_core.h │ │ ├── test │ │ │ ├── test_conflict_func.py │ │ │ ├── test_function.py │ │ │ ├── test_misc.py │ │ │ └── test_tutorial.py │ │ ├── tutorial │ │ │ ├── auto_grad1.py │ │ │ ├── auto_grad2.py │ │ │ ├── auto_grad3.py │ │ │ ├── auto_grad4.py │ │ │ ├── auto_grad5_optim.py │ │ │ ├── auto_grad6_module.py │ │ │ ├── auto_grad7_dynet.py │ │ │ └── quickstart.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── _pytree.py │ │ │ ├── checkpoint.py │ │ │ ├── data.py │ │ │ ├── dtype.py │ │ │ ├── hooks.py │ │ │ └── pip_publish.py │ │ └── vision │ │ │ ├── _internally_replaced_utils.py │ │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── mnist.py │ │ │ ├── utils.py │ │ │ └── vision.py │ │ │ ├── transforms.py │ │ │ └── utils.py │ ├── compile_extern.py │ ├── compiler.py │ ├── contrib.py │ ├── dataset │ │ ├── __init__.py │ │ ├── cifar.py │ │ ├── dataset.py │ │ ├── mnist.py │ │ ├── sampler.py │ │ ├── utils.py │ │ └── voc.py │ ├── demo │ │ └── simple_cgan.py │ ├── depthwise_conv.py │ ├── distributions.py │ ├── einops │ │ ├── __init__.py │ │ ├── _backends.py │ │ ├── einops.py │ │ ├── experimental │ │ │ ├── __init__.py │ │ │ └── indexing.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── _einmix.py │ │ │ └── jittor.py │ │ └── parsing.py │ ├── extern │ │ ├── acl │ │ │ ├── acl_compiler.py │ │ │ ├── acl_error_code.cc │ │ │ ├── acl_jittor.cc │ │ │ ├── acl_jittor.h │ │ │ ├── acl_op_exec.cc │ │ │ ├── aclnn │ │ │ │ ├── aclnn.cc │ │ │ │ └── aclnn.h │ │ │ ├── aclops │ │ │ │ ├── __init__.py │ │ │ │ ├── aclops.h │ │ │ │ ├── base_op.h │ │ │ │ ├── base_op_acl.cc │ │ │ │ ├── binary_op_acl.cc │ │ │ │ ├── binary_op_acl.h │ │ │ │ ├── bmm_op.py │ │ │ │ ├── bmm_op_acl.cc │ │ │ │ ├── bmm_op_acl.h │ │ │ │ ├── concat_op.py │ │ │ │ ├── concat_op_acl.cc │ │ │ │ ├── concat_op_acl.h │ │ │ │ ├── conv_op.py │ │ │ │ ├── conv_op_acl.cc │ │ │ │ ├── conv_op_acl.h │ │ │ │ ├── cumsum_op.py │ │ │ │ ├── cumsum_op_acl.cc │ │ │ │ ├── cumsum_op_acl.h │ │ │ │ ├── dropout_op.py │ │ │ │ ├── dropout_op_acl.cc │ │ │ │ ├── dropout_op_acl.h │ │ │ │ ├── embedding_op.py │ │ │ │ ├── embedding_op_acl.cc │ │ │ │ ├── embedding_op_acl.h │ │ │ │ ├── expand_op_acl.cc │ │ │ │ ├── expand_op_acl.h │ │ │ │ ├── flashattention_op.py │ │ │ │ ├── flashattention_op_acl.cc │ │ │ │ ├── flashattention_op_acl.h │ │ │ │ ├── flip_op.py │ │ │ │ ├── flip_op_acl.cc │ │ │ │ ├── flip_op_acl.h │ │ │ │ ├── floor_op.py │ │ │ │ ├── floor_op_acl.cc │ │ │ │ ├── floor_op_acl.h │ │ │ │ ├── gather_scatter_op.py │ │ │ │ ├── gather_scatter_op_acl.cc │ │ │ │ ├── gather_scatter_op_acl.h │ │ │ │ ├── getitem_op.py │ │ │ │ ├── getitem_op_acl.cc │ │ │ │ ├── getitem_op_acl.h │ │ │ │ ├── index_op.py │ │ │ │ ├── index_op_acl.cc │ │ │ │ ├── index_op_acl.h │ │ │ │ ├── matmul_op.py │ │ │ │ ├── matmul_op_acl.cc │ │ │ │ ├── matmul_op_acl.h │ │ │ │ ├── nantonum_op.py │ │ │ │ ├── nantonum_op_acl.cc │ │ │ │ ├── nantonum_op_acl.h │ │ │ │ ├── norms_op.py │ │ │ │ ├── norms_op_acl.cc │ │ │ │ ├── norms_op_acl.h │ │ │ │ ├── pool_op.py │ │ │ │ ├── pool_op_acl.cc │ │ │ │ ├── pool_op_acl.h │ │ │ │ ├── random_op_acl.cc │ │ │ │ ├── random_op_acl.h │ │ │ │ ├── reduce_op_acl.cc │ │ │ │ ├── reduce_op_acl.h │ │ │ │ ├── relu_op.py │ │ │ │ ├── relu_op_acl.cc │ │ │ │ ├── relu_op_acl.h │ │ │ │ ├── rope_op.py │ │ │ │ ├── rope_op_acl.cc │ │ │ │ ├── rope_op_acl.h │ │ │ │ ├── setitem_op.py │ │ │ │ ├── setitem_op_acl.cc │ │ │ │ ├── setitem_op_acl.h │ │ │ │ ├── sigmoid_op.py │ │ │ │ ├── sigmoid_op_acl.cc │ │ │ │ ├── sigmoid_op_acl.h │ │ │ │ ├── silu_op.py │ │ │ │ ├── silu_op_acl.cc │ │ │ │ ├── silu_op_acl.h │ │ │ │ ├── softmax_op.py │ │ │ │ ├── softmax_op_acl.cc │ │ │ │ ├── softmax_op_acl.h │ │ │ │ ├── stack_op.py │ │ │ │ ├── stack_op_acl.cc │ │ │ │ ├── stack_op_acl.h │ │ │ │ ├── ternary_op_acl.cc │ │ │ │ ├── ternary_op_acl.h │ │ │ │ ├── transpose_op.py │ │ │ │ ├── transpose_op_acl.cc │ │ │ │ ├── transpose_op_acl.h │ │ │ │ ├── triu_op.py │ │ │ │ ├── triu_op_acl.cc │ │ │ │ ├── triu_op_acl.h │ │ │ │ ├── unary_op_acl.cc │ │ │ │ ├── unary_op_acl.h │ │ │ │ ├── utils.cc │ │ │ │ ├── utils.h │ │ │ │ ├── where_op.py │ │ │ │ ├── where_op_acl.cc │ │ │ │ └── where_op_acl.h │ │ │ └── hccl │ │ │ │ ├── inc │ │ │ │ └── hccl_wrapper.h │ │ │ │ ├── ops │ │ │ │ ├── hccl_all_gather_op.cc │ │ │ │ ├── hccl_all_gather_op.h │ │ │ │ ├── hccl_all_reduce_op.cc │ │ │ │ ├── hccl_all_reduce_op.h │ │ │ │ ├── hccl_broadcast_op.cc │ │ │ │ ├── hccl_broadcast_op.h │ │ │ │ ├── hccl_reduce_op.cc │ │ │ │ └── hccl_reduce_op.h │ │ │ │ └── src │ │ │ │ └── hccl_wrapper.cc │ │ ├── corex │ │ │ └── corex_compiler.py │ │ ├── cuda │ │ │ ├── cub │ │ │ │ ├── inc │ │ │ │ │ └── cub_test.h │ │ │ │ └── ops │ │ │ │ │ ├── cub_arg_reduce_op.cc │ │ │ │ │ ├── cub_arg_reduce_op.h │ │ │ │ │ ├── cub_argsort_op.cc │ │ │ │ │ ├── cub_argsort_op.h │ │ │ │ │ ├── cub_cumsum_op.cc │ │ │ │ │ ├── cub_cumsum_op.h │ │ │ │ │ ├── cub_test_op.cc │ │ │ │ │ ├── cub_test_op.h │ │ │ │ │ ├── cub_where_op.cc │ │ │ │ │ └── cub_where_op.h │ │ │ ├── cublas │ │ │ │ ├── inc │ │ │ │ │ └── cublas_wrapper.h │ │ │ │ ├── ops │ │ │ │ │ ├── cublas_acc_matmul_op.cc │ │ │ │ │ ├── cublas_acc_matmul_op.h │ │ │ │ │ ├── cublas_batched_matmul_op.cc │ │ │ │ │ ├── cublas_batched_matmul_op.h │ │ │ │ │ ├── cublas_matmul_op.cc │ │ │ │ │ ├── cublas_matmul_op.h │ │ │ │ │ ├── cublas_test_op.cc │ │ │ │ │ └── cublas_test_op.h │ │ │ │ └── src │ │ │ │ │ ├── cublas_matmul_test.cc │ │ │ │ │ ├── cublas_wrapper.cc │ │ │ │ │ └── helper_cublas.cc │ │ │ ├── cudnn │ │ │ │ ├── inc │ │ │ │ │ ├── cudnn_rnn_descriptor.h │ │ │ │ │ └── cudnn_wrapper.h │ │ │ │ ├── ops │ │ │ │ │ ├── cudnn_conv3d_backward_w_op.cc │ │ │ │ │ ├── cudnn_conv3d_backward_w_op.h │ │ │ │ │ ├── cudnn_conv3d_backward_x_op.cc │ │ │ │ │ ├── cudnn_conv3d_backward_x_op.h │ │ │ │ │ ├── cudnn_conv3d_op.cc │ │ │ │ │ ├── cudnn_conv3d_op.h │ │ │ │ │ ├── cudnn_conv_backward_w_op.cc │ │ │ │ │ ├── cudnn_conv_backward_w_op.h │ │ │ │ │ ├── cudnn_conv_backward_x_op.cc │ │ │ │ │ ├── cudnn_conv_backward_x_op.h │ │ │ │ │ ├── cudnn_conv_op.cc │ │ │ │ │ ├── cudnn_conv_op.h │ │ │ │ │ ├── cudnn_rnn_backward_x_op.cc │ │ │ │ │ ├── cudnn_rnn_backward_x_op.h │ │ │ │ │ ├── cudnn_rnn_op.cc │ │ │ │ │ ├── cudnn_rnn_op.h │ │ │ │ │ ├── cudnn_test_op.cc │ │ │ │ │ └── cudnn_test_op.h │ │ │ │ └── src │ │ │ │ │ ├── cudnn_conv_test.cc │ │ │ │ │ ├── cudnn_rnn_descriptor.cc │ │ │ │ │ ├── cudnn_wrapper.cc │ │ │ │ │ └── helper_cudnn.cc │ │ │ ├── cufft │ │ │ │ ├── inc │ │ │ │ │ ├── cufft_utils.h │ │ │ │ │ └── cufft_wrapper.h │ │ │ │ ├── ops │ │ │ │ │ ├── cufft_fft_op.cc │ │ │ │ │ └── cufft_fft_op.h │ │ │ │ └── src │ │ │ │ │ └── cufft_wrapper.cc │ │ │ ├── curand │ │ │ │ ├── inc │ │ │ │ │ └── curand_wrapper.h │ │ │ │ ├── ops │ │ │ │ │ ├── curand_random_op.cc │ │ │ │ │ └── curand_random_op.h │ │ │ │ └── src │ │ │ │ │ ├── curand_wrapper.cc │ │ │ │ │ └── helper_curand.cc │ │ │ ├── cusparse │ │ │ │ ├── inc │ │ │ │ │ └── cusparse_wrapper.h │ │ │ │ ├── ops │ │ │ │ │ ├── cusparse_spmmcoo_op.cc │ │ │ │ │ ├── cusparse_spmmcoo_op.h │ │ │ │ │ ├── cusparse_spmmcsr_op.cc │ │ │ │ │ └── cusparse_spmmcsr_op.h │ │ │ │ └── src │ │ │ │ │ ├── cusparse_wrapper.cc │ │ │ │ │ └── helper_cusparse.cc │ │ │ ├── cutt │ │ │ │ └── ops │ │ │ │ │ ├── cutt_test_op.cc │ │ │ │ │ ├── cutt_test_op.h │ │ │ │ │ ├── cutt_transpose_op.cc │ │ │ │ │ ├── cutt_transpose_op.h │ │ │ │ │ ├── cutt_wrapper.cc │ │ │ │ │ └── cutt_wrapper.h │ │ │ ├── inc │ │ │ │ ├── fp16_dev.h │ │ │ │ ├── fp16_emu.h │ │ │ │ ├── helper_cuda.h │ │ │ │ ├── helper_functions.h │ │ │ │ ├── helper_image.h │ │ │ │ ├── helper_string.h │ │ │ │ └── helper_timer.h │ │ │ ├── nccl │ │ │ │ ├── inc │ │ │ │ │ └── nccl_wrapper.h │ │ │ │ ├── ops │ │ │ │ │ ├── nccl_all_gather_op.cc │ │ │ │ │ ├── nccl_all_gather_op.h │ │ │ │ │ ├── nccl_all_reduce_op.cc │ │ │ │ │ ├── nccl_all_reduce_op.h │ │ │ │ │ ├── nccl_broadcast_op.cc │ │ │ │ │ ├── nccl_broadcast_op.h │ │ │ │ │ ├── nccl_reduce_op.cc │ │ │ │ │ ├── nccl_reduce_op.h │ │ │ │ │ ├── nccl_test_op.cc │ │ │ │ │ └── nccl_test_op.h │ │ │ │ └── src │ │ │ │ │ └── nccl_wrapper.cc │ │ │ └── src │ │ │ │ ├── fp16_emu.cc │ │ │ │ └── helper_cuda.cc │ │ ├── llvm │ │ │ └── jt_alignment_from_assumptions.cc │ │ ├── mkl │ │ │ └── ops │ │ │ │ ├── cpu_cnn_inference_f32.cpp │ │ │ │ ├── mkl_conv_backward_w_op.cc │ │ │ │ ├── mkl_conv_backward_w_op.h │ │ │ │ ├── mkl_conv_backward_x_op.cc │ │ │ │ ├── mkl_conv_backward_x_op.h │ │ │ │ ├── mkl_conv_op.cc │ │ │ │ ├── mkl_conv_op.h │ │ │ │ ├── mkl_matmul_op.cc │ │ │ │ ├── mkl_matmul_op.h │ │ │ │ ├── mkl_test_op.cc │ │ │ │ └── mkl_test_op.h │ │ ├── mpi │ │ │ ├── inc │ │ │ │ └── mpi_wrapper.h │ │ │ ├── ops │ │ │ │ ├── mpi_all_reduce_op.cc │ │ │ │ ├── mpi_all_reduce_op.h │ │ │ │ ├── mpi_broadcast_op.cc │ │ │ │ ├── mpi_broadcast_op.h │ │ │ │ ├── mpi_reduce_op.cc │ │ │ │ ├── mpi_reduce_op.h │ │ │ │ ├── mpi_test_op.cc │ │ │ │ └── mpi_test_op.h │ │ │ └── src │ │ │ │ └── mpi_wrapper.cc │ │ └── rocm │ │ │ ├── rocm_cache.tar.gz │ │ │ ├── rocm_compiler.py │ │ │ ├── rocm_config.cc │ │ │ ├── rocm_config.h │ │ │ ├── rocm_jittor.h │ │ │ └── rocm_wrapper.h │ ├── gradfunctional │ │ ├── __init__.py │ │ └── functional.py │ ├── init.py │ ├── init_cupy.py │ ├── linalg.py │ ├── loss3d │ │ ├── __init__.py │ │ ├── chamfer.py │ │ └── emd.py │ ├── lr_scheduler.py │ ├── math_util │ │ ├── __init__.py │ │ ├── gamma.py │ │ ├── igamma.py │ │ └── src │ │ │ ├── gamma_grad.h │ │ │ └── igamma.h │ ├── misc.py │ ├── models │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── densenet.py │ │ ├── googlenet.py │ │ ├── inception.py │ │ ├── mnasnet.py │ │ ├── mobilenet.py │ │ ├── res2net.py │ │ ├── resnet.py │ │ ├── shufflenetv2.py │ │ ├── squeezenet.py │ │ └── vgg.py │ ├── nn.py │ ├── notebook │ │ ├── 60分钟快速入门Jittor │ │ │ ├── README.md │ │ │ ├── mnist.png │ │ │ ├── 计图入门教程 0 --- 介绍与安装.ipynb │ │ │ ├── 计图入门教程 1 --- 基本概念.ipynb │ │ │ ├── 计图入门教程 2 --- 如何训练一个简单线性回归.ipynb │ │ │ └── 计图入门教程 3 --- 尝试解决一个实际问题.ipynb │ │ ├── ConditionGAN.src.md │ │ ├── LSGAN.src.md │ │ ├── __main__.py │ │ ├── basics.src.md │ │ ├── custom_op.src.md │ │ ├── example.src.md │ │ ├── figs │ │ │ └── mop.svg │ │ ├── md_to_ipynb.py │ │ ├── meta_op.src.md │ │ └── profiler.src.md │ ├── optim.py │ ├── other │ │ └── code_softmax.py │ ├── pool.py │ ├── pyjt_compiler.py │ ├── script │ │ ├── Dockerfile_cuda11 │ │ ├── build_aarch64_mkl.sh │ │ ├── converter_server.sh │ │ ├── inference_perf.py │ │ ├── install.sh │ │ ├── install_llvm.sh │ │ ├── install_mkl.sh │ │ ├── make_doc.py │ │ ├── tmpi │ │ └── update.sh │ ├── sparse.py │ ├── src │ │ ├── async_queue.h │ │ ├── common.h │ │ ├── core.h │ │ ├── event_queue.cc │ │ ├── event_queue.h │ │ ├── executor.cc │ │ ├── executor.h │ │ ├── fused_op.cc │ │ ├── fused_op.h │ │ ├── fuser.h │ │ ├── grad.cc │ │ ├── grad.h │ │ ├── graph.cc │ │ ├── graph.h │ │ ├── init.cc │ │ ├── init.h │ │ ├── jit_compiler.cc │ │ ├── jit_compiler.h │ │ ├── jit_key.cc │ │ ├── jit_key.h │ │ ├── lock.cc │ │ ├── lock.h │ │ ├── mem │ │ │ ├── allocator.cc │ │ │ ├── allocator.h │ │ │ ├── allocator │ │ │ │ ├── aligned_allocator.cc │ │ │ │ ├── aligned_allocator.h │ │ │ │ ├── cuda_device_allocator.cc │ │ │ │ ├── cuda_device_allocator.h │ │ │ │ ├── cuda_dual_allocator.cc │ │ │ │ ├── cuda_dual_allocator.h │ │ │ │ ├── cuda_host_allocator.cc │ │ │ │ ├── cuda_host_allocator.h │ │ │ │ ├── cuda_managed_allocator.cc │ │ │ │ ├── cuda_managed_allocator.h │ │ │ │ ├── foreign_allocator.cc │ │ │ │ ├── foreign_allocator.h │ │ │ │ ├── nfef_allocator.cc │ │ │ │ ├── nfef_allocator.h │ │ │ │ ├── sfrl_allocator.cc │ │ │ │ ├── sfrl_allocator.h │ │ │ │ ├── stat_allocator.cc │ │ │ │ ├── stat_allocator.h │ │ │ │ ├── temp_allocator.cc │ │ │ │ └── temp_allocator.h │ │ │ ├── mem_info.cc │ │ │ ├── mem_info.h │ │ │ ├── swap.cc │ │ │ └── swap.h │ │ ├── memory_profiler.cc │ │ ├── memory_profiler.h │ │ ├── misc │ │ │ ├── cpu_atomic.cc │ │ │ ├── cpu_atomic.h │ │ │ ├── cpu_math.cc │ │ │ ├── cpu_math.h │ │ │ ├── cstr.h │ │ │ ├── cuda_atomic.h │ │ │ ├── cuda_flags.cc │ │ │ ├── cuda_flags.h │ │ │ ├── cuda_limits.h │ │ │ ├── deleter.h │ │ │ ├── fast_shared_ptr.h │ │ │ ├── hash.h │ │ │ ├── intrin.h │ │ │ ├── miniz.cc │ │ │ ├── miniz.h │ │ │ ├── nan_checker.cc │ │ │ ├── nan_checker.cu │ │ │ ├── nan_checker.h │ │ │ ├── nano_string.cc │ │ │ ├── nano_string.h │ │ │ ├── nano_vector.h │ │ │ ├── ring_buffer.cc │ │ │ ├── ring_buffer.h │ │ │ ├── stack_vector.h │ │ │ └── string_view_map.h │ │ ├── node.h │ │ ├── numpy_func.h │ │ ├── op.cc │ │ ├── op.h │ │ ├── op_compiler.cc │ │ ├── op_compiler.h │ │ ├── ops │ │ │ ├── arg_reduce_op.cc │ │ │ ├── arg_reduce_op.h │ │ │ ├── argsort_op.cc │ │ │ ├── argsort_op.h │ │ │ ├── array_op.cc │ │ │ ├── array_op.h │ │ │ ├── binary_op.cc │ │ │ ├── binary_op.h │ │ │ ├── broadcast_to_op.cc │ │ │ ├── broadcast_to_op.h │ │ │ ├── candidate_op.cc │ │ │ ├── candidate_op.h │ │ │ ├── clone_op.cc │ │ │ ├── clone_op.h │ │ │ ├── code_op.cc │ │ │ ├── code_op.h │ │ │ ├── copy_op.cc │ │ │ ├── copy_op.h │ │ │ ├── empty_op.cc │ │ │ ├── empty_op.h │ │ │ ├── fetch_op.cc │ │ │ ├── fetch_op.h │ │ │ ├── fuse_transpose_op.cc │ │ │ ├── fuse_transpose_op.h │ │ │ ├── getitem_op.cc │ │ │ ├── getitem_op.h │ │ │ ├── index_op.cc │ │ │ ├── index_op.h │ │ │ ├── numpy_code_op.cc │ │ │ ├── numpy_code_op.h │ │ │ ├── op_register.cc │ │ │ ├── op_register.h │ │ │ ├── op_utils.cc │ │ │ ├── random_op.cc │ │ │ ├── random_op.h │ │ │ ├── reduce_op.cc │ │ │ ├── reduce_op.h │ │ │ ├── reindex_op.cc │ │ │ ├── reindex_op.h │ │ │ ├── reindex_reduce_op.cc │ │ │ ├── reindex_reduce_op.h │ │ │ ├── reshape_op.cc │ │ │ ├── reshape_op.h │ │ │ ├── safe_clip_op.cc │ │ │ ├── safe_clip_op.h │ │ │ ├── setitem_op.cc │ │ │ ├── setitem_op.h │ │ │ ├── tape_op.cc │ │ │ ├── tape_op.h │ │ │ ├── ternary_op.cc │ │ │ ├── ternary_op.h │ │ │ ├── transpose_op.cc │ │ │ ├── transpose_op.h │ │ │ ├── unary_op.cc │ │ │ ├── unary_op.h │ │ │ ├── where_op.cc │ │ │ └── where_op.h │ │ ├── opt │ │ │ ├── expr.cc │ │ │ ├── expr.h │ │ │ ├── gopt │ │ │ │ └── setitem_gopt.cc │ │ │ ├── jit_searcher.cc │ │ │ ├── jit_searcher.h │ │ │ ├── kernel_ir.cc │ │ │ ├── kernel_ir.h │ │ │ ├── pass │ │ │ │ ├── assume_aligned_pass.cc │ │ │ │ ├── assume_aligned_pass.h │ │ │ │ ├── atomic_tuner_pass.h │ │ │ │ ├── check_cache_pass.cc │ │ │ │ ├── check_cache_pass.h │ │ │ │ ├── compile_shapes_pass.cc │ │ │ │ ├── compile_shapes_pass.h │ │ │ │ ├── const_var_pass.cc │ │ │ │ ├── const_var_pass.h │ │ │ │ ├── expand_empty_block_pass.cc │ │ │ │ ├── expand_empty_block_pass.h │ │ │ │ ├── fake_main_pass.cc │ │ │ │ ├── fake_main_pass.h │ │ │ │ ├── float_atomic_fix_pass.cc │ │ │ │ ├── float_atomic_fix_pass.h │ │ │ │ ├── insert_profile_loop_pass.cc │ │ │ │ ├── insert_profile_loop_pass.h │ │ │ │ ├── loop_to_func_pass.cc │ │ │ │ ├── loop_to_func_pass.h │ │ │ │ ├── loop_var_analyze_pass.cc │ │ │ │ ├── loop_var_analyze_pass.h │ │ │ │ ├── mark_raw_pass.cc │ │ │ │ ├── mark_raw_pass.h │ │ │ │ ├── merge_loop_pass.cc │ │ │ │ ├── merge_loop_pass.h │ │ │ │ ├── merge_loop_var_pass.cc │ │ │ │ ├── merge_loop_var_pass.h │ │ │ │ ├── parallel_pass.h │ │ │ │ ├── pass.cc │ │ │ │ ├── pass.h │ │ │ │ ├── remove_intermediate_pass.cc │ │ │ │ ├── remove_intermediate_pass.h │ │ │ │ ├── remove_loop_pass.cc │ │ │ │ ├── remove_loop_pass.h │ │ │ │ ├── rename_loop_index_pass.cc │ │ │ │ ├── rename_loop_index_pass.h │ │ │ │ ├── reorder_loop_pass.cc │ │ │ │ ├── reorder_loop_pass.h │ │ │ │ ├── replace_for_num_pass.cc │ │ │ │ ├── replace_for_num_pass.h │ │ │ │ ├── restride_pass.cc │ │ │ │ ├── restride_pass.h │ │ │ │ ├── shared_reduce_pass.h │ │ │ │ ├── solve_conflict_define_pass.cc │ │ │ │ ├── solve_conflict_define_pass.h │ │ │ │ ├── split_loop_pass.cc │ │ │ │ ├── split_loop_pass.h │ │ │ │ ├── unroll_pass.cc │ │ │ │ ├── unroll_pass.h │ │ │ │ ├── use_movnt_pass.cc │ │ │ │ ├── use_movnt_pass.h │ │ │ │ ├── vectorize_pass.cc │ │ │ │ └── vectorize_pass.h │ │ │ ├── pass_manager.cc │ │ │ ├── pass_manager.h │ │ │ ├── tuner │ │ │ │ ├── broadcast_tuner.cc │ │ │ │ ├── broadcast_tuner.h │ │ │ │ ├── conv_tuner.cc │ │ │ │ ├── conv_tuner.h │ │ │ │ ├── matmul_tuner.cc │ │ │ │ ├── matmul_tuner.h │ │ │ │ ├── reduce_tuner.cc │ │ │ │ ├── reduce_tuner.h │ │ │ │ ├── reorder_tuner.cc │ │ │ │ ├── reorder_tuner.h │ │ │ │ ├── tuner.cc │ │ │ │ └── tuner.h │ │ │ ├── tuner_manager.cc │ │ │ ├── tuner_manager.h │ │ │ ├── var_relay.cc │ │ │ └── var_relay.h │ │ ├── parallel_compiler.cc │ │ ├── parallel_compiler.h │ │ ├── profiler │ │ │ ├── cache_info.cc │ │ │ ├── cache_info.h │ │ │ ├── memory_checker.cc │ │ │ ├── memory_checker.h │ │ │ ├── profiler.cc │ │ │ ├── profiler.h │ │ │ ├── profiler_guard.h │ │ │ ├── replacement.cc │ │ │ ├── replacement.h │ │ │ ├── simple_profiler.h │ │ │ └── vtop.cc │ │ ├── pybind │ │ │ ├── core.cc │ │ │ ├── py_var_tracer.cc │ │ │ ├── py_var_tracer.h │ │ │ └── py_var_tracer_interface.h │ │ ├── pyjt │ │ │ ├── numpy.cc │ │ │ ├── numpy.h │ │ │ ├── py_arg_printer.cc │ │ │ ├── py_arg_printer.h │ │ │ ├── py_array_op.cc │ │ │ ├── py_caller.cc │ │ │ ├── py_caller.h │ │ │ ├── py_converter.h │ │ │ ├── py_obj_holder.h │ │ │ ├── py_ring_buffer.cc │ │ │ ├── py_ring_buffer.h │ │ │ └── pyjt_console.h │ │ ├── test │ │ │ ├── test_expr.cc │ │ │ ├── test_fast_shared_ptr.cc │ │ │ ├── test_jit_key.cc │ │ │ ├── test_kernel_ir.cc │ │ │ ├── test_nano_vector.cc │ │ │ ├── test_op_compiler.cc │ │ │ ├── test_op_relay.cc │ │ │ ├── test_setitem_op.cc │ │ │ └── test_sfrl_allocator.cc │ │ ├── type │ │ │ ├── common_op_type.cc │ │ │ ├── fp16_compute.h │ │ │ └── fp16_op_type.cc │ │ ├── types.h │ │ ├── utils │ │ │ ├── cache_compile.cc │ │ │ ├── cache_compile.h │ │ │ ├── cross_platform.h │ │ │ ├── flags.cc │ │ │ ├── flags.h │ │ │ ├── jit_utils.cc │ │ │ ├── log.cc │ │ │ ├── log.h │ │ │ ├── mwsr_list.cc │ │ │ ├── mwsr_list.h │ │ │ ├── seh.h │ │ │ ├── str_utils.cc │ │ │ ├── str_utils.h │ │ │ ├── tracer.cc │ │ │ ├── tracer.h │ │ │ └── vdp │ │ ├── var.cc │ │ ├── var.h │ │ ├── var_holder.cc │ │ ├── var_holder.h │ │ ├── var_slices.cc │ │ └── var_slices.h │ ├── test │ │ ├── __main__.py │ │ ├── misc │ │ │ └── superglue.py │ │ ├── perf │ │ │ └── perf.py │ │ ├── system │ │ │ ├── test_all.sh │ │ │ ├── test_cuda10.0_ubuntu16.04.sh │ │ │ ├── test_cuda10.0_ubuntu18.04.sh │ │ │ ├── test_cuda11.1_ubuntu16.04.sh │ │ │ ├── test_cuda11.1_ubuntu18.04.sh │ │ │ ├── test_cuda11.1_ubuntu20.04.sh │ │ │ └── test_nocuda_ubuntu18.04.sh │ │ ├── test.h │ │ ├── test_acl.py │ │ ├── test_aclop.py │ │ ├── test_adamw.py │ │ ├── test_affine_grid.py │ │ ├── test_allocator.py │ │ ├── test_allocator2.py │ │ ├── test_arg_pool_op.py │ │ ├── test_arg_reduce_op.py │ │ ├── test_argsort_op.py │ │ ├── test_array.py │ │ ├── test_asm_tuner.py │ │ ├── test_atomic_tuner.py │ │ ├── test_attention.py │ │ ├── test_auto_diff.py │ │ ├── test_batchnorm.py │ │ ├── test_benchmark.py │ │ ├── test_bf16.py │ │ ├── test_bicubic.py │ │ ├── test_binary_op.py │ │ ├── test_bmm.py │ │ ├── test_broadcast_to_op.py │ │ ├── test_broadcast_tuner.py │ │ ├── test_cache.py │ │ ├── test_candidate.py │ │ ├── test_clone.py │ │ ├── test_code_op.py │ │ ├── test_compile_options.py │ │ ├── test_complex.py │ │ ├── test_concat_op.py │ │ ├── test_console.py │ │ ├── test_contrib.py │ │ ├── test_conv_transpose.py │ │ ├── test_conv_tuner.py │ │ ├── test_core.py │ │ ├── test_cub_cumsum.py │ │ ├── test_cublas_test_op.py │ │ ├── test_cuda.py │ │ ├── test_cudnn_op.py │ │ ├── test_cumprod_op.py │ │ ├── test_cusparse_op.py │ │ ├── test_custom_op.py │ │ ├── test_cutt.py │ │ ├── test_cutt_transpose_op.py │ │ ├── test_dataset.py │ │ ├── test_default_var.py │ │ ├── test_densenet.py │ │ ├── test_depthwise_conv.py │ │ ├── test_digamma.py │ │ ├── test_distributions.py │ │ ├── test_dtype_info.py │ │ ├── test_einops.py │ │ ├── test_einsum.py │ │ ├── test_emnist.py │ │ ├── test_error_msg.py │ │ ├── test_example.py │ │ ├── test_example_accumulate_grad.py │ │ ├── test_fetcher.py │ │ ├── test_fft_op.py │ │ ├── test_flags.py │ │ ├── test_fold.py │ │ ├── test_fp16.py │ │ ├── test_function.py │ │ ├── test_fused_op.py │ │ ├── test_fuser.py │ │ ├── test_gamma_distribution.py │ │ ├── test_getitem_simple.py │ │ ├── test_grad.py │ │ ├── test_group_conv_tuner.py │ │ ├── test_histc.py │ │ ├── test_hook.py │ │ ├── test_image_folder.py │ │ ├── test_inception.py │ │ ├── test_index_op.py │ │ ├── test_init.py │ │ ├── test_interpolation.py │ │ ├── test_jit_tests.py │ │ ├── test_jtune.py │ │ ├── test_knn.py │ │ ├── test_lazy_execution.py │ │ ├── test_linalg.py │ │ ├── test_load_pth.py │ │ ├── test_lock.py │ │ ├── test_log.py │ │ ├── test_longest_dis_fuse.py │ │ ├── test_loss.py │ │ ├── test_loss3d.py │ │ ├── test_lr_scheduler.py │ │ ├── test_lstm.py │ │ ├── test_matmul.py │ │ ├── test_matmul_tuner.py │ │ ├── test_mem.py │ │ ├── test_memory_profiler.py │ │ ├── test_merge_loop_var_pass.py │ │ ├── test_merge_single_array_op.py │ │ ├── test_misc_issue.py │ │ ├── test_misc_op.py │ │ ├── test_mkl_conv_op.py │ │ ├── test_mkl_test_op.py │ │ ├── test_models.py │ │ ├── test_mpi.py │ │ ├── test_mpi_batchnorm.py │ │ ├── test_mpi_in_py.py │ │ ├── test_mpi_op.py │ │ ├── test_nano_string.py │ │ ├── test_nano_vector.py │ │ ├── test_nccl.py │ │ ├── test_nccl_ops.py │ │ ├── test_new_fused_op.py │ │ ├── test_node.py │ │ ├── test_notebooks.py │ │ ├── test_numpy_code_op.py │ │ ├── test_op_compiler.py │ │ ├── test_opt_state_dict.py │ │ ├── test_optimizer.py │ │ ├── test_optimizer_save_load.py │ │ ├── test_pad.py │ │ ├── test_parallel_pass.py │ │ ├── test_param_list.py │ │ ├── test_profiler.py │ │ ├── test_pytorch_converter.py │ │ ├── test_random_op.py │ │ ├── test_reduce_op.py │ │ ├── test_reduce_tuner.py │ │ ├── test_reindex_op.py │ │ ├── test_reindex_reduce_op.py │ │ ├── test_relu.py │ │ ├── test_reorder_tuner.py │ │ ├── test_repeat.py │ │ ├── test_reshape.py │ │ ├── test_resize_and_crop.py │ │ ├── test_resnet.py │ │ ├── test_ring_buffer.py │ │ ├── test_ring_buffer2.py │ │ ├── test_rnn.py │ │ ├── test_rocm.py │ │ ├── test_sampler.py │ │ ├── test_search_sorted.py │ │ ├── test_searchsorted_op.py │ │ ├── test_setitem.py │ │ ├── test_single_process_scope.py │ │ ├── test_slice.py │ │ ├── test_sparse.py │ │ ├── test_stop_fuse.py │ │ ├── test_superglue.py │ │ ├── test_ternary_op.py │ │ ├── test_trace_var.py │ │ ├── test_tracer.py │ │ ├── test_transform.py │ │ ├── test_transpose_op.py │ │ ├── test_unary_op.py │ │ ├── test_unique.py │ │ ├── test_utils.py │ │ ├── test_var.py │ │ ├── test_vgg.py │ │ ├── test_weightnorm.py │ │ └── test_where_op.py │ ├── transform │ │ ├── __init__.py │ │ └── function_pil.py │ ├── utils │ │ ├── asm_tuner.py │ │ ├── bench_klo.py │ │ ├── converter_server.py │ │ ├── data.gz │ │ ├── dlink_compiler.py │ │ ├── dumpdef.py │ │ ├── gen_pyi.py │ │ ├── jtune.py │ │ ├── local_doc_builder.py │ │ ├── nvtx.py │ │ ├── polish.py │ │ ├── polish_centos.py │ │ ├── publish.py │ │ ├── pytorch_converter.py │ │ └── tracer.py │ ├── vcompiler │ │ ├── __init__.py │ │ ├── vcompiler.cc │ │ ├── vcompiler.h │ │ └── vcompiler.py │ ├── version │ └── weightnorm.py └── jittor_utils │ ├── __init__.py │ ├── auto_diff.py │ ├── auto_git_tag.py │ ├── class │ ├── motd │ ├── setup.py │ └── setup_env.py │ ├── clean_cache.py │ ├── config.py │ ├── github_release.sh │ ├── install_cuda.py │ ├── install_msvc.py │ ├── load_pytorch.py │ ├── load_pytorch_old.py │ ├── lock.py │ ├── misc.py │ ├── pack_offline.py │ ├── pip_publish.py │ ├── query_cuda_cc.py │ ├── ring_buffer.py │ ├── save_pytorch.py │ ├── student_queue.py │ └── translator.py └── setup.py /.dockerignore: -------------------------------------------------------------------------------- 1 | Dockerfile 2 | **/publish.py 3 | my 4 | .git 5 | .refresh 6 | __pycache__ 7 | .ipynb_checkpoints/ 8 | .vscode/ 9 | __res/ 10 | perf.data 11 | perf.data.old 12 | *.swp 13 | *.ipynb 14 | *.pdf 15 | *.zip 16 | *.tgz 17 | test.py 18 | extern/mkl/mkldnn_lnx*/* 19 | data/ 20 | build/ 21 | venv/ 22 | *.md 23 | !*.src.md 24 | !README.md 25 | !README.cn.md 26 | python/jittor.egg-info 27 | dist/ 28 | !doc/source/* 29 | __data__ -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Describe the bug 11 | A clear and concise description of what the bug is. 使用中文也可以。 12 | 13 | ## Full Log 14 | 15 | Provide a full log of Jittor execution, Jittor will log environment information which help us to locate your bugs. Provide a screenshot is also acceptable. 16 | 17 | ## Minimal Reproduce 18 | 19 | Reproduce this error with a file or several lines of code. 20 | If it is not possible, leave it blank. 21 | 22 | ## Expected behavior 23 | A clear and concise description of what you expected to happen. 24 | 25 | If you are submitting an issue for the first time, please refer to [our guideline](https://github.com/Jittor/jittor/issues/238) 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | my 2 | .refresh 3 | .DS_Store 4 | __pycache__ 5 | .ipynb_checkpoints/ 6 | .vscode/ 7 | __res/ 8 | perf.data 9 | perf.data.old 10 | *.swp 11 | *.ipynb 12 | *.pdf 13 | *.zip 14 | *.tgz 15 | *.obj 16 | test.py 17 | extern/mkl/mkldnn_lnx*/* 18 | data/ 19 | build/ 20 | venv/ 21 | *.md 22 | !*.src.md 23 | !README.md 24 | !README.cn.md 25 | !CHANGELOG.md 26 | python/jittor.egg-info 27 | dist/ 28 | !doc/source/* 29 | core 30 | __data__ 31 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude __data__ 2 | exclude __pycache__ 3 | prune **/__data__/ 4 | prune **/__pycache__ 5 | prune *.pyc -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/build_doc.sh: -------------------------------------------------------------------------------- 1 | # sudo python3.7 -m pip install \ 2 | # recommonmark \ 3 | # sphinx sphinx-autobuild sphinx_rtd_theme \ 4 | # sphinx-autobuild \ 5 | # --timeout 100 6 | 7 | 8 | bpath=$(readlink -f "${BASH_SOURCE[0]}") 9 | bpath=$(dirname "${bpath}") 10 | 11 | jittor_path=$(readlink -f "${bpath}/..") 12 | 13 | echo "[doc path] $bpath" 14 | echo "[jittor path] $jittor_path" 15 | 16 | export PYTHONPATH=$jittor_path/python 17 | cd $bpath 18 | sphinx-autobuild -b html source build -H 0.0.0.0 -p 8890 19 | 20 | -------------------------------------------------------------------------------- /doc/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/jittor/444154c2a1e63ee4a0af07831a0c54e2ebb7a561/doc/logo.png -------------------------------------------------------------------------------- /doc/source/README.cn.md: -------------------------------------------------------------------------------- 1 | ../../README.cn.md -------------------------------------------------------------------------------- /doc/source/jittor.attention.md: -------------------------------------------------------------------------------- 1 | jittor.attention 2 | ===================== 3 | 4 | 这里是Jittor的 注意力 模块的API文档,您可以通过`from jittor import attention`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.attention 8 | :members: 9 | :undoc-members: 10 | ``` 11 | -------------------------------------------------------------------------------- /doc/source/jittor.contrib.md: -------------------------------------------------------------------------------- 1 | jittor.contrib 2 | ===================== 3 | 4 | 这里是Jittor的贡献代码模块模块的API文档,此模块的代码可能还没有完全成熟,我们将在后续迭代开发中继续完善,您可以通过`from jittor import contrib`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.contrib 8 | :members: 9 | :undoc-members: 10 | ``` 11 | -------------------------------------------------------------------------------- /doc/source/jittor.dataset.md: -------------------------------------------------------------------------------- 1 | jittor.dataset 2 | ===================== 3 | 4 | 这里是Jittor的数据集模块的API文档,您可以通过`from jittor import dataset`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.dataset 8 | :imported-members: 9 | :members: 10 | :undoc-members: 11 | ``` 12 | -------------------------------------------------------------------------------- /doc/source/jittor.distributions.md: -------------------------------------------------------------------------------- 1 | jittor.distributions 2 | ===================== 3 | 4 | 这里是Jittor的随机分布模块的API文档,您可以通过`from jittor import distributions`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.distributions 8 | :members: 9 | :undoc-members: 10 | ``` 11 | -------------------------------------------------------------------------------- /doc/source/jittor.init.md: -------------------------------------------------------------------------------- 1 | jittor.init 2 | ===================== 3 | 4 | 这里是Jittor的参数初始化模块的API文档,您可以通过`from jittor import init`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.init 8 | :members: 9 | :undoc-members: 10 | ``` 11 | -------------------------------------------------------------------------------- /doc/source/jittor.linalg.md: -------------------------------------------------------------------------------- 1 | jittor.linalg 2 | ===================== 3 | 4 | 这里是Jittor的线性代数函数的API文档,您可以通过`from jittor import linalg`来获取该模块。 5 | 6 | ## 基本函数简介 7 | #### 基本线性代数运算API 8 | - linalg.inv(a) 9 | 10 | 对a进行求逆运算 11 | 12 | - linalg.pinv(a) 13 | 14 | 对a进行广义求逆运算。该运算不要求原矩阵a可逆。 15 | 16 | - linalg.slogdet(a) 17 | 18 | 对a求取slogdet。会返回值以及符号。 19 | 20 | - linalg.det(a) 21 | 22 | 对a求行列式。 23 | 24 | - linalg.solve(a,b) 25 | 26 | 求解线性方程Ax=b的解。 27 | 28 | #### 分解API 29 | - linalg.cholesky(a) 30 | 31 | 对a进行cholesky分解。 32 | 33 | - linalg.qr(a) 34 | 35 | 对a进行qr分解。 36 | 37 | - linalg.svd 38 | 39 | 对a进行奇异值分解。 40 | #### 特征值API 41 | - linalg.eig(a) 42 | 43 | 求取a的特征值以及特征向量。 44 | 45 | - linalg.eigh(a) 46 | 47 | 针对埃尔米特矩阵或者对称矩阵求特征值以及特征向量。 48 | 49 | 50 | 目前的linalg库支持 51 | 52 | ```eval_rst 53 | .. automodule:: jittor.linalg 54 | :members: 55 | :undoc-members: 56 | ``` 57 | 58 | -------------------------------------------------------------------------------- /doc/source/jittor.loss3d.md: -------------------------------------------------------------------------------- 1 | jittor.loss3d 2 | ===================== 3 | 4 | 这里是Jittor的 3d 损失函数 模块的API文档,您可以通过`from jittor import loss3d`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.loss3d 8 | :members: chamfer_loss, ChamferLoss, earth_mover_distance, EarthMoverDistance 9 | :undoc-members: 10 | ``` 11 | -------------------------------------------------------------------------------- /doc/source/jittor.md: -------------------------------------------------------------------------------- 1 | jittor 2 | ===================== 3 | 4 | ## jittor 5 | 6 | 这里是Jittor主模块的API文档,您可以通过`import jittor`来获取该模块。 7 | 8 | ```eval_rst 9 | .. automodule:: jittor 10 | :members: 11 | :undoc-members: 12 | ``` 13 | 14 | ## jittor.core 15 | 16 | 以下为Jittor的内核API,内核API可以通过`jittor.core.XXX`或者`jittor.XXX`直接访问。 17 | 18 | 19 | ```eval_rst 20 | .. automodule:: jittor_core 21 | :imported-members: 22 | :members: 23 | :undoc-members: 24 | ``` 25 | 26 | ## jittor.ops 27 | 28 | 这里是Jittor的基础算子模块的API文档,该API可以通过`jittor.ops.XXX`或者`jittor.XXX`直接访问。 29 | 30 | ```eval_rst 31 | .. automodule:: jittor_core.ops 32 | :members: 33 | :undoc-members: 34 | ``` 35 | 36 | ## jittor.Var 37 | 38 | 这里是Jittor的基础变量类的API文档。该API可以通过`my_jittor_var.XXX`直接访问。 39 | 40 | ```eval_rst 41 | .. automodule:: jittor_core.Var 42 | :members: 43 | :undoc-members: 44 | ``` 45 | 46 | ## jittor.Misc 47 | 48 | 这里是Jittor的基础算子模块的API文档,该API可以通过`jittor.misc.XXX`或者`jittor.XXX`直接访问。 49 | 50 | ```eval_rst 51 | .. automodule:: jittor.misc 52 | :members: 53 | :undoc-members: 54 | ``` -------------------------------------------------------------------------------- /doc/source/jittor.models.md: -------------------------------------------------------------------------------- 1 | jittor.models 2 | ===================== 3 | 4 | 这里是Jittor的骨干网络模块的API文档,您可以通过`from jittor import models`来获取该模块。 5 | 6 | ```eval_rst 7 | 8 | .. automodule:: jittor.models 9 | :members: 10 | :imported-members: 11 | :undoc-members: 12 | :exclude-members: ResNet,ShuffleNetV2,SqueezeNet,VGG 13 | ``` 14 | 15 | -------------------------------------------------------------------------------- /doc/source/jittor.nn.md: -------------------------------------------------------------------------------- 1 | jittor.nn 2 | ===================== 3 | 4 | 这里是Jittor的神经网络模块的API文档,您可以通过`from jittor import nn`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.nn 8 | :members: 9 | :undoc-members: 10 | 11 | .. automodule:: jittor.nn 12 | :imported-members: 13 | :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool2d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d 14 | :undoc-members: 15 | 16 | .. autoclass:: jittor.nn.ReLU 17 | :members: 18 | .. autoclass:: jittor.nn.ReLU6 19 | :members: 20 | .. autoclass:: jittor.nn.LeakyReLU 21 | :members: 22 | .. autoclass:: jittor.nn.Softmax 23 | :members: 24 | ``` 25 | -------------------------------------------------------------------------------- /doc/source/jittor.optim.md: -------------------------------------------------------------------------------- 1 | jittor.optim 2 | ===================== 3 | 4 | 这里是Jittor的优化器模块的API文档,您可以通过`from jittor import optim`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.optim 8 | :members: 9 | :undoc-members: 10 | ``` 11 | 12 | 以下是Jittor的学习率调度模块的API文档,学习率调度模块需要配合优化器使用,您可以通过`from jittor import lr_scheduler`来获取该模块。 13 | 14 | ```eval_rst 15 | .. automodule:: jittor.lr_scheduler 16 | :members: 17 | :undoc-members: 18 | ``` -------------------------------------------------------------------------------- /doc/source/jittor.transform.md: -------------------------------------------------------------------------------- 1 | jittor.transform 2 | ===================== 3 | 4 | 这里是Jittor的 数据变换 模块的API文档,您可以通过`from jittor import transform`来获取该模块。 5 | 6 | ```eval_rst 7 | .. automodule:: jittor.transform 8 | :members: 9 | :undoc-members: 10 | ``` 11 | -------------------------------------------------------------------------------- /doc/source/todo.md: -------------------------------------------------------------------------------- 1 | TODO 2 | ===================== 3 | 4 | ## 文档相关 5 | 6 | * 文档语法规范 7 | * 文档加上教程链接 8 | * MPI接口文档 9 | * 文档自动更新 10 | * 首页到文档的链接 11 | * 模型库的文档(GAN,segmentation,detection...) 12 | * 文档补全,重要的类加上使用example 13 | -------------------------------------------------------------------------------- /python/jittor/ccl/__init__.py: -------------------------------------------------------------------------------- 1 | from .ccl_2d import ccl_2d 2 | from .ccl_3d import ccl_3d 3 | from .ccl_link import ccl_link -------------------------------------------------------------------------------- /python/jittor/compatibility/distributed.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from enum import Enum 3 | import jittor as jt 4 | 5 | 6 | class DistributedDataParallel: 7 | def __new__(cls, model): 8 | return model 9 | 10 | def is_initialized(): 11 | return True 12 | 13 | def get_rank(group=None): 14 | return 0 15 | 16 | def get_world_size(group=None): 17 | return 1 18 | 19 | def get_backend(group=None): 20 | return "nccl" 21 | 22 | def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None): 23 | return 1 24 | 25 | def barrier(): 26 | pass 27 | 28 | def is_available(): 29 | return True 30 | 31 | def is_built(): 32 | return True 33 | 34 | class ReduceOp: 35 | SUM = 0 36 | 37 | class GroupMember: 38 | WORLD = 0 39 | 40 | class ProcessGroup: 41 | pass 42 | 43 | class Join: 44 | pass 45 | 46 | dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL")) 47 | _backend = dist_backend.NCCL 48 | 49 | def is_mpi_available(): 50 | return jt.in_mpi 51 | 52 | def DistributedDataParallel(model, *args, **kw): 53 | return model 54 | -------------------------------------------------------------------------------- /python/jittor/compatibility/distributions.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | class RelaxedBernoulli: 4 | def __init__(self, temperature, probs=None, logits=None): 5 | self.temperature = temperature 6 | self.probs = probs 7 | self.logits = logits 8 | 9 | def rsample(self): 10 | noise = jt.rand_like(self.logits) 11 | eps = 1e-20 12 | noise = jt.clamp(noise, eps, 1.0 - eps) 13 | logit_noise = jt.log(noise) - jt.log(1 - noise) 14 | sample = (self.logits + logit_noise) / self.temperature 15 | return jt.sigmoid(sample) 16 | -------------------------------------------------------------------------------- /python/jittor/compatibility/fft/__init__.py: -------------------------------------------------------------------------------- 1 | #TODO: Implement FFT and IFFT 2 | fftn = None 3 | fftshift = None 4 | ifftn = None 5 | ifftshift = None -------------------------------------------------------------------------------- /python/jittor/compatibility/fx.py: -------------------------------------------------------------------------------- 1 | class Proxy: 2 | pass -------------------------------------------------------------------------------- /python/jittor/compatibility/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def _jit_set_profiling_mode(x): pass 4 | def _jit_set_profiling_executor(x): pass 5 | def _jit_override_can_fuse_on_cpu(x): pass 6 | def _jit_override_can_fuse_on_gpu(x): pass 7 | 8 | def script(func): 9 | return func 10 | 11 | inf = math.inf 12 | nan = math.nan -------------------------------------------------------------------------------- /python/jittor/compatibility/nn/init.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | for k,v in jt.nn.init.__dict__.items(): 4 | if callable(v): 5 | globals()[k] = v 6 | 7 | 8 | normal = gauss 9 | normal_ = gauss_ 10 | xavier_normal = xavier_gauss 11 | xavier_normal_ = xavier_gauss_ 12 | zeros_ = zero_ 13 | 14 | 15 | jt.Var.normal_ = normal_ 16 | 17 | -------------------------------------------------------------------------------- /python/jittor/compatibility/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import rnn -------------------------------------------------------------------------------- /python/jittor/compatibility/nn/utils/rnn.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | PackedSequence = None 4 | 5 | def pad_sequence(sequences,batch_first=False,padding_value=0.0): 6 | max_f = max([len(s) for s in sequences]) 7 | # max_f = 512 8 | b = len(sequences) 9 | if batch_first: 10 | ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value) 11 | for i,s in enumerate(sequences): 12 | ret[i,:len(s)] = s 13 | else: 14 | ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value) 15 | for i,s in enumerate(sequences): 16 | ret[:len(s),i] = s 17 | # print(ret.shape) 18 | # ret = ret[:,:406] 19 | return ret 20 | -------------------------------------------------------------------------------- /python/jittor/compatibility/src/jtorch_core.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "common.h" 3 | #include "var_holder.h" 4 | #include "misc/fast_shared_ptr.h" 5 | 6 | namespace jittor { 7 | 8 | // @pyjt(device) 9 | // @attrs(heaptype) 10 | struct Device { 11 | string name; 12 | 13 | // @pyjt(__init__) 14 | Device(const string& name, int ordinal=0); 15 | // @pyjt(__get__type, __str__) 16 | inline string get_type() {return name;} 17 | // @pyjt(__get__index) 18 | inline int index() {return 0;} 19 | }; 20 | 21 | // @pyjt(backward) 22 | void backward(VarHolder* x); 23 | 24 | // @pyjt(grad_set) 25 | void grad_set(VarHolder* x, Maybe v); 26 | // @pyjt(grad_get) 27 | Maybe grad_get(VarHolder* x); 28 | // @pyjt(grad_del) 29 | void grad_del(VarHolder* x); 30 | 31 | // @pyjt(retain_grad_set) 32 | inline void retain_grad_set(VarHolder* x, bool v) { 33 | x->var->flags.set(NodeFlags::_th_require_grad, v); 34 | } 35 | // @pyjt(retain_grad_get) 36 | inline bool retain_grad_get(VarHolder* x) { 37 | return x->var->flags.get(NodeFlags::_th_require_grad); 38 | } 39 | 40 | } -------------------------------------------------------------------------------- /python/jittor/compatibility/test/test_conflict_func.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | import jittor as jt 5 | 6 | class TestConflictFunc(unittest.TestCase): 7 | def test_max(self): 8 | a = torch.Tensor([1,4,2]) 9 | assert a.max() == 4 10 | v, k = a.max(dim=0) 11 | assert v==4 and k==1 12 | 13 | def test_argsort(self): 14 | a = torch.Tensor([1,4,2]) 15 | k = a.argsort() 16 | assert jt.all_equal(k, [0,2,1]) 17 | 18 | with jt.flag_scope(th_mode=0): 19 | k, v = a.argsort() 20 | assert jt.all_equal(k, [0,2,1]) 21 | 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /python/jittor/compatibility/test/test_misc.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | 5 | class TestMisc(unittest.TestCase): 6 | def test_update_grad(self): 7 | class Net(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0])) 11 | net = Net() 12 | assert(net.a.requires_grad) 13 | net.load_state_dict({"a": torch.Tensor([3.0, 4.0])}) 14 | assert(net.a.requires_grad) 15 | 16 | def test_reshape(self): 17 | a = torch.ones(3,3) 18 | a.requires_grad = True 19 | b = torch.reshape(a, [9]) 20 | assert b.requires_grad == True 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /python/jittor/compatibility/utils/__init__.py: -------------------------------------------------------------------------------- 1 | cpp_extension = None 2 | _flatten_dense_tensors = None 3 | _unflatten_dense_tensors = None 4 | 5 | tensorboard = None -------------------------------------------------------------------------------- /python/jittor/compatibility/utils/_pytree.py: -------------------------------------------------------------------------------- 1 | #TODO: Implement this 2 | _register_pytree_node = None 3 | _dict_flatten = None -------------------------------------------------------------------------------- /python/jittor/compatibility/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | detach_variable = None 2 | 3 | 4 | def checkpoint( 5 | *args, 6 | **kwargs 7 | ): 8 | pass 9 | -------------------------------------------------------------------------------- /python/jittor/compatibility/utils/dtype.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | Dtype = Union[Callable, str] 3 | 4 | def get_string_dtype(dtype): 5 | if callable(dtype): 6 | dtype = dtype.__name__ 7 | if not isinstance(dtype, str): 8 | raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.") 9 | return dtype -------------------------------------------------------------------------------- /python/jittor/compatibility/utils/hooks.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/jittor/444154c2a1e63ee4a0af07831a0c54e2ebb7a561/python/jittor/compatibility/utils/hooks.py -------------------------------------------------------------------------------- /python/jittor/compatibility/utils/pip_publish.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import shutil 4 | import sys 5 | 6 | home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..") 7 | home_path = os.path.abspath(home_path) 8 | 9 | def callback(func, path, exc_info): 10 | print(f"remove \"{path}\" failed.") 11 | 12 | def rmtree(path): 13 | if os.path.isdir(path): 14 | print(f"remove \"{path}\" recursive.") 15 | shutil.rmtree(path, onerror=callback) 16 | 17 | def remove_tmpfile(): 18 | dist_file = home_path+"/dist" 19 | egg_file = glob.glob(home_path+"/**/*egg-info") 20 | rmtree(dist_file) 21 | for e in egg_file: 22 | rmtree(e) 23 | 24 | def run_cmd(cmd): 25 | print("[CMD]", cmd) 26 | assert os.system(cmd)==0 27 | 28 | os.chdir(home_path) 29 | remove_tmpfile() 30 | 31 | run_cmd(f"{sys.executable} ./setup.py sdist") 32 | run_cmd(f"{sys.executable} -m twine upload dist/*") 33 | 34 | remove_tmpfile() -------------------------------------------------------------------------------- /python/jittor/compatibility/vision/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST 2 | 3 | __all__ = ( 4 | "EMNIST", 5 | "FashionMNIST", 6 | "QMNIST", 7 | "MNIST", 8 | "KMNIST", 9 | ) -------------------------------------------------------------------------------- /python/jittor/compatibility/vision/transforms.py: -------------------------------------------------------------------------------- 1 | from jittor.transform import * -------------------------------------------------------------------------------- /python/jittor/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset, VarDataset, DataLoader 3 | from .mnist import MNIST 4 | from .cifar import CIFAR10, CIFAR100 5 | from .voc import VOC 6 | from .sampler import * -------------------------------------------------------------------------------- /python/jittor/einops/__init__.py: -------------------------------------------------------------------------------- 1 | class EinopsError(RuntimeError): 2 | """ Runtime error thrown by einops """ 3 | pass 4 | 5 | 6 | __all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError'] 7 | 8 | from jittor.einops.einops import rearrange, reduce, repeat, parse_shape, asnumpy 9 | -------------------------------------------------------------------------------- /python/jittor/einops/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/jittor/444154c2a1e63ee4a0af07831a0c54e2ebb7a561/python/jittor/einops/experimental/__init__.py -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/jittor/444154c2a1e63ee4a0af07831a0c54e2ebb7a561/python/jittor/extern/acl/aclops/__init__.py -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/binary_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | struct BinaryOpRunner : public BaseOpRunner 8 | { 9 | BinaryOpRunner(); 10 | 11 | protected: 12 | void executeOp(std::unordered_map::iterator &it) override; 13 | }; 14 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/bmm_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class BatchMatMulOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void setupInputDesc() override; 12 | void executeOp(std::unordered_map::iterator &it) override; 13 | 14 | public: 15 | BatchMatMulOpRunner(); 16 | }; 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/concat_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class ConcatOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | ConcatOpRunner(); 15 | }; 16 | 17 | class SplitWithSizeOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | SplitWithSizeOpRunner(); 25 | }; 26 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/conv_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class Conv2dOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | Conv2dOpRunner(); 15 | }; 16 | 17 | class Conv2dBackwardOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | void setupOutputDesc() override; 23 | 24 | public: 25 | Conv2dBackwardOpRunner(); 26 | }; 27 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/cumsum_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class CumsumOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | CumsumOpRunner(); 15 | }; 16 | 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/dropout_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class DropoutOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | DropoutOpRunner(); 15 | }; 16 | 17 | class DropoutBackwardOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | DropoutBackwardOpRunner(); 25 | }; 26 | 27 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/embedding_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class EmbeddingOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | public: 13 | EmbeddingOpRunner(); 14 | }; 15 | 16 | class EmbeddingBackwardOpRunner : public BaseOpRunner 17 | { 18 | 19 | protected: 20 | void executeOp(std::unordered_map::iterator &it) override; 21 | public: 22 | EmbeddingBackwardOpRunner(); 23 | }; 24 | 25 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/expand_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | struct ExpandOpRunner : public BaseOpRunner 8 | { 9 | ExpandOpRunner(); 10 | 11 | protected: 12 | void executeOp(std::unordered_map::iterator &it) override; 13 | }; 14 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/flashattention_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class FlashAttentionOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | FlashAttentionOpRunner(); 15 | }; 16 | 17 | class FlashAttentionBackwardOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | FlashAttentionBackwardOpRunner(); 25 | }; 26 | 27 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/flip_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class FlipOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | FlipOpRunner(); 15 | }; 16 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/floor_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class FloorOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | FloorOpRunner(); 15 | }; 16 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/gather_scatter_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class GatherOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | GatherOpRunner(); 15 | }; 16 | 17 | class ScatterOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | ScatterOpRunner(); 25 | }; 26 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/index_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class RangeOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | RangeOpRunner(); 15 | }; 16 | 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/matmul_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class MatMulOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void setupInputDesc() override; 12 | void executeOp(std::unordered_map::iterator &it) override; 13 | 14 | public: 15 | MatMulOpRunner(); 16 | }; 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/nantonum_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class NanToNumOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | NanToNumOpRunner(); 15 | }; 16 | 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/norms_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class BatchNormOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | public: 13 | BatchNormOpRunner(); 14 | }; 15 | 16 | class BatchNormBackwardOpRunner : public BaseOpRunner 17 | { 18 | 19 | protected: 20 | void executeOp(std::unordered_map::iterator &it) override; 21 | public: 22 | BatchNormBackwardOpRunner(); 23 | }; 24 | 25 | class LayerNormOpRunner : public BaseOpRunner 26 | { 27 | 28 | protected: 29 | void executeOp(std::unordered_map::iterator &it) override; 30 | public: 31 | LayerNormOpRunner(); 32 | }; 33 | 34 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/pool_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class MaxpoolOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | MaxpoolOpRunner(); 15 | }; 16 | 17 | class AvgpoolOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | AvgpoolOpRunner(); 25 | }; 26 | 27 | class MaxpoolBackwardOpRunner : public BaseOpRunner 28 | { 29 | 30 | protected: 31 | void executeOp(std::unordered_map::iterator &it) override; 32 | 33 | public: 34 | MaxpoolBackwardOpRunner(); 35 | }; 36 | 37 | class AvgpoolBackwardOpRunner : public BaseOpRunner 38 | { 39 | 40 | protected: 41 | void executeOp(std::unordered_map::iterator &it) override; 42 | 43 | public: 44 | AvgpoolBackwardOpRunner(); 45 | }; 46 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/random_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class RandomOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | string name; // special to random op 12 | void executeOp(std::unordered_map::iterator &it) override; 13 | 14 | public: 15 | RandomOpRunner(); 16 | RandomOpRunner(const string &name); 17 | }; 18 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/reduce_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | struct ReduceOpRunner : public BaseOpRunner 8 | { 9 | int op_idx; // Specific to reduce operations 10 | 11 | ReduceOpRunner(); 12 | 13 | protected: 14 | ReduceAttr *attr; 15 | aclIntArray *dim; 16 | bool keepdims; 17 | 18 | void setupOutputDesc() override; 19 | void executeOp(std::unordered_map::iterator &it) override; 20 | }; 21 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/relu_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class LeakyReLUOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | LeakyReLUOpRunner(); 15 | }; 16 | 17 | class LeakyReLUBackwardOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | LeakyReLUBackwardOpRunner(); 25 | }; 26 | 27 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/rope_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class RotaryPosEmbOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | RotaryPosEmbOpRunner(); 15 | }; 16 | 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/setitem_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class InplaceMaskedScatterOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | InplaceMaskedScatterOpRunner(); 15 | }; 16 | 17 | class IndexPutImplOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | IndexPutImplOpRunner(); 25 | }; 26 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/sigmoid_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class SigmoidOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | SigmoidOpRunner(); 15 | }; 16 | 17 | class SigmoidBackwardOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | SigmoidBackwardOpRunner(); 25 | }; 26 | 27 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/silu_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class SiLUOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | SiLUOpRunner(); 15 | }; 16 | 17 | class SiLUBackwardOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | SiLUBackwardOpRunner(); 25 | }; 26 | 27 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/softmax_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class SoftmaxOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | SoftmaxOpRunner(); 15 | }; 16 | 17 | class SoftmaxBackwardOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | SoftmaxBackwardOpRunner(); 25 | }; 26 | 27 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/stack_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class StackOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | StackOpRunner(); 15 | }; 16 | 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/ternary_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | struct TernaryOpRunner : public BaseOpRunner 8 | { 9 | TernaryOpRunner(); 10 | 11 | protected: 12 | void executeOp(std::unordered_map::iterator &it) override; 13 | }; 14 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/transpose_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class TransposeOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | TransposeOpRunner(); 15 | }; 16 | 17 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/triu_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class TriuOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | public: 13 | TriuOpRunner(); 14 | }; 15 | 16 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/unary_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | struct UnaryOpRunner : public BaseOpRunner 8 | { 9 | UnaryOpRunner(); 10 | 11 | protected: 12 | void executeOp(std::unordered_map::iterator &it) override; 13 | }; 14 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "misc/nano_string.h" 9 | 10 | namespace jittor 11 | { 12 | aclDataType get_dtype(NanoString s); 13 | 14 | extern std::unordered_map op_idx_map; 15 | int CreateAclTensor(const std::vector &shape, void *deviceAddr, int64_t size, 16 | aclDataType dataType, aclTensor **tensor, bool use_nchw = false); 17 | 18 | int CreateFakeTransAclTensor(std::vector &shape, void *deviceAddr, int64_t size, 19 | aclDataType dataType, aclTensor **tensor, bool use_nchw = false); 20 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/aclops/where_op_acl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utils.h" 3 | #include "base_op.h" 4 | 5 | namespace jittor 6 | { 7 | class WhereOpRunner : public BaseOpRunner 8 | { 9 | 10 | protected: 11 | void executeOp(std::unordered_map::iterator &it) override; 12 | 13 | public: 14 | WhereOpRunner(); 15 | }; 16 | 17 | class NonzeroOpRunner : public BaseOpRunner 18 | { 19 | 20 | protected: 21 | void executeOp(std::unordered_map::iterator &it) override; 22 | 23 | public: 24 | NonzeroOpRunner(); 25 | }; 26 | } -------------------------------------------------------------------------------- /python/jittor/extern/acl/hccl/inc/hccl_wrapper.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2025 Jittor. 3 | // All Rights Reserved. 4 | // Maintainers: 5 | // Jiapeng Zhang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | 11 | #pragma once 12 | #include "mpi_wrapper.h" 13 | 14 | #define ACLCHECK(ret) do {\ 15 | if(ret != ACL_SUCCESS)\ 16 | {\ 17 | LOGe << "retcode: " << ret;\ 18 | return;\ 19 | }\ 20 | } while(0)\ 21 | 22 | #define HCCLCHECK(ret) do {\ 23 | if(ret != HCCL_SUCCESS)\ 24 | {\ 25 | LOGe << HcclGetErrorString(ret) << " retcode: " << ret;\ 26 | return;\ 27 | }\ 28 | } while(0) 29 | 30 | #include 31 | 32 | namespace jittor { 33 | 34 | EXTERN_LIB HcclRootInfo root_info; 35 | EXTERN_LIB HcclComm comm; 36 | EXTERN_LIB uint32_t hccl_device_id; 37 | 38 | } // jittor 39 | -------------------------------------------------------------------------------- /python/jittor/extern/acl/hccl/ops/hccl_all_gather_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2025 Jittor. 3 | // All Rights Reserved. 4 | // Maintainers: 5 | // Jiapeng Zhang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct HcclAllGatherOp : Op { 16 | Var* x, * y; 17 | 18 | HcclAllGatherOp(Var* x); 19 | void infer_shape() override; 20 | 21 | const char* name() const override { return "hccl_all_gather"; } 22 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 23 | DECLARE_jit_run; 24 | }; 25 | 26 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/acl/hccl/ops/hccl_all_reduce_op.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "op.h" 3 | 4 | namespace jittor { 5 | 6 | struct HcclAllReduceOp : Op { 7 | Var* x, * y; 8 | string reduce_op; 9 | 10 | HcclAllReduceOp(Var* x, string reduce_op="sum"); 11 | void infer_shape() override; 12 | 13 | const char* name() const override { return "hccl_all_reduce"; } 14 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 15 | DECLARE_jit_run; 16 | }; 17 | 18 | } // jittor 19 | -------------------------------------------------------------------------------- /python/jittor/extern/acl/hccl/ops/hccl_broadcast_op.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "op.h" 3 | 4 | namespace jittor { 5 | 6 | struct HcclBroadcastOp : Op { 7 | Var* x, * y; 8 | int root; 9 | 10 | HcclBroadcastOp(Var* x, int root=0); 11 | void infer_shape() override; 12 | 13 | const char* name() const override { return "hccl_broadcast"; } 14 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 15 | DECLARE_jit_run; 16 | }; 17 | 18 | } // jittor 19 | -------------------------------------------------------------------------------- /python/jittor/extern/acl/hccl/ops/hccl_reduce_op.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "op.h" 3 | 4 | namespace jittor { 5 | 6 | struct HcclReduceOp : Op { 7 | Var* x, * y; 8 | string reduce_op; 9 | int root; 10 | 11 | HcclReduceOp(Var* x, string reduce_op="sum", int root=0); 12 | void infer_shape() override; 13 | 14 | const char* name() const override { return "hccl_reduce"; } 15 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 16 | DECLARE_jit_run; 17 | }; 18 | 19 | } // jittor 20 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | 14 | namespace jittor { 15 | 16 | struct CubArgReduceOp : Op { 17 | Var* x, * offsets, * y, * y_key; 18 | NanoString op; 19 | bool keepdims; 20 | // @attrs(multiple_outputs) 21 | CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims); 22 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 23 | void infer_shape() override; 24 | 25 | const char* name() const override { return "cub_arg_reduce"; } 26 | DECLARE_jit_run; 27 | }; 28 | 29 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cub/ops/cub_argsort_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | 14 | namespace jittor { 15 | 16 | struct CubArgsortOp : Op { 17 | Var* x, * indexes, * offsets, * y, * y_key; 18 | bool descending; 19 | // @attrs(multiple_outputs) 20 | CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending=false, NanoString dtype=ns_int32); 21 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 22 | void infer_shape() override; 23 | 24 | const char* name() const override { return "cub_argsort"; } 25 | DECLARE_jit_run; 26 | }; 27 | 28 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cub/ops/cub_cumsum_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | 14 | namespace jittor { 15 | 16 | struct CubCumsumOp : Op { 17 | Var* x, * y; 18 | bool reverse; 19 | 20 | CubCumsumOp(Var* x, bool reverse=false); 21 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 22 | 23 | void infer_shape() override; 24 | const char* name() const override { return "cub_cumsum"; } 25 | DECLARE_jit_run; 26 | }; 27 | 28 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cub/ops/cub_test_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct CubTestOp : Op { 13 | Var* output; 14 | string cmd; 15 | 16 | CubTestOp(string cmd); 17 | 18 | const char* name() const override { return "cub_test"; } 19 | DECLARE_jit_run; 20 | }; 21 | 22 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include 12 | #include 13 | 14 | #include "utils/log.h" 15 | #include "helper_cuda.h" 16 | #include "fp16_emu.h" 17 | #include "common.h" 18 | #include "misc/nano_string.h" 19 | 20 | namespace jittor { 21 | 22 | EXTERN_LIB cublasHandle_t cublas_handle; 23 | 24 | static inline cudaDataType get_dtype(NanoString dtype) { 25 | if (dtype == ns_float32) return CUDA_R_32F; 26 | if (dtype == ns_float64) return CUDA_R_64F; 27 | if (dtype == ns_float16) return CUDA_R_16F; 28 | #ifndef IS_ROCM 29 | if (dtype == ns_bfloat16) return CUDA_R_16BF; 30 | #endif 31 | LOGf << "not support type" << dtype; 32 | return CUDA_R_32F; 33 | } 34 | 35 | } // jittor 36 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct CublasAccMatmulOp : Op { 16 | Var* a, * b, * c; 17 | bool trans_a, trans_b; 18 | int stride_a, stride_b; 19 | int offset_a, offset_b; 20 | CublasAccMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b, int stride_a=-1, int stride_b=-1, int offset_a=0, int offset_b=0); 21 | 22 | const char* name() const override { return "cublas_acc_matmul"; } 23 | void infer_shape() override; 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Meng-Hao Guo 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | 11 | 12 | // cublas_batched_matmul_op.h 13 | #pragma once 14 | #include "op.h" 15 | #include "ops/op_register.h" 16 | #include "var.h" 17 | 18 | namespace jittor { 19 | 20 | struct CublasBatchedMatmulOp : Op { 21 | Var* a, * b, * c; 22 | bool trans_a, trans_b; 23 | CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); 24 | 25 | const char* name() const override { return "cublas_batched_matmul"; } 26 | void infer_shape() override; 27 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 28 | DECLARE_jit_run; 29 | }; 30 | 31 | } // jittor 32 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct CublasMatmulOp : Op { 16 | Var* a, * b, * c; 17 | bool trans_a, trans_b; 18 | CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); 19 | 20 | const char* name() const override { return "cublas_matmul"; } 21 | void infer_shape() override; 22 | DECLARE_jit_run; 23 | }; 24 | 25 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | 9 | #include "var.h" 10 | #include "cublas_test_op.h" 11 | 12 | int cublas_test_entry(int); 13 | 14 | namespace jittor { 15 | 16 | #ifndef JIT 17 | CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) { 18 | output = create_output(1, ns_float32); 19 | } 20 | 21 | void CublasTestOp::jit_prepare(JK& jk) { 22 | jk << "«T:float32"; 23 | } 24 | 25 | #else // JIT 26 | #ifdef JIT_cpu 27 | void CublasTestOp::jit_run() { 28 | ASSERT(cublas_test_entry(size_mult)==0); 29 | output->ptr()[0] = 123; 30 | } 31 | #endif 32 | #endif // JIT 33 | 34 | } // jittor 35 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cublas/ops/cublas_test_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct CublasTestOp : Op { 13 | Var* output; 14 | int size_mult; 15 | 16 | CublasTestOp(int size_mult); 17 | 18 | const char* name() const override { return "cublas_test"; } 19 | DECLARE_jit_run; 20 | }; 21 | 22 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cublas/src/cublas_wrapper.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #include "cublas_wrapper.h" 11 | #include "misc/cuda_flags.h" 12 | 13 | namespace jittor { 14 | 15 | cublasHandle_t cublas_handle; 16 | 17 | struct cublas_initer { 18 | 19 | inline cublas_initer() { 20 | if (!get_device_count()) return; 21 | checkCudaErrors(cublasCreate(&cublas_handle)); 22 | LOGv << "cublasCreate finished" << (void*)cublas_handle; 23 | } 24 | 25 | inline ~cublas_initer() { 26 | if (!get_device_count()) return; 27 | LOGv << "cublasDestroy:" << (void*)cublas_handle; 28 | checkCudaErrors(cublasDestroy(cublas_handle)); 29 | LOGv << "cublasDestroy finished"; 30 | } 31 | 32 | } init; 33 | 34 | } // jittor 35 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct CudnnConv3dOp : Op { 13 | Var* x, * w, * y; 14 | int strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups; 15 | string xformat; 16 | CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd=1, int dilationh=1, int dilationw=1, int groups=1, string xformat="ncdhw"); 17 | 18 | const char* name() const override { return "cudnn_conv3d"; } 19 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 20 | void infer_shape() override; 21 | DECLARE_jit_run; 22 | }; 23 | 24 | } // jittor 25 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang 5 | // Guowei Yang <471184555@qq.com> 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct CudnnConvBackwardWOp : Op { 16 | Var* x, * dy, * dw; 17 | int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; 18 | string xformat, wformat, yformat; 19 | 20 | CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); 21 | 22 | const char* name() const override { return "cudnn_conv_backward_w"; } 23 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 24 | void infer_shape() override; 25 | DECLARE_jit_run; 26 | }; 27 | 28 | } // jittor 29 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang 5 | // Guowei Yang <471184555@qq.com> 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct CudnnConvBackwardXOp : Op { 16 | Var* w, * dy, * dx; 17 | int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; 18 | string xformat, wformat, yformat; 19 | 20 | CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); 21 | 22 | const char* name() const override { return "cudnn_conv_backward_x"; } 23 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 24 | void infer_shape() override; 25 | DECLARE_jit_run; 26 | }; 27 | 28 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct CudnnConvOp : Op { 13 | Var* x, * w, * y; 14 | int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; 15 | string xformat, wformat, yformat; 16 | /* CudnnConvOp: xformat abcd represents nchw */ 17 | CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); 18 | 19 | const char* name() const override { return "cudnn_conv"; } 20 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 21 | void infer_shape() override; 22 | DECLARE_jit_run; 23 | }; 24 | 25 | } // jittor 26 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | 9 | #include "var.h" 10 | #include "cudnn_test_op.h" 11 | #include "utils/str_utils.h" 12 | 13 | int cudnn_test_entry( int argc, char** argv ); 14 | 15 | namespace jittor { 16 | 17 | #ifndef JIT 18 | CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) { 19 | output = create_output(1, ns_float32); 20 | } 21 | 22 | void CudnnTestOp::jit_prepare(JK& jk) { 23 | jk << "«T:float32"; 24 | } 25 | 26 | #else // JIT 27 | #ifdef JIT_cpu 28 | void CudnnTestOp::jit_run() { 29 | auto args = split(cmd, " "); 30 | if (!cmd.size()) args.clear(); 31 | vector v(args.size()); 32 | for (uint i=0; iptr()[0] = 123; 36 | } 37 | #endif 38 | #endif // JIT 39 | 40 | } // jittor 41 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct CudnnTestOp : Op { 13 | Var* output; 14 | string cmd; 15 | CudnnTestOp(string cmd); 16 | 17 | const char* name() const override { return "cudnn_test"; } 18 | DECLARE_jit_run; 19 | }; 20 | 21 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/src/cudnn_wrapper.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "cudnn_wrapper.h" 8 | #include "misc/cuda_flags.h" 9 | 10 | namespace jittor { 11 | 12 | cudnnHandle_t cudnn_handle; 13 | int max_cache_size = 100; 14 | float max_workspace_ratio = 0.25; 15 | 16 | void set_algorithm_cache_size(int size) { 17 | max_cache_size = size; 18 | } 19 | 20 | void set_max_workspace_ratio(float64 ratio) { 21 | max_workspace_ratio = ratio; 22 | } 23 | 24 | struct cudnn_initer { 25 | 26 | inline cudnn_initer() { 27 | if (!get_device_count()) return; 28 | checkCudaErrors(cudnnCreate(&cudnn_handle)); 29 | LOGv << "cudnnCreate finished"; 30 | } 31 | 32 | inline ~cudnn_initer() { 33 | if (!get_device_count()) return; 34 | checkCudaErrors(cudnnDestroy(cudnn_handle)); 35 | LOGv << "cudnnDestroy finished"; 36 | } 37 | 38 | } init; 39 | 40 | } // jittor 41 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cudnn/src/helper_cudnn.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "utils/log.h" 3 | #include "helper_cuda.h" 4 | 5 | const char *_cudaGetErrorEnum(cudnnStatus_t error) { 6 | return cudnnGetErrorString(error); 7 | } -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cufft/inc/cufft_wrapper.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com>. 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include 12 | #include 13 | #include "cufft_utils.h" 14 | 15 | #include "utils/log.h" 16 | #include "helper_cuda.h" 17 | #include "fp16_emu.h" 18 | #include "common.h" 19 | 20 | namespace jittor { 21 | 22 | EXTERN_LIB unordered_map cufft_handle_cache; 23 | 24 | } // jittor 25 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cufft/ops/cufft_fft_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com>. 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | //TODO: support FFT2D only now. 16 | struct CufftFftOp : Op { 17 | bool inverse; 18 | Var* x, * y; 19 | NanoString type; 20 | CufftFftOp(Var* x, bool inverse=false); 21 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 22 | 23 | const char* name() const override { return "cufft_fft"; } 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cufft/src/cufft_wrapper.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com>. 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #include "cufft_wrapper.h" 11 | #include "misc/cuda_flags.h" 12 | 13 | namespace jittor { 14 | 15 | unordered_map cufft_handle_cache; 16 | 17 | struct cufft_initer { 18 | 19 | inline cufft_initer() { 20 | if (!get_device_count()) return; 21 | LOGv << "cufftCreate finished"; 22 | } 23 | 24 | inline ~cufft_initer() { 25 | if (!get_device_count()) return; 26 | for (auto it = cufft_handle_cache.begin(); it != cufft_handle_cache.end(); it++) { 27 | CUFFT_CALL(cufftDestroy(it->second)); 28 | } 29 | cufft_handle_cache.clear(); 30 | LOGv << "cufftDestroy finished"; 31 | } 32 | 33 | } init; 34 | 35 | } // jittor 36 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/curand/inc/curand_wrapper.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include 12 | #include 13 | 14 | #include "helper_cuda.h" 15 | #include "fp16_emu.h" 16 | #include "common.h" 17 | 18 | namespace jittor { 19 | 20 | EXTERN_LIB curandGenerator_t gen; 21 | 22 | } // jittor 23 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/curand/ops/curand_random_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com>. 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct CurandRandomOp : Op { 16 | Var* output; 17 | NanoString type; 18 | CurandRandomOp(NanoVector shape, NanoString dtype=ns_float32, NanoString type=ns_uniform); 19 | 20 | const char* name() const override { return "curand_random"; } 21 | DECLARE_jit_run; 22 | }; 23 | 24 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/curand/src/curand_wrapper.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #include "curand_wrapper.h" 11 | #include "init.h" 12 | #include "misc/cuda_flags.h" 13 | 14 | namespace jittor { 15 | 16 | curandGenerator_t gen; 17 | 18 | struct curand_initer { 19 | 20 | inline curand_initer() { 21 | if (!get_device_count()) return; 22 | checkCudaErrors( curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT) ); 23 | add_set_seed_callback([](int seed) { 24 | checkCudaErrors( curandSetPseudoRandomGeneratorSeed(gen, seed) ); 25 | }); 26 | LOGv << "curandCreate finished"; 27 | } 28 | 29 | inline ~curand_initer() { 30 | if (!get_device_count()) return; 31 | checkCudaErrors( curandDestroyGenerator(gen) ); 32 | LOGv << "curandDestroy finished"; 33 | } 34 | 35 | } init_; 36 | 37 | } // jittor 38 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cusparse/ops/cusparse_spmmcoo_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Shizhan Lu <578752274@qq.com>. 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | #include "cusparse.h" 10 | namespace jittor { 11 | 12 | struct CusparseSpmmcooOp : Op { 13 | Var* x; 14 | Var* outputVar; 15 | Var* row_indices; 16 | Var* col_indices; 17 | Var* value; 18 | Var* output; 19 | int A_row; 20 | int A_col; 21 | bool trans_A; 22 | bool trans_B; 23 | CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col,bool trans_A,bool trans_B); 24 | const char* name() const override { return "cusparse_spmmcoo"; } 25 | DECLARE_jit_run; 26 | }; 27 | 28 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cusparse/ops/cusparse_spmmcsr_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Shizhan Lu <578752274@qq.com>. 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | #include "cusparse.h" 10 | namespace jittor { 11 | 12 | struct CusparseSpmmcsrOp : Op { 13 | Var* x; 14 | Var* outputVar; 15 | Var* col_indices; 16 | Var* row_offset; 17 | Var* value; 18 | Var* output; 19 | int A_row; 20 | int A_col; 21 | bool trans_A; 22 | bool trans_B; 23 | CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col,bool trans_A,bool trans_B); 24 | const char* name() const override { return "cusparse_spmmcsr"; } 25 | DECLARE_jit_run; 26 | }; 27 | 28 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cusparse/src/cusparse_wrapper.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Shizhan Lu <578752274@qq.com>. 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "cusparse_wrapper.h" 8 | #include "misc/cuda_flags.h" 9 | 10 | namespace jittor { 11 | 12 | cusparseHandle_t cusparse_handle; 13 | 14 | struct cusparse_initer { 15 | 16 | inline cusparse_initer() { 17 | if (!get_device_count()) return; 18 | checkCudaErrors(cusparseCreate(&cusparse_handle)); 19 | LOGv << "cusparseCreate finished" << (void*)cusparse_handle; 20 | } 21 | 22 | inline ~cusparse_initer() { 23 | if (!get_device_count()) return; 24 | LOGv << "cusparseDestroy:" << (void*)cusparse_handle; 25 | checkCudaErrors(cusparseDestroy(cusparse_handle)); 26 | LOGv << "cusparseDestroy finished"; 27 | } 28 | 29 | } init; 30 | 31 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2019 Dun Liang . All Rights Reserved. 3 | // This file is subject to the terms and conditions defined in 4 | // file 'LICENSE.txt', which is part of this source code package. 5 | // *************************************************************** 6 | #include "var.h" 7 | #include "cutt_test_op.h" 8 | #include "utils/str_utils.h" 9 | 10 | #ifdef JIT 11 | #include "cutt.h" 12 | #endif 13 | 14 | namespace jittor { 15 | 16 | #ifndef JIT 17 | CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) { 18 | flags.set(NodeFlags::_cpu, 0); 19 | flags.set(NodeFlags::_cuda, 1); 20 | output = create_output(1, ns_float32); 21 | } 22 | 23 | void CuttTestOp::jit_prepare(JK& jk) { 24 | jk << "«T:float32"; 25 | } 26 | 27 | #else // JIT 28 | #ifdef JIT_cuda 29 | 30 | void CuttTestOp::jit_run() { 31 | auto args = split(cmd, " "); 32 | if (!cmd.size()) args.clear(); 33 | vector v(args.size()); 34 | for (uint i=0; iptr()[0] = 123; 37 | 38 | } 39 | #endif 40 | #endif // JIT 41 | 42 | } // jittor 43 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cutt/ops/cutt_test_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2019 3 | // Guoye Yang <498731903@qq.com> 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct CuttTestOp : Op { 15 | Var* output; 16 | string cmd; 17 | 18 | CuttTestOp(string cmd); 19 | 20 | const char* name() const override { return "cutt_test"; } 21 | DECLARE_jit_run; 22 | }; 23 | 24 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2019 3 | // Guoye Yang <498731903@qq.com> 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct CuttTransposeOp : Op { 15 | Var* x, * y; 16 | NanoVector axes; 17 | CuttTransposeOp(Var* x, NanoVector axes=NanoVector()); 18 | 19 | const char* name() const override { return "cutt_transpose"; } 20 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 21 | void infer_shape() override; 22 | DECLARE_jit_run; 23 | }; 24 | 25 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cutt/ops/cutt_wrapper.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2019 3 | // Dun Liang 4 | // Guowei Yang <471184555@qq.com> 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #include "cutt_wrapper.h" 10 | 11 | 12 | namespace jittor { 13 | 14 | void jt_alloc(void** p, size_t len, size_t& allocation) { 15 | *p = exe.allocator->alloc(len, allocation); 16 | } 17 | 18 | void jt_free(void* p, size_t len, size_t& allocation) { 19 | exe.allocator->free(p, len, allocation); 20 | } 21 | 22 | struct cutt_initer { 23 | 24 | inline cutt_initer() { 25 | custom_cuda_malloc = jt_alloc; 26 | custom_cuda_free = jt_free; 27 | LOGv << "cuttCreate finished"; 28 | } 29 | 30 | inline ~cutt_initer() { 31 | LOGv << "cuttDestroy finished"; 32 | } 33 | 34 | } cutt_init; 35 | 36 | } // jittor 37 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/cutt/ops/cutt_wrapper.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2019 3 | // Dun Liang 4 | // Guowei Yang <471184555@qq.com> 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "executor.h" 11 | #include "CudaUtils.h" 12 | 13 | void jt_alloc(void** p, size_t len, size_t& allocation); 14 | 15 | void jt_free(void* p, size_t len, size_t& allocation); -------------------------------------------------------------------------------- /python/jittor/extern/cuda/inc/fp16_dev.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2014 NVIDIA Corporation. All rights reserved. 3 | * 4 | * Please refer to the NVIDIA end user license agreement (EULA) associated 5 | * with this source code for terms and conditions that govern your use of 6 | * this software. Any use, reproduction, disclosure, or distribution of 7 | * this software and related documentation outside the terms of the EULA 8 | * is strictly prohibited. 9 | * 10 | */ 11 | 12 | #if !defined(_FP16_DEV_H_) 13 | #define _FP16_DEV_H_ 14 | 15 | #include "fp16_emu.h" 16 | 17 | template 18 | void gpu_float2half_rn(int size, const value_type *buffIn, half1 *buffOut); 19 | 20 | #endif // _FP16_DEV_H_ 21 | 22 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/nccl/inc/nccl_wrapper.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. 3 | // All Rights Reserved. 4 | // Maintainers: 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "mpi_wrapper.h" 12 | 13 | #include 14 | #include 15 | #include "utils/log.h" 16 | #include "helper_cuda.h" 17 | 18 | namespace jittor { 19 | 20 | EXTERN_LIB ncclComm_t comm; 21 | EXTERN_LIB ncclUniqueId id; 22 | EXTERN_LIB int nccl_device_id; 23 | 24 | } // jittor 25 | -------------------------------------------------------------------------------- /python/jittor/extern/cuda/nccl/ops/nccl_all_gather_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Guoye Yang <498731903@qq.com>. 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct NcclAllGatherOp : Op { 15 | Var* x, * y; 16 | 17 | NcclAllGatherOp(Var* x); 18 | void infer_shape() override; 19 | 20 | const char* name() const override { return "nccl_all_gather"; } 21 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 22 | DECLARE_jit_run; 23 | }; 24 | 25 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Guoye Yang <498731903@qq.com>. 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct NcclAllReduceOp : Op { 15 | Var* x, * y; 16 | 17 | NcclAllReduceOp(Var* x); 18 | void infer_shape() override; 19 | 20 | const char* name() const override { return "nccl_all_reduce"; } 21 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 22 | DECLARE_jit_run; 23 | }; 24 | 25 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Guoye Yang <498731903@qq.com>. 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct NcclBroadcastOp : Op { 15 | Var* x, * y; 16 | int root; 17 | 18 | NcclBroadcastOp(Var* x, int root=0); 19 | void infer_shape() override; 20 | 21 | const char* name() const override { return "nccl_broadcast"; } 22 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 23 | DECLARE_jit_run; 24 | }; 25 | 26 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Guoye Yang <498731903@qq.com>. 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct NcclReduceOp : Op { 15 | Var* x, * y; 16 | int root; 17 | 18 | NcclReduceOp(Var* x, int root=0); 19 | void infer_shape() override; 20 | 21 | const char* name() const override { return "nccl_reduce"; } 22 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 23 | DECLARE_jit_run; 24 | }; 25 | 26 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/cuda/nccl/ops/nccl_test_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. 3 | // All Rights Reserved. 4 | // Maintainers: 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct NcclTestOp : Op { 16 | Var* output; 17 | string cmd; 18 | 19 | NcclTestOp(string cmd); 20 | 21 | const char* name() const override { return "nccl_test"; } 22 | DECLARE_jit_run; 23 | }; 24 | 25 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct MklConvBackwardWOp : Op { 16 | Var* x, * dy, * dw; 17 | int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; 18 | string xformat, wformat, yformat; 19 | 20 | MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); 21 | 22 | const char* name() const override { return "mkl_conv_backward_w"; } 23 | void infer_shape() override; 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor 28 | -------------------------------------------------------------------------------- /python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct MklConvBackwardXOp : Op { 16 | Var* w, * dy, * dx; 17 | int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; 18 | string xformat, wformat, yformat; 19 | 20 | MklConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); 21 | 22 | const char* name() const override { return "mkl_conv_backward_x"; } 23 | void infer_shape() override; 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/mkl/ops/mkl_conv_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct MklConvOp : Op { 16 | Var* x, * w, * y; 17 | int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; 18 | string xformat, wformat, yformat; 19 | /* MklConvOp: xformat abcd represents nchw */ 20 | MklConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); 21 | 22 | const char* name() const override { return "mkl_conv"; } 23 | void infer_shape() override; 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/mkl/ops/mkl_matmul_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "op.h" 12 | 13 | namespace jittor { 14 | 15 | struct MklMatmulOp : Op { 16 | Var* a, * b, * c; 17 | bool trans_a, trans_b; 18 | MklMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); 19 | 20 | const char* name() const override { return "mkl_matmul"; } 21 | void infer_shape() override; 22 | DECLARE_jit_run; 23 | }; 24 | 25 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/mkl/ops/mkl_test_op.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | 9 | #include "var.h" 10 | #include "mkl_test_op.h" 11 | 12 | int mkl_test_entry(); 13 | 14 | namespace jittor { 15 | 16 | #ifndef JIT 17 | MklTestOp::MklTestOp() { 18 | output = create_output(1, ns_float32); 19 | } 20 | 21 | void MklTestOp::jit_prepare(JK& jk) { 22 | jk << "«T:float32"; 23 | } 24 | 25 | #else // JIT 26 | #ifdef JIT_cpu 27 | void MklTestOp::jit_run() { 28 | ASSERT(mkl_test_entry()==0); 29 | output->ptr()[0] = 123; 30 | } 31 | #endif 32 | #endif // JIT 33 | 34 | } // jittor 35 | -------------------------------------------------------------------------------- /python/jittor/extern/mkl/ops/mkl_test_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct MklTestOp : Op { 13 | Var* output; 14 | MklTestOp(); 15 | 16 | const char* name() const override { return "mkl_test"; } 17 | DECLARE_jit_run; 18 | }; 19 | 20 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/mpi/ops/mpi_all_reduce_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Guowei Yang <471184555@qq.com>. 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct MpiAllReduceOp : Op { 15 | Var* x, * y; 16 | NanoString op; 17 | 18 | /** 19 | 20 | Mpi All Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and broadcast to all MPI nodes. 21 | 22 | Args: 23 | 24 | * x: variable to be all reduced. 25 | * op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'. 26 | */ 27 | MpiAllReduceOp(Var* x, NanoString op=ns_add); 28 | void infer_shape() override; 29 | 30 | const char* name() const override { return "mpi_all_reduce"; } 31 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 32 | DECLARE_jit_run; 33 | }; 34 | 35 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/mpi/ops/mpi_broadcast_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Guowei Yang <471184555@qq.com>. 4 | // Dun Liang . 5 | // All Rights Reserved. 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct MpiBroadcastOp : Op { 15 | Var* x, * y; 16 | int root; 17 | 18 | /** 19 | 20 | Mpi Broadcast Operator broadcasts variable [x] in [root] MPI nodes to all MPI nodes. 21 | 22 | Args: 23 | 24 | * x: variable to be broadcasted. 25 | * root: ID of MPI node to be broadcasted. Default: 0. 26 | */ 27 | MpiBroadcastOp(Var* x, int root=0); 28 | void infer_shape() override; 29 | 30 | const char* name() const override { return "mpi_broadcast"; } 31 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 32 | DECLARE_jit_run; 33 | }; 34 | 35 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/mpi/ops/mpi_test_op.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2019 Dun Liang . All Rights Reserved. 3 | // This file is subject to the terms and conditions defined in 4 | // file 'LICENSE.txt', which is part of this source code package. 5 | // *************************************************************** 6 | #include "mpi_wrapper.h" 7 | 8 | #include "var.h" 9 | #include "mpi_test_op.h" 10 | #include "utils/str_utils.h" 11 | 12 | namespace jittor { 13 | 14 | #ifndef JIT 15 | MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) { 16 | output = create_output(1, ns_float32); 17 | } 18 | 19 | void MpiTestOp::jit_prepare(JK& jk) { 20 | jk << "«T:float32"; 21 | } 22 | 23 | #else // JIT 24 | 25 | void MpiTestOp::jit_run() { 26 | output->ptr()[0] = 123; 27 | 28 | int world_size = mpi_world_size; 29 | 30 | int world_rank = mpi_world_rank; 31 | 32 | char processor_name[MPI_MAX_PROCESSOR_NAME]; 33 | int name_len; 34 | MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len)); 35 | 36 | printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size); 37 | 38 | } 39 | 40 | #endif // JIT 41 | 42 | } // jittor 43 | -------------------------------------------------------------------------------- /python/jittor/extern/mpi/ops/mpi_test_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Dun Liang . 4 | // All Rights Reserved. 5 | // This file is subject to the terms and conditions defined in 6 | // file 'LICENSE.txt', which is part of this source code package. 7 | // *************************************************************** 8 | #pragma once 9 | #include "op.h" 10 | 11 | namespace jittor { 12 | 13 | struct MpiTestOp : Op { 14 | Var* output; 15 | string cmd; 16 | 17 | MpiTestOp(string cmd); 18 | 19 | const char* name() const override { return "mpi_test"; } 20 | DECLARE_jit_run; 21 | }; 22 | 23 | } // jittor -------------------------------------------------------------------------------- /python/jittor/extern/rocm/rocm_cache.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/jittor/444154c2a1e63ee4a0af07831a0c54e2ebb7a561/python/jittor/extern/rocm/rocm_cache.tar.gz -------------------------------------------------------------------------------- /python/jittor/extern/rocm/rocm_config.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2021 Jittor. All Rights Reserved. 3 | // Maintainers: Zheng-Ning Liu . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | void rocm_config(const string& name, string& src); 13 | 14 | } 15 | 16 | 17 | -------------------------------------------------------------------------------- /python/jittor/extern/rocm/rocm_jittor.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2021 Jittor. All Rights Reserved. 3 | // Maintainers: Zheng-Ning Liu . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | void rocm_jittor_op_compiler(string& filename, string& src, bool is_rocm, string& extra_flags); 13 | 14 | } 15 | -------------------------------------------------------------------------------- /python/jittor/gradfunctional/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import jvp, vjp 2 | 3 | -------------------------------------------------------------------------------- /python/jittor/loss3d/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer import chamfer_loss, ChamferLoss 2 | from .emd import earth_mover_distance, EarthMoverDistance 3 | -------------------------------------------------------------------------------- /python/jittor/math_util/__init__.py: -------------------------------------------------------------------------------- 1 | from .gamma import digamma, lgamma 2 | from .igamma import igamma 3 | -------------------------------------------------------------------------------- /python/jittor/math_util/igamma.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import jittor as jt 5 | from jittor import nn 6 | 7 | f = open(os.path.join(os.path.realpath(os.path.dirname(__file__)), "src", "igamma.h"), "r") 8 | cuda_header = f.read() 9 | f.close() 10 | 11 | def igamma(alpha, x): 12 | cuda_src = ''' 13 | @alias(x, in0) 14 | @alias(px ,out0) 15 | int batch_size = x_stride0 == 1 ? 1 : x_shape0; 16 | int batch_shape = x_shape0 * x_stride0 / batch_size; 17 | float alpha = data["alpha"]; 18 | igamma_kernel<<>>(x_p, px_p, alpha, batch_shape); 19 | ''' 20 | out = jt.code(x.shape, x.dtype, [x], cuda_header=cuda_header, cuda_src=cuda_src, data={"alpha": alpha}) 21 | return out 22 | -------------------------------------------------------------------------------- /python/jittor/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet 2 | from .resnet import * 3 | from . import vgg 4 | from .vgg import * 5 | from . import alexnet 6 | from .alexnet import * 7 | from . import squeezenet 8 | from .squeezenet import * 9 | from . import inception 10 | from .inception import * 11 | from . import googlenet 12 | from .googlenet import * 13 | from . import mobilenet 14 | from .mobilenet import * 15 | from . import mnasnet 16 | from .mnasnet import * 17 | from . import shufflenetv2 18 | from .shufflenetv2 import * 19 | from .res2net import res2net50, res2net101 20 | from . import densenet 21 | from .densenet import * 22 | -------------------------------------------------------------------------------- /python/jittor/notebook/60分钟快速入门Jittor/README.md: -------------------------------------------------------------------------------- 1 | # 计图零基础入门教程(60分钟) 2 | 3 | ``` 4 | git clone https://github.com/Jittor/LearnJittorBasicIn60Min.git 5 | cd LearnJittorBasicIn60Min 6 | jupyter notebook 7 | ``` 8 | 9 | 在线浏览地址: 10 | 11 | 特别感谢教程作者:llt 12 | -------------------------------------------------------------------------------- /python/jittor/notebook/60分钟快速入门Jittor/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/jittor/444154c2a1e63ee4a0af07831a0c54e2ebb7a561/python/jittor/notebook/60分钟快速入门Jittor/mnist.png -------------------------------------------------------------------------------- /python/jittor/notebook/__main__.py: -------------------------------------------------------------------------------- 1 | from .md_to_ipynb import dirname, notebook_dir 2 | import os 3 | import sys 4 | import shutil 5 | from distutils.dir_util import copy_tree 6 | 7 | copy_tree(dirname, notebook_dir) 8 | os.chdir(notebook_dir) 9 | os.system(f"{sys.executable} -m jupyter notebook") -------------------------------------------------------------------------------- /python/jittor/notebook/profiler.src.md: -------------------------------------------------------------------------------- 1 | # Profiler: Profiling your model 2 | 3 | # 性能分析器:分析您的模型 4 | 5 | > NOTE: This tutorial is still working in progress 6 | 7 | In this tutorial, we will show: 8 | 1. how to profiling your model and check the elapsed time of each operation 9 | 2. profiling the cache hit rate 10 | 11 | > 注意:本教程仍在持续更新中 12 | 13 | 在本教程中,我们将展示: 14 | 15 | 1. 如何分析模型并检查每个运算的耗时 16 | 2. 分析缓存命中率 17 | 18 | ```python 19 | import jittor as jt 20 | ``` -------------------------------------------------------------------------------- /python/jittor/script/build_aarch64_mkl.sh: -------------------------------------------------------------------------------- 1 | # wget https://github.com/oneapi-src/oneDNN/archive/refs/tags/v2.2.zip 2 | # extract zip 3 | # cd to root folder 4 | 5 | mkdir -p build 6 | cd build 7 | make clean 8 | export CC=aarch64-linux-gnu-gcc-8 9 | export CXX=aarch64-linux-gnu-g++-8 10 | cmake .. \ 11 | -DCMAKE_SYSTEM_NAME=Linux \ 12 | -DCMAKE_SYSTEM_PROCESSOR=AARCH64 \ 13 | -DCMAKE_LIBRARY_PATH=/usr/aarch64-linux-gnu/lib \ 14 | -DCMAKE_BUILD_TYPE=Release 15 | # -DCMAKE_SHARED_LINKER_FLAGS=' -lm ' \ 16 | make -j8 17 | 18 | name=dnnl_lnx_2.2.0_cpu_gomp_aarch64 19 | mkdir -p $name 20 | cp -r ../include ./$name/ 21 | mkdir -p ./$name/lib 22 | cp ./src/libmkldnn.so ./$name/lib/libmkldnn.so 23 | cp -r ../examples ./$name/ 24 | cp ./include/oneapi/dnnl/* ./$name/include/oneapi/dnnl/ 25 | 26 | tar -acvf $name.tgz ./$name/ 27 | 28 | rsync -avPu $name.tgz jittor-web:Documents/jittor-blog/assets/ 29 | ssh jittor-web Documents/jittor-blog.git/hooks/post-update 30 | echo "https://cg.cs.tsinghua.edu.cn/jittor/assets/$name.tgz" 31 | md5sum $name.tgz -------------------------------------------------------------------------------- /python/jittor/script/converter_server.sh: -------------------------------------------------------------------------------- 1 | cat > /tmp/converter_server.dockerfile <<\EOF 2 | FROM jittor/jittor 3 | 4 | RUN python3.7 -m pip install flask 5 | RUN apt update && apt install git -y 6 | EOF 7 | 8 | docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile . 9 | 10 | # docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server" 11 | while true; do 12 | timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:58187:5000 -v /etc/letsencrypt/:/https jittor/converter_server bash -c "python3.7 -m pip install -U jittor && python3.7 -m jittor.test.test_core && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/live/randonl.me/fullchain.pem --key=/https/live/randonl.me/privkey.pem --host=0.0.0.0" 13 | sleep 10 14 | done -------------------------------------------------------------------------------- /python/jittor/script/install_mkl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xe 3 | if [ "$cache_path" = "" ]; then 4 | bpath=$(dirname "${BASH_SOURCE[0]}") 5 | cd $bpath 6 | cd ../extern/mkl 7 | else 8 | cd $cache_path 9 | fi 10 | filename="mkldnn_lnx_1.0.2_cpu_gomp.tgz" 11 | dirname="mkldnn_lnx_1.0.2_cpu_gomp" 12 | if [ ! -f $filename ]; then 13 | wget https://github.com/intel/mkl-dnn/releases/download/v1.0.2/$filename 14 | fi 15 | if [ ! -d $dirname ]; then 16 | tar zxvf $filename 17 | fi 18 | 19 | if [ ! -f $dirname/examples/test ]; then 20 | echo "compile mkldnn example and test" 21 | cd $dirname/examples 22 | g++ -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test 23 | fi -------------------------------------------------------------------------------- /python/jittor/script/update.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | bpath=$(dirname "${BASH_SOURCE[0]}") 3 | cd $bpath 4 | cd .. 5 | pwd 6 | git fetch --all 7 | git reset --hard origin/master 8 | python3.7 -c "import jittor" -------------------------------------------------------------------------------- /python/jittor/src/core.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | 9 | #include "var.h" 10 | #include "op.h" 11 | #include "var_holder.h" 12 | 13 | namespace jittor { 14 | 15 | // @pyjt(number_of_hold_vars) 16 | inline static uint64 get_number_of_hold_vars() { 17 | return hold_vars.size(); 18 | } 19 | 20 | // @pyjt(number_of_lived_vars) 21 | inline static int64 get_number_of_lived_vars() { 22 | return Var::number_of_lived_vars; 23 | } 24 | 25 | // @pyjt(number_of_lived_ops) 26 | inline static int64 get_number_of_lived_ops() { 27 | return Op::number_of_lived_ops; 28 | } 29 | 30 | // @pyjt(print_trace) 31 | inline static void __print_trace() { 32 | print_trace(); 33 | } 34 | 35 | // @pyjt(grad) 36 | vector _grad(VarHolder* loss, const vector& targets, bool retain_graph=true); 37 | 38 | } // jittor 39 | -------------------------------------------------------------------------------- /python/jittor/src/executor.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // Guoye Yang <498731903@qq.com> 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "common.h" 12 | #include "mem/allocator.h" 13 | #ifdef HAS_CUDA 14 | #include 15 | #include "helper_cuda.h" 16 | #endif 17 | 18 | namespace jittor { 19 | 20 | struct Executor { 21 | Allocator* allocator; 22 | Allocator* temp_allocator; 23 | bool last_is_cuda = false; 24 | void run_sync(vector vars, bool device_sync, bool weak_sync=true); 25 | 26 | inline Allocation alloc_temp(size_t size) { 27 | return Allocation(temp_allocator, size); 28 | } 29 | }; 30 | 31 | EXTERN_LIB Executor exe; 32 | 33 | void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, int ll, int rr, int64 tt); 34 | 35 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/fuser.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "common.h" 12 | 13 | namespace jittor { 14 | 15 | void count_fuse(int64_t tt, int start_var_num, const vector& ops, const vector& vars, vector &father, vector &var_fused); 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/grad.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "ops/tape_op.h" 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | vector grad(Var* loss, vector targets, bool retain_graph=true); 13 | 14 | // @pyjt(tape_together) 15 | void tape_together( 16 | const vector& taped_inputs, 17 | const vector& taped_outputs, 18 | GradCallback&& grad_callback 19 | ); 20 | 21 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/init.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include 9 | #include "common.h" 10 | 11 | namespace jittor { 12 | 13 | typedef void (*set_seed_callback)(int); 14 | 15 | void init(); 16 | 17 | /** 18 | Sets the seed of jittor random number generator. Also see @jittor.set_global_seed. 19 | 20 | ---------------- 21 | 22 | * [in] seed: a python number. 23 | 24 | */ 25 | // @pyjt(set_seed, seed) 26 | void set_seed(int seed); 27 | 28 | /** 29 | Returns the seed of jittor random number generator. 30 | */ 31 | // @pyjt(get_seed) 32 | int get_seed(); 33 | 34 | void add_set_seed_callback(set_seed_callback callback); 35 | 36 | extern 37 | std::default_random_engine* get_random_engine(); 38 | 39 | // things need to be clean before python exit 40 | // @pyjt(cleanup) 41 | void cleanup(); 42 | 43 | // @pyjt(jt_init_subprocess) 44 | void jt_init_subprocess(); 45 | 46 | } // jittor 47 | -------------------------------------------------------------------------------- /python/jittor/src/jit_compiler.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "common.h" 8 | #include "op_compiler.h" 9 | 10 | namespace jittor { 11 | namespace jit_compiler { 12 | 13 | jit_op_entry_t compile( 14 | const string& jit_key, 15 | const string& src, 16 | const bool is_cuda_op = false, 17 | const string& extra_flags=""); 18 | 19 | } // jit_compiler 20 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/lock.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Wenyang Zhou <576825820@qq.com> 5 | // Dun Liang 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "common.h" 12 | 13 | namespace jittor { 14 | 15 | // @pyjt(set_lock_path) 16 | void set_lock_path(string path); 17 | 18 | void lock(); 19 | 20 | void unlock(); 21 | 22 | EXTERN_LIB int _has_lock; 23 | 24 | struct lock_guard { 25 | int has_lock = 0; 26 | inline lock_guard() { 27 | if (_has_lock) return; 28 | has_lock = 1; 29 | lock(); 30 | } 31 | inline ~lock_guard() { 32 | if (!has_lock) return; 33 | unlock(); 34 | } 35 | }; 36 | 37 | } // jittor 38 | -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/aligned_allocator.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "mem/allocator.h" 9 | 10 | namespace jittor { 11 | 12 | struct AlignedAllocator : Allocator { 13 | uint64 flags() const override { return _aligned; } 14 | const char* name() const override; 15 | void* alloc(size_t size, size_t& allocation) override; 16 | void free(void* mem_ptr, size_t size, const size_t& allocation) override; 17 | }; 18 | 19 | EXTERN_LIB AlignedAllocator aligned_allocator; 20 | 21 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/cuda_device_allocator.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #ifdef HAS_CUDA 9 | #include "mem/allocator.h" 10 | 11 | namespace jittor { 12 | 13 | struct CudaDeviceAllocator : Allocator { 14 | uint64 flags() const override { return _cuda; } 15 | const char* name() const override; 16 | void* alloc(size_t size, size_t& allocation) override; 17 | void free(void* mem_ptr, size_t size, const size_t& allocation) override; 18 | }; 19 | 20 | EXTERN_LIB CudaDeviceAllocator cuda_device_allocator; 21 | 22 | } 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/cuda_dual_allocator.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #ifdef HAS_CUDA 8 | #include "misc/cuda_flags.h" 9 | #include "mem/allocator/cuda_dual_allocator.h" 10 | #include "mem/allocator/cuda_host_allocator.h" 11 | #include "mem/allocator/cuda_device_allocator.h" 12 | #include "event_queue.h" 13 | 14 | namespace jittor { 15 | 16 | SFRLAllocator cuda_dual_host_allocator(&cuda_host_allocator, 0.3, 1<<28); 17 | SFRLAllocator cuda_dual_device_allocator(&cuda_device_allocator, 0.3, 1<<28); 18 | CudaDualAllocator cuda_dual_allocator; 19 | DelayFree delay_free; 20 | 21 | namespace cuda_dual_local { 22 | 23 | list allocations; 24 | 25 | static void free_caller() { 26 | allocations.pop_front(); 27 | } 28 | 29 | } 30 | 31 | void to_free_allocation(CUDA_HOST_FUNC_ARGS) { 32 | using namespace cuda_dual_local; 33 | event_queue.push(free_caller); 34 | } 35 | 36 | } 37 | 38 | #endif 39 | -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/cuda_host_allocator.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #ifdef HAS_CUDA 8 | #include 9 | #include "helper_cuda.h" 10 | #include "mem/allocator/cuda_host_allocator.h" 11 | 12 | namespace jittor { 13 | 14 | CudaHostAllocator cuda_host_allocator; 15 | EXTERN_LIB bool no_cuda_error_when_free; 16 | 17 | const char* CudaHostAllocator::name() const {return "cuda_host";} 18 | 19 | void* CudaHostAllocator::alloc(size_t size, size_t& allocation) { 20 | if (size==0) return (void*)0x10; 21 | void* ptr; 22 | checkCudaErrors(cudaMallocHost(&ptr, size)); 23 | return ptr; 24 | } 25 | 26 | void CudaHostAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { 27 | if (size==0) return; 28 | if (no_cuda_error_when_free) return; 29 | checkCudaErrors(cudaFreeHost(mem_ptr)); 30 | } 31 | 32 | } // jittor 33 | 34 | #endif 35 | -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/cuda_host_allocator.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #ifdef HAS_CUDA 9 | #include "mem/allocator.h" 10 | 11 | namespace jittor { 12 | 13 | struct CudaHostAllocator : Allocator { 14 | inline uint64 flags() const override { return 0; } 15 | const char* name() const override; 16 | void* alloc(size_t size, size_t& allocation) override; 17 | void free(void* mem_ptr, size_t size, const size_t& allocation) override; 18 | }; 19 | 20 | EXTERN_LIB CudaHostAllocator cuda_host_allocator; 21 | 22 | } 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/cuda_managed_allocator.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #ifdef HAS_CUDA 9 | #include "mem/allocator.h" 10 | 11 | namespace jittor { 12 | 13 | struct CudaManagedAllocator : Allocator { 14 | uint64 flags() const override { return _cuda; } 15 | const char* name() const override; 16 | void* alloc(size_t size, size_t& allocation) override; 17 | void free(void* mem_ptr, size_t size, const size_t& allocation) override; 18 | }; 19 | 20 | EXTERN_LIB CudaManagedAllocator cuda_managed_allocator; 21 | DECLARE_FLAG(int, use_cuda_managed_allocator); 22 | 23 | } 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/foreign_allocator.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "mem/allocator.h" 9 | 10 | namespace jittor { 11 | 12 | struct ForeignAllocator : Allocator { 13 | uint64 flags() const override { return _aligned; } 14 | const char* name() const override; 15 | void* alloc(size_t size, size_t& allocation) override; 16 | void free(void* mem_ptr, size_t size, const size_t& allocation) override; 17 | bool share_with(size_t size, size_t allocation) override; 18 | }; 19 | 20 | void make_foreign_allocation(Allocation& a, void* ptr, size_t size, std::function&& del_func); 21 | 22 | EXTERN_LIB ForeignAllocator foreign_allocator; 23 | 24 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/nfef_allocator.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "mem/allocator/nfef_allocator.h" 8 | #include "var.h" 9 | 10 | namespace jittor { 11 | 12 | DEFINE_FLAG(int, use_nfef_allocator, 0, "Enable never free exact fit allocator"); 13 | 14 | void NFEFAllocator::setup(Allocator* underlying) { 15 | this->underlying = underlying; 16 | } 17 | 18 | const char* NFEFAllocator::name() const {return "nfef";} 19 | 20 | void* NFEFAllocator::alloc(size_t size, size_t& allocation) { 21 | auto iter = freed.find(size); 22 | if (iter == freed.end() || iter->second.empty()) 23 | return underlying->alloc(size, allocation); 24 | auto ptr = iter->second.front(); 25 | iter->second.pop_front(); 26 | return ptr; 27 | } 28 | 29 | void NFEFAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { 30 | freed[size].push_front(mem_ptr); 31 | } 32 | 33 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/nfef_allocator.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include 9 | #include 10 | #include "mem/allocator.h" 11 | 12 | namespace jittor { 13 | 14 | // Never free exact fit allocator 15 | struct NFEFAllocator : Allocator { 16 | Allocator* underlying; 17 | std::unordered_map> freed; 18 | 19 | void setup(Allocator* underlying); 20 | uint64 flags() const override { return underlying->flags(); } 21 | const char* name() const override; 22 | void* alloc(size_t size, size_t& allocation) override; 23 | void free(void* mem_ptr, size_t size, const size_t& allocation) override; 24 | }; 25 | 26 | DECLARE_FLAG(int, use_nfef_allocator); 27 | 28 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/mem/allocator/stat_allocator.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "mem/allocator.h" 9 | 10 | namespace jittor { 11 | 12 | struct StatAllocator : Allocator { 13 | Allocator* underlying; 14 | 15 | void setup(Allocator* underlying); 16 | uint64 flags() const override { return underlying->flags(); } 17 | const char* name() const override; 18 | void* alloc(size_t size, size_t& allocation) override; 19 | void free(void* mem_ptr, size_t size, const size_t& allocation) override; 20 | }; 21 | 22 | DECLARE_FLAG(int, use_stat_allocator); 23 | DECLARE_FLAG(size_t, stat_allocator_total_alloc_call); 24 | DECLARE_FLAG(size_t, stat_allocator_total_alloc_byte); 25 | DECLARE_FLAG(size_t, stat_allocator_total_free_call); 26 | DECLARE_FLAG(size_t, stat_allocator_total_free_byte); 27 | 28 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/mem/mem_info.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | // @pyjt(display_memory_info) 13 | void display_memory_info(const char* fileline="", bool dump_var=false, bool red_color=false); 14 | 15 | // @pyjt(MemInfo) 16 | struct MemInfo { 17 | // @pyjt(total_cpu_ram) 18 | int64 total_cpu_ram; 19 | // @pyjt(total_cuda_ram) 20 | int64 total_cuda_ram; 21 | // @pyjt(total_cpu_used) 22 | int64 total_cpu_used; 23 | // @pyjt(total_cuda_used) 24 | int64 total_cuda_used; 25 | 26 | inline MemInfo(const MemInfo&) = default; 27 | 28 | MemInfo(); 29 | }; 30 | 31 | EXTERN_LIB MemInfo mem_info; 32 | 33 | // @pyjt(get_mem_info) 34 | inline MemInfo get_mem_info() { return MemInfo(); } 35 | 36 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/misc/cpu_atomic.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "misc/cpu_atomic.h" 8 | 9 | namespace jittor { 10 | 11 | std::atomic_flag lock = ATOMIC_FLAG_INIT;; 12 | 13 | } // jittor 14 | -------------------------------------------------------------------------------- /python/jittor/src/misc/cpu_math.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | float _erfinv(float y); 13 | double _erfinv(double y); 14 | 15 | } 16 | 17 | -------------------------------------------------------------------------------- /python/jittor/src/misc/cuda_flags.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | 11 | #ifdef HAS_CUDA 12 | #include 13 | 14 | namespace jittor { 15 | 16 | DECLARE_FLAG(int, use_cuda); 17 | DECLARE_FLAG(int, sync_run); 18 | 19 | // @pyjt(get_device_count) 20 | int get_device_count(); 21 | 22 | } // jittor 23 | 24 | #if defined(CUDART_VERSION) && CUDART_VERSION < 10000 25 | #define _cudaLaunchHostFunc(a,b,c) \ 26 | cudaStreamAddCallback(a,b,c,0) 27 | #define CUDA_HOST_FUNC_ARGS cudaStream_t stream, cudaError_t status, void* 28 | #else 29 | #define _cudaLaunchHostFunc(a,b,c) \ 30 | cudaLaunchHostFunc(a,b,c) 31 | #define CUDA_HOST_FUNC_ARGS void* 32 | #endif 33 | 34 | #else 35 | 36 | namespace jittor { 37 | 38 | constexpr int use_cuda = 0; 39 | 40 | inline int get_device_count() { return 0; } 41 | 42 | } // jittor 43 | #endif 44 | -------------------------------------------------------------------------------- /python/jittor/src/misc/deleter.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | #include 10 | 11 | namespace jittor { 12 | 13 | struct Deleter { 14 | std::function del; 15 | inline Deleter(std::function&& func) : del(move(func)) {} 16 | inline Deleter() {} 17 | inline ~Deleter() { if (del) del(); } 18 | }; 19 | 20 | } // jittor 21 | -------------------------------------------------------------------------------- /python/jittor/src/misc/hash.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | uint constexpr const_hash(const char *input) { 13 | return *input ? 14 | static_cast(*input) + 55 * const_hash(input + 1) : 15 | 0; 16 | } 17 | 18 | /* simple hash function */ 19 | // @pyjt(hash) 20 | inline uint hash(const char* input) { 21 | uint v=0, mul=1; 22 | while (*input) { 23 | v += mul * (uint)*input; 24 | mul *= 55; 25 | input++; 26 | } 27 | return v; 28 | } 29 | 30 | 31 | inline uint64 hash64(const string& input) { 32 | uint64 v=0, mul=1; 33 | for (char c : input) { 34 | v += mul * (uint64)c; 35 | mul *= 257; 36 | } 37 | return v; 38 | } 39 | 40 | } // jittor 41 | -------------------------------------------------------------------------------- /python/jittor/src/misc/intrin.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | static inline int lzcnt(int64 v) { 13 | #ifdef __clang__ 14 | #if __has_feature(__builtin_ia32_lzcnt_u64) 15 | return __builtin_ia32_lzcnt_u64(v); 16 | #else 17 | return v ? __builtin_clzll(v) : 64; 18 | #endif 19 | #else 20 | #ifdef _MSC_VER 21 | unsigned long index; 22 | _BitScanReverse64(&index, v); 23 | return v ? 63-index : 64; 24 | #else 25 | return __builtin_clzll(v); 26 | #endif 27 | #endif 28 | } 29 | 30 | } -------------------------------------------------------------------------------- /python/jittor/src/misc/nan_checker.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // This file is subject to the terms and conditions defined in 4 | // file 'LICENSE.txt', which is part of this source code package. 5 | // *************************************************************** 6 | #pragma once 7 | #include "op.h" 8 | #include "var.h" 9 | 10 | namespace jittor { 11 | 12 | bool check_nan(Var* v, Op* op); 13 | 14 | } -------------------------------------------------------------------------------- /python/jittor/src/ops/array_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | #include "mem/allocator.h" 10 | 11 | typedef struct _object PyObject; 12 | 13 | namespace jittor { 14 | 15 | struct ArrayArgs { 16 | const void* ptr; 17 | NanoVector shape; 18 | NanoString dtype; 19 | unique_ptr buffer; 20 | }; 21 | 22 | struct ArrayOp : Op { 23 | Var* output; 24 | Allocation allocation; 25 | // @pybind(None) 26 | ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float32); 27 | 28 | // @pybind(array_) 29 | ArrayOp(ArrayArgs&& args); 30 | 31 | ArrayOp(PyObject* obj); 32 | template 33 | inline T* ptr() { return (T*)allocation.ptr; } 34 | 35 | const char* name() const override { return "array"; } 36 | void run() override; 37 | void jit_prepare(JK& jk) override; 38 | }; 39 | 40 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/binary_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct BinaryOp : Op { 13 | Var* x, * y, * z; 14 | BinaryOp(Var* x, Var* y, NanoString p); 15 | 16 | const char* name() const override { return "binary"; } 17 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 18 | void infer_shape() override; 19 | DECLARE_jit_run; 20 | }; 21 | 22 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/clone_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct CloneOp : Op { 15 | Var* x, * y; 16 | CloneOp(Var* x); 17 | 18 | const char* name() const override { return "clone"; } 19 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 20 | void infer_shape() override; 21 | }; 22 | 23 | VarPtr detach(Var* x); 24 | 25 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/copy_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct CopyOp : Op { 15 | CopyOp(Var* x); 16 | 17 | const char* name() const override { return "copy"; } 18 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 19 | void infer_shape() override; 20 | void run() override; 21 | }; 22 | 23 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/empty_op.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #include "var.h" 10 | #include "ops/array_op.h" 11 | #include "ops/op_register.h" 12 | #include "ops/empty_op.h" 13 | 14 | namespace jittor { 15 | 16 | EmptyOp::EmptyOp(NanoVector shape, NanoString dtype) { 17 | flags.set(NodeFlags::_cpu); 18 | flags.set(NodeFlags::_cuda); 19 | create_output(shape, dtype); 20 | } 21 | 22 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/empty_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "op.h" 11 | 12 | namespace jittor { 13 | 14 | struct EmptyOp : Op { 15 | EmptyOp(NanoVector shape, NanoString dtype=ns_float32); 16 | 17 | const char* name() const override { return "empty"; } 18 | }; 19 | 20 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/fuse_transpose_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct FuseTransposeOp : Op { 13 | Var* x, * y; 14 | NanoVector axes; 15 | FuseTransposeOp(Var* x, NanoVector axes=NanoVector()); 16 | 17 | const char* name() const override { return "fuse_transpose"; } 18 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 19 | void infer_shape() override; 20 | DECLARE_jit_run; 21 | }; 22 | 23 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/random_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct RandomOp : Op { 13 | Var* output; 14 | NanoString type; 15 | RandomOp(NanoVector shape, NanoString dtype=ns_float32, NanoString type=ns_uniform); 16 | 17 | const char* name() const override { return "random"; } 18 | DECLARE_jit_run; 19 | }; 20 | 21 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/reduce_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct ReduceOp : Op { 13 | Var* x, * y; 14 | uint16 reduce_mask; // i-th bit is 1 of dim-i is reduced 15 | uint16 keepdims_mask; 16 | ReduceOp(Var* x, NanoString op, int dim, bool keepdims=false); 17 | ReduceOp(Var* x, NanoString op, NanoVector dims=NanoVector(), bool keepdims=false); 18 | // @pybind(None) 19 | ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask); 20 | 21 | const char* name() const override { return "reduce"; } 22 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 23 | void infer_shape() override; 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/safe_clip_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | 11 | namespace jittor { 12 | 13 | struct SafeClipOp : Op { 14 | Var* x, * y; 15 | float64 left, right; 16 | /** Safe clip value to a range, and keep 17 | the gradient pass thought. 18 | 19 | * [in] x: input value 20 | * [in] left: float64 clip min value. 21 | * [in] right: float64 clip max value. 22 | 23 | */ 24 | // @pybind(safe_clip) 25 | SafeClipOp(Var* x, float64 left=-1e300, float64 right=1e300); 26 | 27 | const char* name() const override { return "safe_clip"; } 28 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 29 | void infer_shape() override; 30 | DECLARE_jit_run; 31 | }; 32 | 33 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/setitem_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | #include "var_slices.h" 10 | 11 | namespace jittor { 12 | 13 | struct SetitemOp : Op { 14 | VarSlices vs; 15 | // map i to related var slice 16 | NanoVector i_to_vs; 17 | // map i to related o 18 | NanoVector i_to_o; 19 | NanoVector o_shape; 20 | int first_oid_of_var, var_dim; 21 | int bmask; 22 | NanoString op; 23 | 24 | SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op=ns_void); 25 | 26 | const char* name() const override { return "setitem"; } 27 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 28 | void grads(Var** dout, VarPtr* dins) override; 29 | void infer_shape() override; 30 | void compile_optimize(string& src) override; 31 | void graph_optimize() override; 32 | DECLARE_jit_run; 33 | }; 34 | 35 | } // jittor 36 | -------------------------------------------------------------------------------- /python/jittor/src/ops/ternary_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct TernaryOp : Op { 13 | Var* cond, * x, * y, * z; 14 | TernaryOp(Var* cond, Var* x, Var* y); 15 | 16 | const char* name() const override { return "ternary"; } 17 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 18 | void infer_shape() override; 19 | DECLARE_jit_run; 20 | }; 21 | 22 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/transpose_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | namespace jittor { 11 | 12 | struct TransposeOp : Op { 13 | Var* x, * y; 14 | NanoVector axes; 15 | TransposeOp(Var* x, NanoVector axes=NanoVector()); 16 | 17 | const char* name() const override { return "transpose"; } 18 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 19 | void infer_shape() override; 20 | DECLARE_jit_run; 21 | }; 22 | 23 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/ops/unary_op.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "op.h" 9 | 10 | 11 | namespace jittor { 12 | 13 | struct UnaryOp : Op { 14 | Var* x, * y; 15 | // @pybind(unary,cast) 16 | UnaryOp(Var* x, NanoString op); 17 | 18 | const char* name() const override { return "unary"; } 19 | VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; 20 | void infer_shape() override; 21 | DECLARE_jit_run; 22 | }; 23 | 24 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/opt/jit_searcher.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | DECLARE_FLAG(int, jit_search_kernel); 13 | 14 | struct Searcher { 15 | OpCompiler* oc; 16 | int64_t timeout, best_time; 17 | loop_options_t best_choices; 18 | 19 | Searcher(OpCompiler* oc); 20 | void reset(); 21 | int64_t get_time_of_current_choices(); 22 | void search(const loop_option_candidates_t& candidates); 23 | }; 24 | 25 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/assume_aligned_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct AssumeAlignedPass : Pass { 13 | AssumeAlignedPass() : Pass("assume_aligned") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/atomic_tuner_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "opt/pass/pass.h" 12 | 13 | namespace jittor { 14 | 15 | struct AtomicTunerPass : Pass { 16 | AtomicTunerPass() : Pass("atomic") {}; 17 | void run() override; 18 | }; 19 | 20 | } // jittor 21 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/check_cache_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "opt/pass/pass.h" 12 | 13 | namespace jittor { 14 | 15 | struct CheckCachePass : Pass { 16 | CheckCachePass() : Pass("check_cache") {}; 17 | void run() override; 18 | }; 19 | 20 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/compile_shapes_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct CompileShapesPass : Pass { 13 | CompileShapesPass() : Pass("compile_shapes") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/const_var_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct ConstVarPass : Pass { 13 | ConstVarPass() : Pass("const_var_pass") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/expand_empty_block_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct ExpandEmptyBlockPass : Pass { 13 | ExpandEmptyBlockPass() : Pass("expand_empty_block") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/fake_main_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct FakeMainPass : Pass { 13 | FakeMainPass() : Pass("fake_main") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/float_atomic_fix_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "opt/pass/pass.h" 11 | 12 | namespace jittor { 13 | 14 | struct FloatAtomicFixPass : Pass { 15 | FloatAtomicFixPass() : Pass("float_atomic_fix") {}; 16 | void run() override; 17 | }; 18 | 19 | } // jittor 20 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/insert_profile_loop_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct InsertProfileLoopPass : Pass { 13 | InsertProfileLoopPass() : Pass("insert_profile_loop") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/loop_to_func_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct LoopToFuncPass : Pass { 13 | LoopToFuncPass() : Pass("loop_to_func") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/loop_var_analyze_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct LoopVarAnalyzePass : Pass { 13 | // total number of loop ranges 14 | int number_of_ranges; 15 | 16 | LoopVarAnalyzePass() : Pass("loop_var_analyze"), number_of_ranges(0) {}; 17 | void run() override; 18 | }; 19 | 20 | } // jittor 21 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/mark_raw_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct MarkRawPass : Pass { 13 | MarkRawPass() : Pass("mark_raw") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/merge_loop_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct MergeLoopPass : Pass { 13 | MergeLoopPass() : Pass("merge_loop") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/merge_loop_var_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct MergeLoopVarPass : Pass { 13 | MergeLoopVarPass() : Pass("merge_loop_var") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/parallel_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct ParallelPass : Pass { 13 | ParallelPass() : Pass("parallel") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/pass.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | #include "opt/pass/pass.h" 9 | #include "opt/pass_manager.h" 10 | 11 | namespace jittor { 12 | 13 | Pass::Pass(const string& name): name(name) {} 14 | Pass::~Pass() {} 15 | 16 | void Pass::init(PassManager* pm) { 17 | this->pm = pm; 18 | op = pm->oc->op; 19 | all = &pm->all; 20 | ir = pm->main_ir; 21 | } 22 | 23 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | #include "fused_op.h" 10 | #include "opt/kernel_ir.h" 11 | 12 | namespace jittor { 13 | 14 | struct Pass { 15 | FusedOp* op; 16 | KernelIR* all; 17 | KernelIR* ir; 18 | PassManager* pm; 19 | string name; 20 | 21 | Pass(const string& name); 22 | virtual ~Pass(); 23 | 24 | void init(PassManager* pm); 25 | virtual void run() = 0; 26 | }; 27 | 28 | } // jittor 29 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/remove_intermediate_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct RemoveIntermediatePass : Pass { 13 | RemoveIntermediatePass() : Pass("remove_intermediate") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/remove_loop_pass.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | #include "var.h" 9 | #include "opt/pass_manager.h" 10 | #include "opt/pass/remove_loop_pass.h" 11 | 12 | namespace jittor { 13 | 14 | void RemoveLoopPass::run() { 15 | int loop_id=0; 16 | for (size_t i=0; ichildren.size(); i++) { 17 | auto& c = ir->children[i]; 18 | if (c->type == "loop") { 19 | auto choice = op->get_loop_option("remove"+S(loop_id)); 20 | if (choice) { 21 | c->erase(); 22 | i--; 23 | } 24 | loop_id++; 25 | } 26 | } 27 | } 28 | 29 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/remove_loop_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | // this is a debug pass, remove i-th loop, key: removei 13 | struct RemoveLoopPass : Pass { 14 | RemoveLoopPass() : Pass("remove_loop") {}; 15 | void run() override; 16 | }; 17 | 18 | } // jittor 19 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/rename_loop_index_pass.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | #include "var.h" 9 | #include "opt/pass_manager.h" 10 | #include "opt/pass/rename_loop_index_pass.h" 11 | 12 | namespace jittor { 13 | 14 | void RenameLoopIndexPass::run() { 15 | // TODO: move out rename_loop_index 16 | ir->rename_loop_index(); 17 | } 18 | 19 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/rename_loop_index_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct RenameLoopIndexPass : Pass { 13 | RenameLoopIndexPass() : Pass("rename_loop_index") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/reorder_loop_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct ReorderLoopPass : Pass { 13 | ReorderLoopPass() : Pass("reorder_loop") {}; 14 | void run() override; 15 | vector search_parse_loop_order(); 16 | }; 17 | 18 | } // jittor 19 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/replace_for_num_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | // replace_for_num pass 13 | // T num=opi_x->num; 14 | // for (T i=0; i 16 | // T opi_xshapej = opi_x->shape[j]; ... 17 | // T opi_xstride{DIM-1} = 1; 18 | // T opi_xstride{j} = opi_xstride{j+1} * opi_xshape{j+1} 19 | // for (T i{d}=0; i{d}. 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct RestridePass : Pass { 13 | RestridePass() : Pass("restride") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/shared_reduce_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Zheng-Ning Liu . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct SharedReducePass : Pass { 13 | SharedReducePass() : Pass("shared_reduce") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/solve_conflict_define_pass.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | #include "var.h" 9 | #include "opt/pass_manager.h" 10 | #include "opt/pass/solve_conflict_define_pass.h" 11 | 12 | namespace jittor { 13 | 14 | void SolveConflictDefinePass::run() { 15 | ir->solve_conflict_define(); 16 | } 17 | 18 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/solve_conflict_define_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct SolveConflictDefinePass : Pass { 13 | SolveConflictDefinePass() : Pass("solve_conflict_define") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/split_loop_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct SplitLoopPass : Pass { 13 | int number_of_ranges_after_split; 14 | 15 | SplitLoopPass() : Pass("split_loop"), number_of_ranges_after_split(0) {}; 16 | void run() override; 17 | }; 18 | 19 | } // jittor 20 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/unroll_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct UnrollPass : Pass { 13 | UnrollPass() : Pass("expand_empty_block") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/use_movnt_pass.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #include 11 | #include "var.h" 12 | #include "opt/pass_manager.h" 13 | #include "opt/pass/use_movnt_pass.h" 14 | 15 | namespace jittor { 16 | 17 | void UseMovntPass::run() { 18 | // TODO: need to test this pass 19 | if (!op->get_loop_option("use_movnt")) 20 | return; 21 | 22 | for (auto& c : ir->children) { 23 | if (c->type != "loop") continue; 24 | c->push_front("//@begin replace \"vmova(.*,.*\\(.*\\))\" \"vmovnt\\g<1>\"", &c->children, true); 25 | c->push_back("//@end", &c->children, true); 26 | } 27 | } 28 | 29 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/use_movnt_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "opt/pass/pass.h" 12 | 13 | namespace jittor { 14 | 15 | struct UseMovntPass : Pass { 16 | UseMovntPass() : Pass("use_movnt") {}; 17 | void run() override; 18 | }; 19 | 20 | } // jittor 21 | -------------------------------------------------------------------------------- /python/jittor/src/opt/pass/vectorize_pass.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "opt/pass/pass.h" 9 | 10 | namespace jittor { 11 | 12 | struct VectorizePass : Pass { 13 | VectorizePass() : Pass("vectorize") {}; 14 | void run() override; 15 | }; 16 | 17 | } // jittor 18 | -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner/broadcast_tuner.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "common.h" 12 | #include "var.h" 13 | #include "opt/tuner_manager.h" 14 | 15 | namespace jittor { 16 | 17 | struct BroadcastTuner : Tuner { 18 | BroadcastTuner() : Tuner("broadcast") {} 19 | void run(PassManager* pm, TunerManager* tm); 20 | }; 21 | 22 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner/conv_tuner.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "common.h" 12 | #include "var.h" 13 | #include "opt/tuner_manager.h" 14 | 15 | namespace jittor { 16 | 17 | struct ConvTuner : Tuner { 18 | ConvTuner() : Tuner("conv") {} 19 | void forwardTune(FusedOp* fop); 20 | void run(PassManager* pm, TunerManager* tm); 21 | }; 22 | 23 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner/matmul_tuner.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "common.h" 11 | #include "var.h" 12 | #include "opt/tuner_manager.h" 13 | 14 | namespace jittor { 15 | 16 | struct MatmulTuner : Tuner { 17 | MatmulTuner() : Tuner("matmul") {} 18 | void run(PassManager* pm, TunerManager* tm); 19 | }; 20 | 21 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner/reduce_tuner.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guowei Yang <471184555@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include "common.h" 12 | #include "var.h" 13 | #include "ops/reduce_op.h" 14 | #include "opt/tuner_manager.h" 15 | 16 | namespace jittor { 17 | 18 | struct ReduceTuner : Tuner { 19 | ReduceTuner() : Tuner("reduce") {} 20 | void run(PassManager* pm, TunerManager* tm); 21 | }; 22 | 23 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner/reorder_tuner.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "common.h" 8 | #include "opt/tuner/reorder_tuner.h" 9 | #include "opt/pass_manager.h" 10 | #include "opt/pass/loop_var_analyze_pass.h" 11 | #include "opt/pass/split_loop_pass.h" 12 | 13 | namespace jittor { 14 | 15 | void ReorderTuner::run(PassManager* pm, TunerManager* tm) { 16 | auto* lva_pass = pm->get_pass("loop_var_analyze"); 17 | auto* sl_pass = pm->get_pass("split_loop"); 18 | if (!sl_pass || !lva_pass) return; 19 | auto number_of_ranges = lva_pass->number_of_ranges; 20 | auto number_of_ranges_after_split = sl_pass->number_of_ranges_after_split; 21 | for (int i=0; i. 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | #include "opt/tuner_manager.h" 10 | 11 | namespace jittor { 12 | 13 | struct ReorderTuner : Tuner { 14 | ReorderTuner() : Tuner("reorder") {} 15 | void run(PassManager* pm, TunerManager* tm); 16 | }; 17 | 18 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner/tuner.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "common.h" 8 | #include "opt/tuner/tuner.h" 9 | 10 | namespace jittor { 11 | 12 | Tuner::Tuner(const string& name) : name(name), confidence(0), candidates({}) {}; 13 | Tuner::~Tuner() {} 14 | 15 | void Tuner::add_candidate(const string& key, int value) { 16 | candidates[key].push_back(value); 17 | } 18 | 19 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner/tuner.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | struct Tuner { 13 | string name; 14 | int confidence; 15 | loop_option_candidates_t candidates; 16 | 17 | Tuner(const string& name); 18 | void add_candidate(const string& key, int value); 19 | virtual ~Tuner(); 20 | virtual void run(PassManager* pm, TunerManager* tm) = 0; 21 | }; 22 | 23 | } -------------------------------------------------------------------------------- /python/jittor/src/opt/tuner_manager.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | #include "opt/tuner/tuner.h" 10 | #include "opt/jit_searcher.h" 11 | 12 | namespace jittor { 13 | 14 | struct TunerManager { 15 | OpCompiler* oc; 16 | Searcher searcher; 17 | Tuner* best_tuner; 18 | 19 | vector> tuners; 20 | 21 | TunerManager(OpCompiler* oc); 22 | string tune(); 23 | 24 | // run and store a tuner, return confidence 25 | template void run_tuner(PassManager* pm); 26 | }; 27 | 28 | } -------------------------------------------------------------------------------- /python/jittor/src/parallel_compiler.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | void parallel_compile_all_ops(vector& queue, vector& range, FusedOp& fused_op, vector& fuse_ops, vector& ops, int64 tt, int force_compile=0); 13 | 14 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/profiler/cache_info.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #include "profiler/cache_info.h" 11 | 12 | namespace jittor { 13 | CacheInfo::CacheInfo(unique_ptr* mm) { 14 | check_times = mm->get()->check_times; 15 | tlb_miss_times = mm->get()->tlb->miss_time; 16 | cache_miss_times.clear(); 17 | for (int i = 0; i < (int)mm->get()->caches.size(); ++i) 18 | cache_miss_times.push_back(mm->get()->caches[i]->miss_time); 19 | } 20 | 21 | CacheInfo::CacheInfo() { 22 | check_times = tlb_miss_times = 0; 23 | cache_miss_times.clear(); 24 | } 25 | 26 | } //jittor -------------------------------------------------------------------------------- /python/jittor/src/profiler/cache_info.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include 12 | #include 13 | #include "profiler/memory_checker.h" 14 | 15 | namespace jittor { 16 | struct CacheInfo { 17 | int64_t check_times, tlb_miss_times; 18 | vector cache_miss_times; 19 | CacheInfo(unique_ptr* mm); 20 | CacheInfo(); 21 | }; 22 | 23 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/profiler/memory_checker.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Guoye Yang <498731903@qq.com> 5 | // Dun Liang . 6 | // 7 | // This file is subject to the terms and conditions defined in 8 | // file 'LICENSE.txt', which is part of this source code package. 9 | // *************************************************************** 10 | #pragma once 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include "common.h" 16 | #include "profiler/replacement.h" 17 | 18 | namespace jittor { 19 | struct MemoryChecker { 20 | Cache* tlb; 21 | vector caches; 22 | size_t page_size; 23 | int64_t check_times; 24 | // translate virtual address to physical address or not 25 | size_t vtop; 26 | 27 | //TODO auto build MemoryChecker 28 | MemoryChecker(Cache* tlb, vector caches, size_t page_size, size_t vtop); 29 | ~MemoryChecker(); 30 | static string get_replace_strategy(int id); 31 | void clear(); 32 | void print_miss(); 33 | void check_hit(size_t vaddr); 34 | }; 35 | 36 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/pybind/py_var_tracer_interface.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include 9 | 10 | namespace jittor { 11 | 12 | // @pyjt(dump_trace_data) 13 | PyObject* dump_trace_data(); 14 | 15 | // @pyjt(clear_trace_data) 16 | void clear_trace_data(); 17 | 18 | } // jittor 19 | -------------------------------------------------------------------------------- /python/jittor/src/pyjt/py_arg_printer.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include 9 | #include "common.h" 10 | 11 | namespace jittor { 12 | 13 | struct PyArgPrinter { 14 | PyObject* obj; 15 | const char* name; 16 | }; 17 | std::ostream& operator<<(std::ostream& os, const PyArgPrinter& arg); 18 | 19 | struct PyTupleArgPrinter { 20 | PyObject* obj; 21 | const char* name; 22 | }; 23 | std::ostream& operator<<(std::ostream& os, const PyTupleArgPrinter& args); 24 | 25 | struct PyKwArgPrinter { 26 | PyObject* obj; 27 | }; 28 | std::ostream& operator<<(std::ostream& os, const PyKwArgPrinter& args); 29 | 30 | struct PyFastCallArgPrinter { 31 | PyObject** obj; 32 | int64 n; 33 | PyObject* kw; 34 | }; 35 | std::ostream& operator<<(std::ostream& os, const PyFastCallArgPrinter& args); 36 | 37 | } 38 | -------------------------------------------------------------------------------- /python/jittor/src/pyjt/py_caller.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #include "pyjt/py_obj_holder.h" 10 | #include "pyjt/py_converter.h" 11 | #include "pyjt/py_caller.h" 12 | 13 | namespace jittor { 14 | 15 | string py_caller(const string& mod_func, const vector& args, const map& kw) { 16 | PyObjHolder mod(PyImport_ImportModule("jittor")); 17 | PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_wrapper")); 18 | PyObjHolder py_name(to_py_object(mod_func)); 19 | PyObjHolder py_args(to_py_tuple(args)); 20 | PyObjHolder py_kw(to_py_object(kw)); 21 | PyObjHolder ret(PyObject_CallFunctionObjArgs(func.obj, py_name.obj, py_args.obj, py_kw.obj, nullptr)); 22 | CHECK(is_type(ret.obj)) << "expect return type string."; 23 | return from_py_object(ret.obj); 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /python/jittor/src/pyjt/py_caller.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: 4 | // Dun Liang . 5 | // 6 | // This file is subject to the terms and conditions defined in 7 | // file 'LICENSE.txt', which is part of this source code package. 8 | // *************************************************************** 9 | #pragma once 10 | #include "common.h" 11 | 12 | namespace jittor { 13 | 14 | string py_caller(const string& mod_func, const vector& args, const map& kw); 15 | 16 | } 17 | -------------------------------------------------------------------------------- /python/jittor/src/test/test_fast_shared_ptr.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include "misc/fast_shared_ptr.h" 8 | 9 | namespace jittor { 10 | 11 | JIT_TEST(fast_shared_ptr) { 12 | unordered_map a; 13 | fast_shared_ptr> ap(move(a)); 14 | ASSERT(ap.ptr==0); 15 | ap = {{"a",1}}; 16 | auto bp = ap; 17 | ASSERT(bp.ptr==ap.ptr && bp.ref_cnt()==2); 18 | ap = nullptr; 19 | ASSERT(ap.ptr==nullptr && bp.ref_cnt()==1); 20 | ap = clone(bp.data()); 21 | ASSERT(ap.data().size()==1 && bp.ref_cnt()==1); 22 | } 23 | 24 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/test/test_op_compiler.cc: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #include 8 | #include "op_compiler.h" 9 | 10 | namespace jittor { 11 | 12 | JIT_TEST(regex) { 13 | std::string s(R"( 14 | asdas 15 | void adasd 16 | asdads XxxXxxOp::jit_run() { 17 | xxxx 18 | })"); 19 | std::regex e(R"([^]*\s(\S*Op)::jit_run[^]*)"); 20 | std::smatch cm; 21 | 22 | // std::regex_match ( s, cm, e, std::regex_constants::match_default ); 23 | std::regex_match ( s, cm, e); 24 | 25 | CHECK(cm.size()==2); 26 | CHECK(cm[1]=="XxxXxxOp"); 27 | } 28 | 29 | } // jittor 30 | -------------------------------------------------------------------------------- /python/jittor/src/utils/cache_compile.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | namespace jit_compiler { 12 | 13 | string read_all(const string& fname); 14 | void write(const string& fname, const string& src); 15 | bool file_exist(const string& fname); 16 | string join(string a, string b); 17 | bool cache_compile(string cmd, const string& cache_path="", const string& jittor_path=""); 18 | 19 | } // jit_compiler 20 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/utils/flags.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "utils/log.h" -------------------------------------------------------------------------------- /python/jittor/src/utils/tracer.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include "common.h" 9 | 10 | namespace jittor { 11 | 12 | void print_trace(); 13 | void breakpoint(); 14 | 15 | } // jittor -------------------------------------------------------------------------------- /python/jittor/src/utils/vdp: -------------------------------------------------------------------------------- 1 | #define _P(...) -------------------------------------------------------------------------------- /python/jittor/test/system/test_all.sh: -------------------------------------------------------------------------------- 1 | bash python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh 2 | bash python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh 3 | bash python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh 4 | bash python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh 5 | bash python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh 6 | bash python/jittor/test/system/test_nocuda_ubuntu18.04.sh 7 | -------------------------------------------------------------------------------- /python/jittor/test/test.h: -------------------------------------------------------------------------------- 1 | // *************************************************************** 2 | // Copyright (c) 2023 Jittor. All Rights Reserved. 3 | // Maintainers: Dun Liang . 4 | // This file is subject to the terms and conditions defined in 5 | // file 'LICENSE.txt', which is part of this source code package. 6 | // *************************************************************** 7 | #pragma once 8 | #include 9 | 10 | using namespace std; 11 | 12 | void test_main(); 13 | 14 | void expect_error(function func); 15 | 16 | int main() { 17 | try { 18 | test_main(); 19 | } catch (const std::exception& e) { 20 | std::cout << e.what() << std::endl; 21 | return 1; 22 | } 23 | } -------------------------------------------------------------------------------- /python/jittor/test/test_compile_options.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: Dun Liang . 4 | # This file is subject to the terms and conditions defined in 5 | # file 'LICENSE.txt', which is part of this source code package. 6 | # *************************************************************** 7 | import unittest 8 | import jittor as jt 9 | import os 10 | from .test_log import find_log_with_re 11 | from .test_fused_op import retry 12 | 13 | class TestCompileOptions(unittest.TestCase): 14 | def test(self): 15 | a = jt.array([1,2,3]) 16 | a.sync() 17 | assert a.compile_options=={} 18 | a.compile_options = {"compile_shapes":1} 19 | assert a.compile_options=={"compile_shapes":1} 20 | b = a+a 21 | assert b.compile_options=={} 22 | with jt.flag_scope(compile_options={"compile_shapes":1}): 23 | c = a+b 24 | assert c.compile_options=={"compile_shapes":1} 25 | with jt.profile_scope() as report: 26 | c.sync() 27 | assert len(report)==2 and "compile_shapes:1" in report[1][0] 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() -------------------------------------------------------------------------------- /python/jittor/test/test_console.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: Dun Liang . 4 | # This file is subject to the terms and conditions defined in 5 | # file 'LICENSE.txt', which is part of this source code package. 6 | # *************************************************************** 7 | import unittest 8 | import jittor as jt 9 | import numpy as np 10 | from jittor_utils import run_cmd 11 | import sys 12 | 13 | class TestConsole(unittest.TestCase): 14 | def test_console(self): 15 | run_cmd(f"{sys.executable} -m jittor_utils.config --cxx-example > tmp.cc", jt.flags.cache_path) 16 | s = run_cmd(f"{jt.flags.cc_path} tmp.cc $({sys.executable} -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o tmp.out && ./tmp.out", jt.flags.cache_path) 17 | print(s) 18 | assert "jt.Var" in s 19 | assert "pred.shape 2 1000" in s 20 | 21 | if __name__ == "__main__": 22 | unittest.main() -------------------------------------------------------------------------------- /python/jittor/test/test_cutt.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2019 3 | # Guoye Yang <498731903@qq.com> 4 | # Dun Liang . 5 | # All Rights Reserved. 6 | # This file is subject to the terms and conditions defined in 7 | # file 'LICENSE.txt', which is part of this source code package. 8 | # *************************************************************** 9 | import unittest 10 | import jittor as jt 11 | import numpy as np 12 | from jittor import compile_extern 13 | from .test_log import find_log_with_re 14 | import copy 15 | if jt.has_cuda: 16 | from jittor.compile_extern import cutt_ops 17 | else: 18 | cutt_ops = None 19 | 20 | class TestCutt(unittest.TestCase): 21 | @unittest.skipIf(cutt_ops==None, "Not use cutt, Skip") 22 | @jt.flag_scope(use_cuda=1) 23 | def test(self): 24 | t = cutt_ops.cutt_test("213") 25 | assert t.data == 123 26 | if __name__ == "__main__": 27 | unittest.main() -------------------------------------------------------------------------------- /python/jittor/test/test_emnist.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: 4 | # Dun Liang . 5 | # 6 | # This file is subject to the terms and conditions defined in 7 | # file 'LICENSE.txt', which is part of this source code package. 8 | # *************************************************************** 9 | import unittest 10 | import jittor as jt 11 | from jittor.dataset.mnist import EMNIST, MNIST 12 | import jittor.transform as transform 13 | 14 | @unittest.skipIf(True, f"skip emnist test") 15 | class TestEMNIST(unittest.TestCase): 16 | def test_emnist(self): 17 | import pylab as pl 18 | # emnist_dataset = EMNIST() 19 | emnist_dataset = EMNIST() 20 | for imgs, labels in emnist_dataset: 21 | print(imgs.shape, labels.shape) 22 | print(labels.max(), labels.min()) 23 | # imgs = imgs.transpose(0,1,3,2).transpose(1,2,0,3)[0].reshape(28, -1) 24 | imgs = imgs.transpose(1,2,0,3)[0].reshape(28, -1) 25 | print(labels) 26 | pl.imshow(imgs), pl.show() 27 | break 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /python/jittor/test/test_flags.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: Dun Liang . 4 | # This file is subject to the terms and conditions defined in 5 | # file 'LICENSE.txt', which is part of this source code package. 6 | # *************************************************************** 7 | import unittest 8 | import jittor as jt 9 | from .test_core import expect_error 10 | 11 | class TestFlags(unittest.TestCase): 12 | def test_error(self): 13 | def check(): jt.flags.asdasd=1 14 | expect_error(check) 15 | 16 | def test_get_set(self): 17 | prev = jt.flags.log_v 18 | jt.flags.log_v=1 19 | assert jt.flags.log_v == 1 20 | jt.flags.log_v=prev 21 | assert jt.flags.log_v == prev 22 | 23 | def test_scope(self): 24 | prev = jt.flags.log_v 25 | with jt.flag_scope(log_v=1): 26 | assert jt.flags.log_v == 1 27 | assert jt.flags.log_v == prev 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /python/jittor/test/test_jit_tests.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: Dun Liang . 4 | # This file is subject to the terms and conditions defined in 5 | # file 'LICENSE.txt', which is part of this source code package. 6 | # *************************************************************** 7 | import unittest 8 | import jittor as jt 9 | from jittor import LOG 10 | 11 | def test(name): 12 | doc = eval(f"jt.tests.{name}.__doc__") 13 | doc = doc[doc.find("From"):].strip() 14 | LOG.i(f"Run test {name} {doc}") 15 | exec(f"jt.tests.{name}()") 16 | 17 | tests = [ name for name in dir(jt.tests) if not name.startswith("__") ] 18 | src = "class TestJitTests(unittest.TestCase):\n" 19 | for name in tests: 20 | doc = eval(f"jt.tests.{name}.__doc__") 21 | doc = doc[doc.find("From"):].strip() 22 | src += f""" 23 | def test_{name}(self): 24 | test("{name}") 25 | """ 26 | 27 | LOG.vvv("eval src\n"+src) 28 | exec(src) 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /python/jittor/test/test_lock.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: 4 | # Wenyang Zhou <576825820@qq.com> 5 | # Dun Liang . 6 | # 7 | # This file is subject to the terms and conditions defined in 8 | # file 'LICENSE.txt', which is part of this source code package. 9 | # *************************************************************** 10 | import unittest 11 | import os, sys 12 | import jittor as jt 13 | import jittor_utils as jit_utils 14 | 15 | class TestLock(unittest.TestCase): 16 | def test(self): 17 | if os.environ.get('lock_full_test', '0') == '1': 18 | cache_path = os.path.join(jit_utils.home(), ".cache", "jittor", "lock") 19 | assert os.system(f"rm -rf {cache_path}") == 0 20 | cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example" 21 | else: 22 | cmd = f"{sys.executable} -m jittor.test.test_example" 23 | print("run cmd twice", cmd) 24 | assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() -------------------------------------------------------------------------------- /python/jittor/test/test_mkl_test_op.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: Dun Liang . 4 | # This file is subject to the terms and conditions defined in 5 | # file 'LICENSE.txt', which is part of this source code package. 6 | # *************************************************************** 7 | import unittest 8 | import jittor as jt 9 | import os 10 | 11 | @unittest.skipIf(not jt.compile_extern.use_mkl, "Not use mkl, Skip") 12 | class TestMklTestOp(unittest.TestCase): 13 | def test(self): 14 | assert jt.mkl_ops.mkl_test().data==123 15 | 16 | if __name__ == "__main__": 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /python/jittor/test/test_nccl.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: 4 | # Dun Liang . 5 | # 6 | # This file is subject to the terms and conditions defined in 7 | # file 'LICENSE.txt', which is part of this source code package. 8 | # *************************************************************** 9 | import jittor as jt 10 | import unittest 11 | 12 | @unittest.skipIf(jt.compile_extern.nccl_ops is None, "no nccl found") 13 | class TestNccl(unittest.TestCase): 14 | @jt.flag_scope(use_cuda=1) 15 | def test_nccl(self): 16 | assert jt.compile_extern.nccl_ops.nccl_test("").data == 123 17 | 18 | if __name__ == "__main__": 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /python/jittor/utils/bench_klo.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | jt.flags.use_device = 1 4 | n = 100000 5 | 6 | with jt.profile_scope(10, 10) as rep: 7 | jt.code([2], "float32", [], 8 | cuda_header='''__global__ void kernel(float* a) {}''', 9 | cuda_src=f''' 10 | for (int i=0; i<{n}; i++) kernel<<<1,32>>>(out0_p); 11 | ''').sync() 12 | 13 | avg_ns = float(rep[1][4]) / n 14 | print("kernel launch overhead(ns):", avg_ns) 15 | -------------------------------------------------------------------------------- /python/jittor/utils/converter_server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from flask import request 3 | from flask import jsonify 4 | app = Flask(__name__) 5 | import json 6 | 7 | from jittor.utils.pytorch_converter import convert 8 | 9 | @app.route('/', methods=["GET", "POST"]) 10 | def hello(): 11 | msg = request 12 | data = msg.data.decode("utf-8") 13 | try: 14 | data = json.loads(data) 15 | src = data["src"] 16 | pjmap = json.loads(data["pjmap"]) 17 | jt_src = convert(src, pjmap) 18 | except Exception as e: 19 | jt_src = str(e) 20 | response = jsonify(jt_src=jt_src) 21 | 22 | # Enable Access-Control-Allow-Origin 23 | response.headers.add("Access-Control-Allow-Origin", "*") 24 | return response 25 | 26 | if __name__ == '__main__': 27 | app.run(host="0.0.0.0") -------------------------------------------------------------------------------- /python/jittor/utils/data.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/jittor/444154c2a1e63ee4a0af07831a0c54e2ebb7a561/python/jittor/utils/data.gz -------------------------------------------------------------------------------- /python/jittor/utils/dlink_compiler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import re 4 | cmds = sys.argv[1:] 5 | def replace(cmds, s, t): 6 | return [ c.replace(s,t) for c in cmds ] 7 | def remove(cmds, ss): 8 | rets = [] 9 | for cmd in cmds: 10 | found = True 11 | for s in ss: 12 | if s in cmd: 13 | found = False 14 | break 15 | if found: 16 | rets.append(cmd) 17 | return rets 18 | 19 | cmds1 = remove(cmds, [".o"]) 20 | cmds1 = replace(cmds1, ".so", ".o") 21 | cmds2 = replace(cmds, "-dc", "") 22 | cmds2 = replace(cmds2, ".cu", ".o") 23 | ret = os.system(" ".join(cmds1).replace("-x cu", "")) 24 | if ret: exit(ret) 25 | ret = os.system(" ".join(cmds2).replace("-x cu", "")) 26 | if ret: exit(ret) -------------------------------------------------------------------------------- /python/jittor/utils/tracer.py: -------------------------------------------------------------------------------- 1 | # *************************************************************** 2 | # Copyright (c) 2023 Jittor. All Rights Reserved. 3 | # Maintainers: 4 | # Dun Liang . 5 | # 6 | # 7 | # This file is subject to the terms and conditions defined in 8 | # file 'LICENSE.txt', which is part of this source code package. 9 | # *************************************************************** 10 | import jittor as jt 11 | 12 | def fill_module_name(m, name): 13 | ps = [] 14 | stack = [] 15 | def callback(parents, k, v, n): 16 | stack.append(str(k)) 17 | for k2, p in v.__dict__.items(): 18 | if isinstance(p, jt.Var): 19 | ps.append(p) 20 | p.name(".".join(stack[1:]+[str(k2)])) 21 | v._trace_name = str(k) 22 | def callback_leave(parents, k, v, n): 23 | stack.pop() 24 | m.dfs([], name, callback, callback_leave) 25 | -------------------------------------------------------------------------------- /python/jittor/vcompiler/__init__.py: -------------------------------------------------------------------------------- 1 | from .vcompiler import * -------------------------------------------------------------------------------- /python/jittor/version: -------------------------------------------------------------------------------- 1 | 939b29514b2e5cc591053aab614efd569772585d 2 | -------------------------------------------------------------------------------- /python/jittor_utils/class/motd: -------------------------------------------------------------------------------- 1 | ★★★★★★★★★★★★★★★★★★★★★ 2 | Welcome to use Jittor 3 | Please put the file under /root directory 4 | ★★★★★★★★★★★★★★★★★★★★★ 5 | 欢迎使用Jittor 6 | 请将文件放置在/root目录下 7 | 本docker已经安装好cuda环境 8 | 相关链接: 9 | * [Jittor官网](https://cg.cs.tsinghua.edu.cn/jittor/) 10 | * [Jittor教程](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/) 11 | * [Jittor模型库](https://cg.cs.tsinghua.edu.cn/jittor/resources/) 12 | * [Jittor文档](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html) 13 | * [Github](https://github.com/jittor/jittor), [Gitee](https://gitee.com/jittor/jittor) 14 | * [Jittor 论坛](https://discuss.jittor.org/) 15 | * 即时通信: QQ Group(761222083) 16 | 17 | 欢迎大家star,fork并在QQ群或者论坛向我们提出宝贵的意见和建议。 18 | 19 | 注意:请不要开启无密码保护的jupyter notebook或vscode server 20 | ★★★★★★★★★★★★★★★★★★★★★ 21 | -------------------------------------------------------------------------------- /python/jittor_utils/class/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | command = sys.argv[1] 4 | if (command == 'ssh'): 5 | port = sys.argv[2] 6 | data = open("/etc/ssh/sshd_config", "r").readlines() 7 | data[12] = 'Port ' + port + '\nPermitRootLogin yes\n' 8 | f = open("/etc/ssh/sshd_config", "w") 9 | f.writelines(data) 10 | f.close() 11 | os.system("service ssh restart") 12 | elif (command == 'passwd'): 13 | passwd = sys.argv[2] 14 | os.system("echo root:"+passwd+" | chpasswd") 15 | else: 16 | print('command error') 17 | -------------------------------------------------------------------------------- /python/jittor_utils/github_release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | version=$1 4 | text=$2 5 | branch=$(git rev-parse --abbrev-ref HEAD) 6 | # repo_full_name=$(git config --get remote.origin.url | sed 's/.*:\/\/github.com\///;s/.git$//') 7 | repo_full_name=$(git config --get remote.origin.url | sed 's/.*github.com://;s/.git$//') 8 | token=$(git config --global github.token) 9 | 10 | generate_post_data() 11 | { 12 | cat <