├── .clang-format ├── .github ├── ISSUE_TEMPLATE │ └── a-new-kernel.md └── workflows │ └── build.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── Makefile ├── README.md ├── docs ├── develop │ └── index.md ├── images │ └── README-1.png └── index.md ├── scripts └── compare │ ├── compare.py │ ├── run_actual.py │ ├── run_onnx.py │ └── validate.sh ├── src ├── 00common │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ ├── common.h │ │ └── common │ │ │ ├── bf16_t.h │ │ │ ├── data_type.h │ │ │ ├── error_handler.h │ │ │ ├── fp16_t.h │ │ │ ├── natural.h │ │ │ ├── range.h │ │ │ ├── rc.hpp │ │ │ └── slice.h │ ├── src │ │ └── data_type.cc │ └── test │ │ └── test.cpp ├── 01graph_topo │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ ├── graph_topo.h │ │ └── graph_topo │ │ │ ├── builder.hpp │ │ │ ├── container.h │ │ │ ├── inplace_modifier.h │ │ │ ├── linked_graph.hpp │ │ │ ├── polymorph_graph.hpp │ │ │ └── searcher.hh │ ├── src │ │ ├── container.cc │ │ ├── inplace_modifier.cc │ │ └── searcher.cc │ └── test │ │ ├── test_graph_topo.cpp │ │ ├── test_linked.cpp │ │ ├── test_modifier.cpp │ │ └── topo.h ├── 02hardware │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── hardware │ │ │ ├── device.h │ │ │ ├── device_manager.h │ │ │ ├── devices │ │ │ ├── cpu.h │ │ │ ├── mlu.h │ │ │ └── nvidia.h │ │ │ ├── functions.h │ │ │ ├── mem_offset_calculator.h │ │ │ ├── mem_pool.h │ │ │ └── memory.h │ ├── src │ │ ├── device.cc │ │ ├── device_manager.cpp │ │ ├── devices │ │ │ ├── cpu │ │ │ │ ├── device.cc │ │ │ │ ├── memory.cc │ │ │ │ └── memory.hh │ │ │ ├── mlu │ │ │ │ ├── device.cc │ │ │ │ ├── functions.cc │ │ │ │ ├── functions.hh │ │ │ │ ├── memory.cc │ │ │ │ └── memory.hh │ │ │ └── nvidia │ │ │ │ ├── device.cc │ │ │ │ ├── memory.cc │ │ │ │ └── memory.hh │ │ ├── mem_offset_calculator.cc │ │ └── mem_pool.cc │ └── test │ │ ├── test_mem_offset_calculator.cpp │ │ └── test_mem_pool.cpp ├── 03runtime │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── runtime │ │ │ ├── resource.h │ │ │ └── stream.h │ └── src │ │ ├── resource.cpp │ │ └── stream.cc ├── 04kernel │ ├── CMakeLists.txt │ ├── README.md │ ├── cmake │ │ └── FindNCCL.cmake │ ├── cuda │ │ ├── CMakeLists.txt │ │ ├── include │ │ │ └── kernel │ │ │ │ └── cuda │ │ │ │ ├── concat.cuh │ │ │ │ ├── expand.cuh │ │ │ │ ├── functions.cuh │ │ │ │ ├── gather.cuh │ │ │ │ ├── pad.cuh │ │ │ │ ├── scatter_nd.cuh │ │ │ │ ├── slice.cuh │ │ │ │ ├── split.cuh │ │ │ │ ├── threads_distributer.cuh │ │ │ │ ├── transpose.cuh │ │ │ │ └── where.cuh │ │ └── src │ │ │ ├── concat.cu │ │ │ ├── expand.cu │ │ │ ├── functions.cu │ │ │ ├── gather.cu │ │ │ ├── macro.cuh │ │ │ ├── pad.cu │ │ │ ├── scatter_nd.cu │ │ │ ├── slice.cu │ │ │ ├── split.cu │ │ │ ├── threads_distributer.cu │ │ │ ├── transpose.cu │ │ │ └── where.cu │ ├── include │ │ └── kernel │ │ │ ├── allocators.h │ │ │ ├── attributes │ │ │ ├── broadcaster.h │ │ │ ├── communication.h │ │ │ ├── expand_info.h │ │ │ ├── gather_info.h │ │ │ ├── mat_mul_info.h │ │ │ ├── mat_mul_integer_info.h │ │ │ ├── pad_info.h │ │ │ ├── pool_attributes.h │ │ │ ├── scatter_nd_info.h │ │ │ ├── slice_info.h │ │ │ ├── softmax_info.h │ │ │ ├── split_info.h │ │ │ └── transpose_info.h │ │ │ ├── blob.hh │ │ │ ├── collector.h │ │ │ ├── collectors │ │ │ ├── all_reduce.h │ │ │ ├── attention.h │ │ │ ├── batch_normalization.h │ │ │ ├── cast.h │ │ │ ├── clip.h │ │ │ ├── concat.h │ │ │ ├── conv.h │ │ │ ├── dequantize_linear.h │ │ │ ├── dynamic_quantize_linear.h │ │ │ ├── gather.h │ │ │ ├── global_pool.h │ │ │ ├── hard_sigmoid.h │ │ │ ├── mat_mul.h │ │ │ ├── mat_mul_integer.h │ │ │ ├── pad.h │ │ │ ├── pool.h │ │ │ ├── reduce.h │ │ │ ├── rms_normalization.h │ │ │ ├── scatter_nd.h │ │ │ ├── select.h │ │ │ ├── simple_binary.h │ │ │ ├── simple_unary.h │ │ │ ├── slice.h │ │ │ ├── softmax.h │ │ │ ├── split.h │ │ │ ├── transpose.h │ │ │ └── where.h │ │ │ ├── graph.h │ │ │ ├── kernel.h │ │ │ ├── layout.h │ │ │ └── tensor.h │ ├── src │ │ ├── allocators │ │ │ ├── flat_allocator.cpp │ │ │ └── reusable_allocator.cpp │ │ ├── attributes │ │ │ ├── broadcaster.cc │ │ │ ├── expand_info.cc │ │ │ ├── gather_info.cc │ │ │ ├── mat_mul_info.cc │ │ │ ├── mat_mul_integer_info.cc │ │ │ ├── pad_2d_info.cc │ │ │ ├── pad_2d_info.h │ │ │ ├── pad_info.cc │ │ │ ├── pool_attributes.cc │ │ │ ├── scatter_nd_info.cc │ │ │ ├── slice_info.cc │ │ │ ├── softmax_info.cc │ │ │ ├── split_info.cc │ │ │ └── transpose_info.cc │ │ ├── blob.cc │ │ ├── collectors │ │ │ ├── all_reduce.cc │ │ │ ├── attention.cc │ │ │ ├── batch_normalization.cc │ │ │ ├── cast.cc │ │ │ ├── clip.cc │ │ │ ├── concat.cc │ │ │ ├── conv.cc │ │ │ ├── dequantize_linear.cc │ │ │ ├── dynamic_quantize_linear.cc │ │ │ ├── gather.cc │ │ │ ├── global_pool.cc │ │ │ ├── hard_sigmoid.cc │ │ │ ├── mat_mul.cc │ │ │ ├── mat_mul_integer.cc │ │ │ ├── pad.cc │ │ │ ├── pool.cc │ │ │ ├── reduce.cc │ │ │ ├── rms_normalization.cc │ │ │ ├── scatter_nd.cc │ │ │ ├── select.cc │ │ │ ├── simple_binary.cc │ │ │ ├── simple_unary.cc │ │ │ ├── slice.cc │ │ │ ├── softmax.cc │ │ │ ├── split.cc │ │ │ ├── transpose.cc │ │ │ └── where.cc │ │ ├── generator │ │ │ ├── nvrtc_repo.cc │ │ │ └── nvrtc_repo.h │ │ ├── graph.cc │ │ ├── kernel.cc │ │ ├── kernels │ │ │ ├── all_reduce │ │ │ │ ├── nccl_kernel.cc │ │ │ │ ├── nccl_kernel.cu │ │ │ │ └── nccl_kernel.hh │ │ │ ├── attention │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── batch_normalization │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cudnn_kernel.cc │ │ │ │ └── cudnn_kernel.hh │ │ │ ├── cast │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── clip │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ └── cuda_kernel.hh │ │ │ ├── concat │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── conv │ │ │ │ ├── cudnn_kernel.cc │ │ │ │ └── cudnn_kernel.hh │ │ │ ├── dequantize_linear │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── dynamic_quantize_linear │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ └── cuda_kernel.hh │ │ │ ├── expand │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ └── cuda_kernel.hh │ │ │ ├── extra_padding │ │ │ │ ├── extra_padding.cu │ │ │ │ └── extra_padding.cuh │ │ │ ├── gather │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── hard_sigmoid │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── mat_mul │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cublas_kernel.cc │ │ │ │ ├── cublas_kernel.cu │ │ │ │ └── cublas_kernel.hh │ │ │ ├── mat_mul_common │ │ │ │ └── cpu_template.hpp │ │ │ ├── mat_mul_integer │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cublas_kernel.cc │ │ │ │ ├── cublas_kernel.cu │ │ │ │ └── cublas_kernel.hh │ │ │ ├── pad │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ └── cuda_kernel.hh │ │ │ ├── pool │ │ │ │ ├── cudnn_kernel.cc │ │ │ │ └── cudnn_kernel.hh │ │ │ ├── reduce │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cudnn_kernel.cc │ │ │ │ └── cudnn_kernel.hh │ │ │ ├── rms_normalization │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── scatter_nd │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ └── cuda_kernel.hh │ │ │ ├── select │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── simple_binary │ │ │ │ ├── binary_cudnn.cc │ │ │ │ ├── binary_cudnn.hh │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── simple_unary │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.hh │ │ │ │ ├── cudnn_activation_kernel.cc │ │ │ │ └── cudnn_activation_kernel.hh │ │ │ ├── slice │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ └── cuda_kernel.hh │ │ │ ├── softmax │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ ├── cuda_kernel.hh │ │ │ │ ├── cudnn_kernel.cc │ │ │ │ └── cudnn_kernel.hh │ │ │ ├── split │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ └── cuda_kernel.hh │ │ │ ├── transpose │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── cuda_kernel.cc │ │ │ │ ├── cuda_kernel.cu │ │ │ │ └── cuda_kernel.hh │ │ │ └── where │ │ │ │ ├── cpu_kernel.cc │ │ │ │ ├── cpu_kernel.hh │ │ │ │ ├── where_cuda.cc │ │ │ │ ├── where_cuda.cu │ │ │ │ └── where_cuda.hh │ │ ├── tensor.cc │ │ └── utilities │ │ │ └── cuda │ │ │ ├── cublas_context.cu │ │ │ ├── cublas_context.hh │ │ │ ├── cublaslt_context.cu │ │ │ ├── cublaslt_context.hh │ │ │ ├── cudnn_context.cu │ │ │ ├── cudnn_context.hh │ │ │ ├── cudnn_functions.cpp │ │ │ ├── cudnn_functions.h │ │ │ ├── nccl_communicator.cu │ │ │ └── nccl_communicator.hh │ └── test │ │ ├── attributes │ │ ├── test_broadcaster.cpp │ │ ├── test_expand_info.cpp │ │ ├── test_gather_info.cpp │ │ ├── test_pool_attributes.cpp │ │ ├── test_scatter_nd_info.cpp │ │ ├── test_split_info.cpp │ │ └── test_transpose_info.cpp │ │ ├── generator │ │ └── test_cuda.cpp │ │ └── kernels │ │ ├── all_reduce │ │ └── test_allreduce_nccl.cpp │ │ ├── cast │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── clip │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── concat │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── dequantize_linear │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── dynamic_quantize_linear │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── expand │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── gather │ │ ├── test_gather_cpu.cpp │ │ └── test_gather_cuda.cpp │ │ ├── hard_sigmoid │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── mat_mul │ │ ├── test_cpu.cpp │ │ └── test_cublas.cpp │ │ ├── mat_mul_integer │ │ ├── test_cpu_kernel.cpp │ │ └── test_cublas_kernel.cpp │ │ ├── pad │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── pool │ │ └── test_cudnn.cpp │ │ ├── reduce │ │ ├── test_cpu.cpp │ │ └── test_cudnn.cpp │ │ ├── rms_normalization │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── scatter_nd │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── select │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── simple_binary │ │ ├── test_binary_cpu.cpp │ │ ├── test_binary_cuda.cpp │ │ └── test_binary_cudnn.cpp │ │ ├── simple_unary │ │ ├── test_cpu.cpp │ │ ├── test_cuda.cpp │ │ └── test_cudnn.cpp │ │ ├── slice │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── softmax │ │ ├── test_cpu.cpp │ │ ├── test_cuda.cpp │ │ └── test_cudnn.cpp │ │ ├── split │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ ├── transpose │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp │ │ └── where │ │ ├── test_cpu.cpp │ │ └── test_cuda.cpp ├── 05computation │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── computation │ │ │ ├── graph.h │ │ │ ├── operator.h │ │ │ └── operators │ │ │ ├── all_reduce.h │ │ │ ├── attention.h │ │ │ ├── batch_normalization.h │ │ │ ├── broadcast.h │ │ │ ├── cast.h │ │ │ ├── clip.h │ │ │ ├── compair.h │ │ │ ├── concat.h │ │ │ ├── conv.h │ │ │ ├── cum_sum.h │ │ │ ├── dequantize_linear.h │ │ │ ├── dynamic_quantize_linear.h │ │ │ ├── einsum.h │ │ │ ├── gather.h │ │ │ ├── gather_elements.h │ │ │ ├── global_pool.h │ │ │ ├── hard_sigmoid.h │ │ │ ├── identity.h │ │ │ ├── mat_mul.h │ │ │ ├── mat_mul_integer.h │ │ │ ├── pad.h │ │ │ ├── pool.h │ │ │ ├── reduce.h │ │ │ ├── reshape.h │ │ │ ├── rms_normalization.h │ │ │ ├── scatter_nd.h │ │ │ ├── select.h │ │ │ ├── simple_binary.h │ │ │ ├── simple_unary.h │ │ │ ├── slice.h │ │ │ ├── softmax.h │ │ │ ├── split.h │ │ │ ├── transpose.h │ │ │ └── where.h │ ├── src │ │ ├── graph.cc │ │ ├── operator.cc │ │ ├── operators │ │ │ ├── all_reduce.cc │ │ │ ├── attention.cc │ │ │ ├── batch_normalization.cc │ │ │ ├── broadcast.cc │ │ │ ├── cast.cc │ │ │ ├── clip.cc │ │ │ ├── compair.cc │ │ │ ├── concat.cc │ │ │ ├── conv.cc │ │ │ ├── cum_sum.cc │ │ │ ├── dequantize_linear.cc │ │ │ ├── dynamic_quantize_linear.cc │ │ │ ├── einsum.cc │ │ │ ├── gather.cc │ │ │ ├── gather_elements.cc │ │ │ ├── global_pool.cc │ │ │ ├── hard_sigmoid.cc │ │ │ ├── identity.cc │ │ │ ├── mat_mul.cc │ │ │ ├── mat_mul_integer.cc │ │ │ ├── pad.cc │ │ │ ├── pool.cc │ │ │ ├── reduce.cc │ │ │ ├── reshape.cc │ │ │ ├── rms_normalization.cc │ │ │ ├── scatter_nd.cc │ │ │ ├── select.cc │ │ │ ├── simple_binary.cc │ │ │ ├── simple_unary.cc │ │ │ ├── slice.cc │ │ │ ├── softmax.cc │ │ │ ├── split.cc │ │ │ ├── transpose.cc │ │ │ └── where.cc │ │ └── transfomation │ │ │ └── layout_permutation.cpp │ └── test │ │ └── test_transpose.cpp ├── 06frontend │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── frontend │ │ │ ├── graph.h │ │ │ ├── infer.h │ │ │ ├── operator.h │ │ │ └── tensor.h │ ├── src │ │ ├── common.cpp │ │ ├── graph.cc │ │ ├── infer.cc │ │ ├── operator.cc │ │ └── tensor.cc │ └── test │ │ └── test_subgraph.cpp.bak ├── 07onnx │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── onnx │ │ │ └── operators.h │ ├── src │ │ ├── operators.cpp │ │ └── operators │ │ │ ├── batch_normalization.cc │ │ │ ├── batch_normalization.hh │ │ │ ├── cast.cc │ │ │ ├── cast.hh │ │ │ ├── clip.cc │ │ │ ├── clip.hh │ │ │ ├── common.cpp │ │ │ ├── common.h │ │ │ ├── compair.cc │ │ │ ├── compair.hh │ │ │ ├── concat.cc │ │ │ ├── concat.hh │ │ │ ├── constant.cc │ │ │ ├── constant.hh │ │ │ ├── constant_of_shape.cc │ │ │ ├── constant_of_shape.hh │ │ │ ├── conv.cc │ │ │ ├── conv.hh │ │ │ ├── cum_sum.cc │ │ │ ├── cum_sum.hh │ │ │ ├── dequantize_linear.cc │ │ │ ├── dequantize_linear.hh │ │ │ ├── dynamic_quantize_linear.cc │ │ │ ├── dynamic_quantize_linear.hh │ │ │ ├── einsum.cc │ │ │ ├── einsum.hh │ │ │ ├── expand.cc │ │ │ ├── expand.hh │ │ │ ├── flatten.cc │ │ │ ├── flatten.hh │ │ │ ├── gather.cc │ │ │ ├── gather.hh │ │ │ ├── gather_elements.cc │ │ │ ├── gather_elements.hh │ │ │ ├── gemm.cc │ │ │ ├── gemm.hh │ │ │ ├── global_pool.cc │ │ │ ├── global_pool.hh │ │ │ ├── hard_sigmoid.cc │ │ │ ├── hard_sigmoid.hh │ │ │ ├── mat_mul.cc │ │ │ ├── mat_mul.hh │ │ │ ├── mat_mul_integer.cc │ │ │ ├── mat_mul_integer.hh │ │ │ ├── pad.cc │ │ │ ├── pad.hh │ │ │ ├── pool.cc │ │ │ ├── pool.hh │ │ │ ├── pool_type.h │ │ │ ├── range.cc │ │ │ ├── range.hh │ │ │ ├── reduce.cc │ │ │ ├── reduce.hh │ │ │ ├── reshape.cc │ │ │ ├── reshape.hh │ │ │ ├── scatter_nd.cc │ │ │ ├── scatter_nd.hh │ │ │ ├── select.cc │ │ │ ├── select.hh │ │ │ ├── shape.cc │ │ │ ├── shape.hh │ │ │ ├── simple_binary.cc │ │ │ ├── simple_binary.hh │ │ │ ├── simple_unary.cc │ │ │ ├── simple_unary.hh │ │ │ ├── slice.cc │ │ │ ├── slice.hh │ │ │ ├── softmax.cc │ │ │ ├── softmax.hh │ │ │ ├── split.cc │ │ │ ├── split.hh │ │ │ ├── squeeze.cc │ │ │ ├── squeeze.hh │ │ │ ├── tile.cc │ │ │ ├── tile.hh │ │ │ ├── transpose.cc │ │ │ ├── transpose.hh │ │ │ ├── unsqueeze.cc │ │ │ ├── unsqueeze.hh │ │ │ ├── where.cc │ │ │ └── where.hh │ └── test │ │ ├── test_clip.cpp │ │ ├── test_concat.cpp │ │ ├── test_conv.cpp │ │ ├── test_cum_sum.cpp │ │ ├── test_einsum.cpp │ │ ├── test_expand.cpp │ │ ├── test_hard_sigmoid.cpp │ │ ├── test_pad.cpp │ │ ├── test_shape.cpp │ │ ├── test_simple_unary.cpp │ │ ├── test_squeeze.cpp │ │ └── test_unsqueeze.cpp ├── 08-01llm │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── llm │ │ │ └── operators.h │ ├── src │ │ ├── operators.cpp │ │ └── operators │ │ │ ├── attention.cc │ │ │ ├── attention.hh │ │ │ ├── common.h │ │ │ ├── mat_mul.cc │ │ │ ├── mat_mul.hh │ │ │ ├── rms_normalization.cc │ │ │ └── rms_normalization.hh │ └── test │ │ └── test_rms_normalization.cpp ├── 08communication │ ├── CMakeLists.txt │ ├── README.md │ ├── include │ │ └── communication │ │ │ └── operators.h │ └── src │ │ ├── operators.cpp │ │ └── operators │ │ ├── all_gather.cc │ │ ├── all_gather.hh │ │ ├── all_reduce.cc │ │ ├── all_reduce.hh │ │ └── common.h └── 09python_ffi │ ├── CMakeLists.txt │ ├── README.md │ ├── pyproject.toml │ └── src │ ├── compiler.cc │ ├── compiler.h │ ├── executor.cc │ ├── executor.h │ ├── functions.cpp │ ├── functions.h │ ├── import.cpp │ ├── import.h │ ├── main.cpp │ └── refactor_graph │ ├── __init__.py │ └── onnx.py └── utilities ├── .cargo └── config.toml ├── .gitignore ├── Cargo.lock ├── Cargo.toml └── src ├── format.rs ├── infer.rs ├── main.rs └── make.rs /.github/ISSUE_TEMPLATE/a-new-kernel.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: A new kernel 3 | about: Add new kernel to the project. 4 | title: "[kernel]" 5 | labels: kernel 6 | assignees: '' 7 | 8 | --- 9 | 10 | - [ ] 单元测试测例 11 | - [ ] cpu kernel 12 | - [ ] cuda kernel 13 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build and test cpu 2 | on: 3 | push: 4 | paths-ignore: 5 | - '**.md' 6 | - 'LICENSE' 7 | pull_request: 8 | paths: 9 | - '**.md' 10 | - 'LICENSE' 11 | 12 | jobs: 13 | build: 14 | name: Build 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | type: [debug, release] 20 | steps: 21 | 22 | - uses: actions/checkout@v3 23 | with: 24 | submodules: recursive 25 | 26 | - name: build ${{ matrix.type }} 27 | run: make TYPE=${{ matrix.type }} 28 | 29 | - name: test ${{ matrix.type }} 30 | run: make test TYPE=${{ matrix.type }} 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .cache/ 3 | /build/ 4 | 5 | *.egg-info/ 6 | __pycache__/ 7 | *.so 8 | *.log 9 | *.onnx 10 | *.pb 11 | *.bin 12 | *.npy 13 | *.meta 14 | 15 | compile_commands.json 16 | 17 | /scripts/*.py 18 | !/scripts/format.py 19 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rd-party/fmt"] 2 | path = 3rd-party/fmt 3 | url = git@github.com:fmtlib/fmt.git 4 | [submodule "3rd-party/googletest"] 5 | path = 3rd-party/googletest 6 | url = git@github.com:google/googletest.git 7 | [submodule "3rd-party/backward-cpp"] 8 | path = 3rd-party/backward-cpp 9 | url = git@github.com:bombela/backward-cpp.git 10 | [submodule "3rd-party/fmtlog"] 11 | path = 3rd-party/fmtlog 12 | url = git@github.com:MengRao/fmtlog.git 13 | [submodule "3rd-party/result"] 14 | path = 3rd-party/result 15 | url = git@github.com:willowell/result.git 16 | [submodule "3rd-party/abseil-cpp"] 17 | path = 3rd-party/abseil-cpp 18 | url = git@github.com:abseil/abseil-cpp.git 19 | [submodule "src/09python_ffi/pybind11"] 20 | path = src/09python_ffi/pybind11 21 | url = git@github.com:pybind/pybind11.git 22 | [submodule "3rd-party/cccl"] 23 | path = 3rd-party/cccl 24 | url = git@github.com:NVIDIA/cccl.git 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY : build install-python reconfig clean clean-log format test 2 | 3 | TYPE ?= Debug 4 | CUDA ?= OFF 5 | KUNLUN ?= OFF 6 | BANG ?= OFF 7 | 8 | CMAKE_EXTRA = 9 | # CMAKE_EXTRA += -DCMAKE_CXX_COMPILER= 10 | 11 | build: 12 | mkdir -p build 13 | cmake -Bbuild -DCMAKE_BUILD_TYPE=$(TYPE) -DUSE_CUDA=$(CUDA) -DUSE_KUNLUN=$(KUNLUN) -DUSE_BANG=$(BANG) $(CMAKE_EXTRA) 14 | make -j -C build 15 | 16 | install-python: build 17 | cp build/src/09python_ffi/python_ffi*.so src/09python_ffi/src/refactor_graph 18 | pip install -e src/09python_ffi/ 19 | 20 | reconfig: 21 | @rm -f build/CMakeCache.txt 22 | @rm -rf build/CMakeFiles 23 | @echo "configuration cache removed." 24 | 25 | clean: 26 | rm -rf build 27 | 28 | clean-log: 29 | rm -rf log 30 | 31 | test: 32 | make test -j -Cbuild 33 | -------------------------------------------------------------------------------- /docs/images/README-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InfiniTensor/RefactorGraph/9eddc91ebfbac72fe0c6f78b21243c7cd20a153b/docs/images/README-1.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # 文档 2 | 3 | ## 目录 4 | 5 | - [开发文档](develop/index.md) 6 | - [参考资料](#参考资料) 7 | - [Nvidia 优化](#nvidia-优化) 8 | 9 | ## 参考资料 10 | 11 | ### Nvidia 优化 12 | 13 | - [深入浅出 GPU 优化系列](https://www.zhihu.com/people/ban-zhuan-yuan-shi-chao-ji-gun/posts) 14 | - [Transpose](https://zhuanlan.zhihu.com/p/582664676) 15 | - [Reduce](https://zhuanlan.zhihu.com/p/559549740) 16 | -------------------------------------------------------------------------------- /scripts/compare/validate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Save the current working directory 4 | original_directory=$(pwd) 5 | 6 | # Get the directory of the script 7 | script_directory=$(dirname "$(readlink -f "$0")") 8 | 9 | # Change to the script's directory 10 | cd "$script_directory" 11 | 12 | if [ "$#" -eq 1 ]; then 13 | echo "Validating model: $1. Random inputs will be generated." 14 | python run_actual.py --model $1 --gen_input 15 | python run_onnx.py --model $1 16 | python compare.py > result.txt 17 | echo "Compare results saved in result.txt." 18 | else 19 | echo "Please provide an onnx file path as a single argument." 20 | fi 21 | 22 | # Change back to the original working directory 23 | cd "$original_directory" 24 | -------------------------------------------------------------------------------- /src/00common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(common VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE COMMON_SRC src/*.cc) 6 | add_library(common STATIC ${COMMON_SRC}) 7 | target_link_libraries(common PUBLIC fmt) 8 | target_include_directories(common PUBLIC include) 9 | 10 | file(GLOB_RECURSE COMMON_TEST test/*.cpp) 11 | if(COMMON_TEST) 12 | add_executable(common_test ${COMMON_TEST}) 13 | add_test(common_test common_test) 14 | target_link_libraries(common_test common GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/00common/README.md: -------------------------------------------------------------------------------- 1 | # 共享库 2 | 3 | 提供一些基础的类型和函数定义。包括: 4 | 5 | | 头文件 | 功能 6 | |:------|:- 7 | | [bf16_t.h](include/common/bf16_t.h) | 定义 bf16 结构体。 8 | | [data_type.h](include/common/data_type.h) | 定义 DataType 类型,包含支持的数据类型枚举、与类型关键字的对应关系、计算大小等功能。 9 | | [error_handler.h](include/common/error_handler.h) | 定义异常类型,以及构造各类异常信息的宏。 10 | | [fp16_t.h](include/common/fp16_t.h) | 定义 fp16 结构体(half,符合 IEEE754 的 16 位浮点数类型)。 11 | | [natural.h](include/common/natural.h) | 定义自然数迭代器,支持生成从指定整数开始递增的数字。 12 | | [range.h](include/common/range.h) | 定义自然数范围,表示指定开始结束的一串自然数并提供范围上的自然数迭代器。 13 | | [rc.h](include/common/rc.h) | 定义非原子的引用计数智能指针。 14 | | [slice.h](include/common/slice.h) | 提供类似 `std::range` 的功能,用开始和结束指针指示一段连续内存并提供指针作为迭代器。 15 | -------------------------------------------------------------------------------- /src/00common/test/test.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include 3 | 4 | using namespace refactor; 5 | 6 | TEST(graph_topo, Builder) { 7 | float val = 2047; 8 | fp16_t ans(val); 9 | EXPECT_EQ(ans.to_f32(), val); 10 | fmt::println("{}", ans.to_string()); 11 | } 12 | -------------------------------------------------------------------------------- /src/01graph_topo/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(graph_topo VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE GRAPH_TOPO_SRC src/*.cc src/*.cpp) 6 | add_library(graph_topo STATIC ${GRAPH_TOPO_SRC}) 7 | target_link_libraries(graph_topo PUBLIC common) 8 | target_include_directories(graph_topo PUBLIC include) 9 | 10 | file(GLOB_RECURSE GRAPH_TOPO_TEST test/*.cpp) 11 | if(GRAPH_TOPO_TEST) 12 | add_executable(graph_topo_test ${GRAPH_TOPO_TEST}) 13 | add_test(graph_topo_test graph_topo_test) 14 | target_link_libraries(graph_topo_test graph_topo GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/01graph_topo/README.md: -------------------------------------------------------------------------------- 1 | # 图拓扑 2 | 3 | 图拓扑结构的增、改、查,包含且仅包含图的拓扑结构。 4 | 5 | 核心结构通过 [*include/graph_topo/container.hpp*](include/graph_topo/container.hpp) 中的 `GraphTopo` 模板类构造和存储。 6 | 7 | 由于拓扑结构内部不保存所有连接关系,要快速查询连接关系就需要通过其他结构的缓存。[*include/graph_topo/searcher.hpp*](include/graph_topo/searcher.hpp) 中定义的 `Searcher` 结构可以缓存图的全局输入输出、每个节点的输入输出、节点和节点之间的前驱后继关系等信息,当 `GraphTopo` 构造完毕,可以利用它构造一个 *Searcher* 对象,在访问拓扑结构时也直接通过 *Searcher* 访问。 8 | -------------------------------------------------------------------------------- /src/01graph_topo/include/graph_topo.h: -------------------------------------------------------------------------------- 1 | #ifndef GRAPH_TOPO_GRAPH_TOPO_H 2 | #define GRAPH_TOPO_GRAPH_TOPO_H 3 | 4 | #include "graph_topo/builder.hpp" 5 | #include "graph_topo/inplace_modifier.h" 6 | #include "graph_topo/polymorph_graph.hpp" 7 | #include "graph_topo/searcher.hh" 8 | 9 | #endif// GRAPH_TOPO_GRAPH_TOPO_H 10 | -------------------------------------------------------------------------------- /src/02hardware/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(hardware VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp) 6 | add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC}) 7 | target_link_libraries(hardware PUBLIC common) 8 | target_include_directories(hardware PUBLIC include) 9 | 10 | if(USE_CUDA) 11 | target_include_directories(hardware PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 12 | endif() 13 | 14 | file(GLOB_RECURSE HARDWARE_TEST test/*.cpp) 15 | if(HARDWARE_TEST) 16 | add_executable(hardware_test ${HARDWARE_TEST}) 17 | add_test(hardware_test hardware_test) 18 | target_link_libraries(hardware_test hardware GTest::gtest_main Backward::Object) 19 | endif() 20 | -------------------------------------------------------------------------------- /src/02hardware/README.md: -------------------------------------------------------------------------------- 1 | # 计算设备抽象层 2 | -------------------------------------------------------------------------------- /src/02hardware/include/hardware/device_manager.h: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICE_MANAGER_H 2 | #define HARDWARE_DEVICE_MANAGER_H 3 | 4 | #include "device.h" 5 | 6 | namespace refactor::hardware::device { 7 | 8 | Arc fetch(Device::Type); 9 | Arc fetch(Device::Type, int32_t card); 10 | Arc init(Device::Type, int32_t card, std::string_view args); 11 | 12 | }// namespace refactor::hardware::device 13 | 14 | #endif// HARDWARE_DEVICE_MANAGER_H 15 | -------------------------------------------------------------------------------- /src/02hardware/include/hardware/devices/cpu.h: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICES_CPU_H 2 | #define HARDWARE_DEVICES_CPU_H 3 | 4 | #include "../device.h" 5 | 6 | namespace refactor::hardware { 7 | 8 | class Cpu final : public Device { 9 | public: 10 | Cpu(); 11 | 12 | Type type() const noexcept final { 13 | return Type::Cpu; 14 | } 15 | }; 16 | 17 | }// namespace refactor::hardware 18 | 19 | #endif// HARDWARE_DEVICES_CPU_H 20 | -------------------------------------------------------------------------------- /src/02hardware/include/hardware/devices/mlu.h: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICES_MLU_H 2 | #define HARDWARE_DEVICES_MLU_H 3 | 4 | #include "../device.h" 5 | 6 | namespace refactor::hardware { 7 | 8 | class Mlu final : public Device { 9 | public: 10 | explicit Mlu(int32_t card); 11 | void setContext() const noexcept final; 12 | Type type() const noexcept final { 13 | return Type::Mlu; 14 | } 15 | }; 16 | 17 | }// namespace refactor::hardware 18 | 19 | #endif// HARDWARE_DEVICES_MLU_H 20 | -------------------------------------------------------------------------------- /src/02hardware/include/hardware/devices/nvidia.h: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICES_NVIDIA_H 2 | #define HARDWARE_DEVICES_NVIDIA_H 3 | 4 | #include "../device.h" 5 | 6 | namespace refactor::hardware { 7 | 8 | class Nvidia final : public Device { 9 | public: 10 | explicit Nvidia(int32_t card); 11 | void setContext() const final; 12 | Type type() const noexcept final { 13 | return Type::Nvidia; 14 | } 15 | }; 16 | 17 | }// namespace refactor::hardware 18 | 19 | #endif// HARDWARE_DEVICES_NVIDIA_H 20 | -------------------------------------------------------------------------------- /src/02hardware/include/hardware/functions.h: -------------------------------------------------------------------------------- 1 | #ifndef MEM_MANAGER_FUNCTIONS_H 2 | #define MEM_MANAGER_FUNCTIONS_H 3 | 4 | #include 5 | 6 | namespace refactor::hardware { 7 | 8 | constexpr size_t align2power(size_t size, int bits) { 9 | auto mask = (1 << bits) - 1; 10 | return (size + mask) & ~mask; 11 | } 12 | 13 | constexpr size_t alignBytes(size_t size, int bytes) { 14 | return (size + bytes - 1) / bytes * bytes; 15 | } 16 | 17 | }// namespace refactor::hardware 18 | 19 | #endif// MEM_MANAGER_FUNCTIONS_H 20 | -------------------------------------------------------------------------------- /src/02hardware/include/hardware/memory.h: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_MEMORY_H 2 | #define HARDWARE_MEMORY_H 3 | 4 | #include 5 | 6 | namespace refactor::hardware { 7 | 8 | class Memory { 9 | public: 10 | virtual ~Memory() = default; 11 | virtual void *malloc(size_t) = 0; 12 | virtual void free(void *) = 0; 13 | virtual void *copyHD(void *dst, void const *src, size_t bytes) const = 0; 14 | virtual void *copyDH(void *dst, void const *src, size_t bytes) const = 0; 15 | virtual void *copyDD(void *dst, void const *src, size_t bytes) const = 0; 16 | }; 17 | 18 | }// namespace refactor::hardware 19 | 20 | #endif// HARDWARE_MEMORY_H 21 | -------------------------------------------------------------------------------- /src/02hardware/src/devices/cpu/device.cc: -------------------------------------------------------------------------------- 1 | #include "hardware/devices/cpu.h" 2 | #include "hardware/mem_pool.h" 3 | #include "memory.hh" 4 | 5 | namespace refactor::hardware { 6 | 7 | static Arc cpuMemory() { 8 | static auto instance = std::make_shared(); 9 | return instance; 10 | } 11 | 12 | Cpu::Cpu() : Device(0, cpuMemory()) {} 13 | 14 | }// namespace refactor::hardware 15 | -------------------------------------------------------------------------------- /src/02hardware/src/devices/cpu/memory.cc: -------------------------------------------------------------------------------- 1 | #include "memory.hh" 2 | #include 3 | #include 4 | 5 | namespace refactor::hardware { 6 | using M = CpuMemory; 7 | 8 | void *M::malloc(size_t size) { 9 | return std::malloc(size); 10 | } 11 | void M::free(void *ptr) { 12 | std::free(ptr); 13 | } 14 | void *M::copyHD(void *dst, void const *src, size_t bytes) const { 15 | return std::memcpy(dst, src, bytes); 16 | } 17 | void *M::copyDH(void *dst, void const *src, size_t bytes) const { 18 | return std::memcpy(dst, src, bytes); 19 | } 20 | void *M::copyDD(void *dst, void const *src, size_t bytes) const { 21 | return std::memcpy(dst, src, bytes); 22 | } 23 | 24 | }// namespace refactor::hardware 25 | -------------------------------------------------------------------------------- /src/02hardware/src/devices/cpu/memory.hh: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICES_CPU_MEMORY_HH 2 | #define HARDWARE_DEVICES_CPU_MEMORY_HH 3 | 4 | #include "hardware/memory.h" 5 | 6 | namespace refactor::hardware { 7 | 8 | class CpuMemory final : public Memory { 9 | void *malloc(size_t) final; 10 | void free(void *) final; 11 | void *copyHD(void *dst, void const *src, size_t bytes) const final; 12 | void *copyDH(void *dst, void const *src, size_t bytes) const final; 13 | void *copyDD(void *dst, void const *src, size_t bytes) const final; 14 | }; 15 | 16 | }// namespace refactor::hardware 17 | 18 | #endif// HARDWARE_DEVICES_CPU_MEMORY_HH 19 | -------------------------------------------------------------------------------- /src/02hardware/src/devices/mlu/functions.cc: -------------------------------------------------------------------------------- 1 | #include "functions.hh" 2 | 3 | namespace refactor::hardware { 4 | 5 | #ifdef USE_BANG 6 | int getDeviceCount() { 7 | unsigned deviceCount; 8 | BANG_ASSERT(cnrtGetDeviceCount(&deviceCount)); 9 | return static_cast(deviceCount); 10 | } 11 | void setDevice(int device) { 12 | BANG_ASSERT(cnrtSetDevice(device)); 13 | } 14 | MemInfo getMemInfo() { 15 | MemInfo memInfo; 16 | BANG_ASSERT(cnrtMemGetInfo(&memInfo.free, &memInfo.total)); 17 | return memInfo; 18 | } 19 | #endif 20 | 21 | }// namespace refactor::hardware 22 | -------------------------------------------------------------------------------- /src/02hardware/src/devices/mlu/functions.hh: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICES_MLU_FUNCTIONS_CUH 2 | #define HARDWARE_DEVICES_MLU_FUNCTIONS_CUH 3 | 4 | #include "common.h" 5 | 6 | #ifdef USE_BANG 7 | #include "cnrt.h" 8 | 9 | #define BANG_ASSERT(STATUS) \ 10 | if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \ 11 | RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \ 12 | cnrtGetErrorStr(status), (int) status)); \ 13 | } 14 | #endif 15 | 16 | namespace refactor::hardware { 17 | 18 | struct MemInfo { 19 | size_t free, total; 20 | }; 21 | 22 | int getDeviceCount(); 23 | void setDevice(int device); 24 | MemInfo getMemInfo(); 25 | 26 | }// namespace refactor::hardware 27 | 28 | #endif// HARDWARE_DEVICES_NVIDIA_FUNCTIONS_CUH 29 | -------------------------------------------------------------------------------- /src/02hardware/src/devices/mlu/memory.hh: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICES_MLU_MEMORY_CUH 2 | #define HARDWARE_DEVICES_MLU_MEMORY_CUH 3 | 4 | #include "hardware/memory.h" 5 | 6 | namespace refactor::hardware { 7 | 8 | class MluMemory final : public Memory { 9 | void *malloc(size_t) final; 10 | void free(void *) final; 11 | void *copyHD(void *dst, void const *src, size_t bytes) const final; 12 | void *copyDH(void *dst, void const *src, size_t bytes) const final; 13 | void *copyDD(void *dst, void const *src, size_t bytes) const final; 14 | }; 15 | 16 | }// namespace refactor::hardware 17 | 18 | #endif// HARDWARE_DEVICES_MLU_MEMORY_HH 19 | -------------------------------------------------------------------------------- /src/02hardware/src/devices/nvidia/memory.hh: -------------------------------------------------------------------------------- 1 | #ifndef HARDWARE_DEVICES_NVIDIA_MEMORY_CUH 2 | #define HARDWARE_DEVICES_NVIDIA_MEMORY_CUH 3 | 4 | #include "hardware/memory.h" 5 | 6 | namespace refactor::hardware { 7 | 8 | class NvidiaMemory final : public Memory { 9 | void *malloc(size_t) final; 10 | void free(void *) final; 11 | void *copyHD(void *dst, void const *src, size_t bytes) const final; 12 | void *copyDH(void *dst, void const *src, size_t bytes) const final; 13 | void *copyDD(void *dst, void const *src, size_t bytes) const final; 14 | }; 15 | 16 | }// namespace refactor::hardware 17 | 18 | #endif// HARDWARE_DEVICES_NVIDIA_MEMORY_CUH 19 | -------------------------------------------------------------------------------- /src/03runtime/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(runtime VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE RUNTIME_SRC src/*.cc src/*.cpp) 6 | add_library(runtime STATIC ${RUNTIME_SRC}) 7 | target_link_libraries(runtime PUBLIC graph_topo hardware) 8 | target_include_directories(runtime PUBLIC include) 9 | 10 | file(GLOB_RECURSE RUNTIME_TEST test/*.cpp) 11 | if(RUNTIME_TEST) 12 | add_executable(runtime_test ${RUNTIME_TEST}) 13 | add_test(runtime_test runtime_test) 14 | target_link_libraries(runtime_test runtime GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/03runtime/README.md: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /src/03runtime/src/resource.cpp: -------------------------------------------------------------------------------- 1 | #include "runtime/resource.h" 2 | 3 | namespace refactor::runtime { 4 | 5 | auto Resource::is(size_t id) const noexcept -> bool { 6 | return resourceTypeId() == id; 7 | } 8 | 9 | auto Resources::fetch(size_t id) noexcept -> Resource * { 10 | auto it = _internal.find(id); 11 | return it != _internal.end() ? it->second.get() : nullptr; 12 | } 13 | auto Resources::fetchOrStore(ResourceBox resource) noexcept -> Resource * { 14 | auto [it, ok] = _internal.try_emplace(resource->resourceTypeId(), std::move(resource)); 15 | return it->second.get(); 16 | } 17 | auto Resources::fetchOrStore(size_t id, std::function fn) -> Resource * { 18 | auto it = _internal.find(id); 19 | if (it == _internal.end()) { 20 | std::tie(it, std::ignore) = _internal.insert({id, fn()}); 21 | } 22 | return it->second.get(); 23 | } 24 | 25 | }// namespace refactor::runtime 26 | -------------------------------------------------------------------------------- /src/04kernel/README.md: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /src/04kernel/cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(kernel_cuda) 3 | 4 | file(GLOB_RECURSE KERNEL_CUDA_SUB_SRC src/*.cu) 5 | 6 | add_library(kernel_cuda STATIC ${KERNEL_CUDA_SUB_SRC}) 7 | target_link_libraries(kernel_cuda PUBLIC common) 8 | target_include_directories(kernel_cuda PUBLIC include) 9 | 10 | file(GLOB_RECURSE KERNEL_CUDA_TEST test/*.cu) 11 | if(KERNEL_CUDA_TEST) 12 | add_executable(kernel_cuda_test ${KERNEL_CUDA_TEST}) 13 | add_test(kernel_cuda_test kernel_cuda_test) 14 | target_link_libraries(kernel_cuda_test kernel_cuda GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/concat.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_CONCAT_CUH 2 | #define KERNEL_CUDA_CONCAT_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | void launchConcat( 9 | KernelLaunchParameters const &, 10 | void const **inputs, unsigned int const *segments, void *output, 11 | unsigned int inputCount, 12 | unsigned int sum, 13 | unsigned int sub); 14 | 15 | }// namespace refactor::kernel::cuda 16 | 17 | #endif// KERNEL_CUDA_CONCAT_CUH 18 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/expand.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_EXPAND_CUH 2 | #define KERNEL_CUDA_EXPAND_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | namespace expand { 9 | struct DimStride { 10 | unsigned int i, o; 11 | }; 12 | }// namespace expand 13 | 14 | void launchExpand( 15 | KernelLaunchParameters const &, 16 | void const *data, expand::DimStride const *strides, void *output, 17 | unsigned int rank, 18 | unsigned int eleSize); 19 | 20 | }// namespace refactor::kernel::cuda 21 | 22 | #endif// KERNEL_CUDA_EXPAND_CUH 23 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/functions.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_FUNCTIONS_CUH 2 | #define KERNEL_CUDA_FUNCTIONS_CUH 3 | 4 | namespace refactor::kernel::cuda { 5 | 6 | int currentDevice(); 7 | 8 | void sync(); 9 | 10 | void setCudaDevice(int); 11 | 12 | void copyOut(void *dst, const void *src, size_t size); 13 | 14 | }// namespace refactor::kernel::cuda 15 | 16 | #endif// KERNEL_CUDA_FUNCTIONS_CUH 17 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/gather.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_GATHER_CUH 2 | #define KERNEL_CUDA_GATHER_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | void launchGather( 9 | KernelLaunchParameters const &, 10 | void const *data, void const *indices, void *output, 11 | bool i64, 12 | unsigned int batch, 13 | unsigned int unit, 14 | unsigned int midSizeI, 15 | unsigned int midSizeO); 16 | 17 | }// namespace refactor::kernel::cuda 18 | 19 | #endif// KERNEL_CUDA_GATHER_CUH 20 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/pad.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_PAD_CUH 2 | #define KERNEL_CUDA_PAD_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | #include 6 | 7 | namespace refactor::kernel::cuda { 8 | 9 | struct PadDimInfo { 10 | unsigned int strideI, strideO, padS, dimI; 11 | }; 12 | 13 | void launchPad( 14 | KernelLaunchParameters const &, 15 | uint8_t const *src, uint8_t const *src_const, 16 | PadDimInfo const *dims, void *output, 17 | unsigned int rank, 18 | unsigned int blockSize); 19 | 20 | }// namespace refactor::kernel::cuda 21 | 22 | #endif// KERNEL_CUDA_PAD_CUH 23 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/scatter_nd.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_SCATTER_ND_CUH 2 | #define KERNEL_CUDA_SCATTER_ND_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | void launchScatterND( 9 | KernelLaunchParameters const &, 10 | void const *data, 11 | void const *indices, 12 | void const *updates, 13 | void *output, 14 | unsigned int const *strides, 15 | size_t rank, 16 | unsigned int blockCount, 17 | size_t blockSize); 18 | 19 | }// namespace refactor::kernel::cuda 20 | 21 | #endif// KERNEL_CUDA_SCATTER_ND_CUH 22 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/slice.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_SLICE_CUH 2 | #define KERNEL_CUDA_SLICE_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | struct SliceDimInfo { 9 | unsigned int strideO, skip; 10 | int strideI; 11 | }; 12 | 13 | void launchSlice( 14 | KernelLaunchParameters const &, 15 | void const *src, SliceDimInfo const *dims, void *output, 16 | unsigned int rank, 17 | unsigned int blockSize); 18 | 19 | }// namespace refactor::kernel::cuda 20 | 21 | #endif// KERNEL_CUDA_SLICE_CUH 22 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/split.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_SPLIT_CUH 2 | #define KERNEL_CUDA_SPLIT_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | void launchSplit( 9 | KernelLaunchParameters const &, 10 | void const *data, unsigned int const *segments, void **outputs, 11 | unsigned int outputCount, 12 | unsigned int sum, 13 | unsigned int sub); 14 | 15 | }// namespace refactor::kernel::cuda 16 | 17 | #endif// KERNEL_CUDA_SPLIT_CUH 18 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/threads_distributer.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_THREADS_DISTRIBUTER_CUH 2 | #define KERNEL_CUDA_THREADS_DISTRIBUTER_CUH 3 | 4 | namespace refactor::kernel::cuda { 5 | 6 | /// @brief 内核的启动参数。 7 | struct KernelLaunchParameters { 8 | /// @brief 网格中块的数量和块中线程的数量。 9 | int gridSize, blockSize; 10 | /// @brief 要处理任务总量。 11 | size_t n; 12 | /// @brief 用于执行内核的流。 13 | void *stream; 14 | }; 15 | 16 | class ThreadsDistributer { 17 | int _maxGridSize; 18 | 19 | public: 20 | ThreadsDistributer(); 21 | 22 | KernelLaunchParameters operator()(size_t n) const; 23 | }; 24 | 25 | }// namespace refactor::kernel::cuda 26 | 27 | #endif// KERNEL_CUDA_THREADS_DISTRIBUTER_CUH 28 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/transpose.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_TRANSPOSE_CUH 2 | #define KERNEL_CUDA_TRANSPOSE_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | namespace transpose { 9 | struct DimStride { 10 | unsigned int i, o; 11 | }; 12 | }// namespace transpose 13 | 14 | void launchTranspose( 15 | KernelLaunchParameters const &, 16 | void const *data, transpose::DimStride const *strides, void *output, 17 | unsigned int rank, 18 | unsigned int eleSize); 19 | 20 | }// namespace refactor::kernel::cuda 21 | 22 | #endif// KERNEL_CUDA_TRANSPOSE_CUH 23 | -------------------------------------------------------------------------------- /src/04kernel/cuda/include/kernel/cuda/where.cuh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDA_WHERE_CUH 2 | #define KERNEL_CUDA_WHERE_CUH 3 | 4 | #include "threads_distributer.cuh" 5 | 6 | namespace refactor::kernel::cuda { 7 | 8 | void launchWhere( 9 | KernelLaunchParameters const &, 10 | unsigned int const *strides, 11 | void const *c, 12 | void const *x, 13 | void const *y, 14 | void *output, 15 | unsigned int rank, 16 | unsigned int eleSize); 17 | 18 | }// namespace refactor::kernel::cuda 19 | 20 | #endif// KERNEL_CUDA_WHERE_CUH 21 | -------------------------------------------------------------------------------- /src/04kernel/cuda/src/functions.cu: -------------------------------------------------------------------------------- 1 | #include "kernel/cuda/functions.cuh" 2 | #include "macro.cuh" 3 | #include 4 | 5 | namespace refactor::kernel::cuda { 6 | 7 | int currentDevice() { 8 | int device; 9 | CUDA_ASSERT(cudaGetDevice(&device)); 10 | return device; 11 | } 12 | 13 | void sync() { 14 | CUDA_ASSERT(cudaDeviceSynchronize()); 15 | } 16 | 17 | void copyOut(void *dst, const void *src, size_t size) { 18 | sync(); 19 | CUDA_ASSERT(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); 20 | } 21 | 22 | void setCudaDevice(int id) { 23 | cudaSetDevice(id); 24 | } 25 | 26 | }// namespace refactor::kernel::cuda 27 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/allocators.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_ALLOCATOR_H 2 | #define KERNEL_ALLOCATOR_H 3 | 4 | #include "graph.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | AllocScheme flatAllocate( 9 | graph_topo::GraphTopo const &, 10 | std::vector, 11 | std::vector const &, 12 | size_t); 13 | 14 | AllocScheme reusableAllocate( 15 | graph_topo::GraphTopo const &, 16 | std::vector, 17 | std::vector const &, 18 | size_t); 19 | 20 | }// namespace refactor::kernel 21 | 22 | #endif// KERNEL_ALLOCATOR_H 23 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/broadcaster.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_BROADCASTER_H 2 | #define KERNEL_BROADCASTER_H 3 | 4 | #include "../tensor.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | /// @brief 优化用于计算的通用广播描述。 9 | struct Broadcaster { 10 | /// @brief 所有输入输出的各维度步长。 11 | std::vector strides; 12 | /// @brief 输出的总大小和输入的数量。 13 | dim_t outputsCount, inputsCount; 14 | 15 | explicit Broadcaster(std::vector>); 16 | explicit Broadcaster(TensorRefs const &inputs); 17 | void locate(dim_t k, dim_t ans[]) const noexcept; 18 | bool needBroadcast() const noexcept; 19 | }; 20 | 21 | }// namespace refactor::kernel 22 | 23 | #endif// KERNEL_BROADCASTER_H 24 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/communication.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_COMMUNICATION_ATTRIBUTES_H 2 | #define KERNEL_COMMUNICATION_ATTRIBUTES_H 3 | 4 | namespace refactor::kernel { 5 | enum class AllReduceType { 6 | Sum, 7 | Avg, 8 | Min, 9 | Max, 10 | Prod 11 | }; 12 | } 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/expand_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_EXPAND_INFO_H 2 | #define KERNEL_EXPAND_INFO_H 3 | 4 | #include "../tensor.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | /// @brief 优化用于计算的单向广播描述。 9 | struct ExpandInfo { 10 | struct Dim { 11 | dim_t i, o; 12 | 13 | bool operator==(Dim const &) const noexcept; 14 | bool operator!=(Dim const &) const noexcept; 15 | }; 16 | 17 | /// @brief 所有输入输出的各维度步长。 18 | std::vector strides; 19 | dim_t blockCount, blockSize; 20 | 21 | ExpandInfo(DataType, slice_t input, slice_t output); 22 | ExpandInfo(Tensor const &input, Tensor const &output); 23 | ExpandInfo reform(dim_t maxblockSize) const noexcept; 24 | void reformAssign(dim_t maxblockSize) noexcept; 25 | }; 26 | 27 | }// namespace refactor::kernel 28 | 29 | #endif// KERNEL_EXPAND_INFO_H 30 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/gather_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_GATHER_INFO_H 2 | #define KERNEL_GATHER_INFO_H 3 | 4 | #include "../tensor.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | /// @brief 优化用于计算的 Gather 描述。 9 | struct GatherInfo { 10 | /// @brief Gather 的计算是 `prefix` 次将 `postfix` 大小的数据从输入拷贝到输出。 11 | /// 其中输入的总大小是 `prefix * midSizeI * postfix`, 12 | /// 输出的总大小是 `prefix * midSizeO * postfix`, 13 | /// 通过保存这两个值可以计算输入输出位置的偏移。 14 | dim_t prefix, postfix, midSizeI, midSizeO; 15 | /// @brief `indices` 的数据类型,可以是 `I32` 或 `I64`。 16 | DataType idxType; 17 | 18 | GatherInfo(dim_t axis, Tensor const &data, Tensor const &indices) noexcept; 19 | }; 20 | 21 | }// namespace refactor::kernel 22 | 23 | #endif// KERNEL_GATHER_INFO_H 24 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/mat_mul_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MAT_MUL_INFO_H 2 | #define KERNEL_MAT_MUL_INFO_H 3 | 4 | #include "kernel/attributes/broadcaster.h" 5 | #include "kernel/attributes/expand_info.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct MatMulInfo { 10 | DataType dataType; 11 | float alpha, beta; 12 | bool transA, transB; 13 | dim_t m, k, n; 14 | // Expand operation info for biasd 15 | std::optional biasExpand; 16 | // A 2-directional broadcaster that deals with dimensions before the last 2 dimensions 17 | Broadcaster broadcaster; 18 | 19 | MatMulInfo(Tensor const &, Tensor const &, 20 | std::optional>, 21 | bool, bool, float, float); 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_MAT_MUL_INFO_H 27 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/mat_mul_integer_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MAT_MUL_INTEGER_INFO_H 2 | #define KERNEL_MAT_MUL_INTEGER_INFO_H 3 | 4 | #include "kernel/attributes/broadcaster.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct MatMulIntegerInfo { 9 | struct Input { 10 | bool 11 | withZeroPoint, 12 | signed_, 13 | scalar; 14 | 15 | Input(TensorRefs const &, size_t i) noexcept; 16 | }; 17 | 18 | Input a, b; 19 | dim_t m, k, n; 20 | Broadcaster broadcaster; 21 | 22 | explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept; 23 | dim_t batch() const noexcept; 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_MAT_MUL_INTEGER_INFO_H 29 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/scatter_nd_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SCATTER_ND_INFO_H 2 | #define KERNEL_SCATTER_ND_INFO_H 3 | 4 | #include "../tensor.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | /// @brief 优化用于计算的 ScatterND 描述。 9 | struct ScatterNDInfo { 10 | dim_t prefix, blockCount; 11 | std::vector strides; 12 | size_t blockSize; 13 | 14 | ScatterNDInfo(Tensor const &data, 15 | Tensor const &indices); 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_SCATTER_ND_INFO_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/softmax_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SOFTMAX_INFO_H 2 | #define KERNEL_SOFTMAX_INFO_H 3 | 4 | #include "../tensor.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct SoftmaxInfo { 9 | dim_t pre, mid, post; 10 | DataType type; 11 | 12 | SoftmaxInfo(Tensor const &data, dim_t axis) noexcept; 13 | }; 14 | 15 | }// namespace refactor::kernel 16 | 17 | #endif// KERNEL_SOFTMAX_INFO_H 18 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/split_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SPLIT_INFO_H 2 | #define KERNEL_SPLIT_INFO_H 3 | 4 | #include "../tensor.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | /// @brief 优化用于计算的 Split 描述。 9 | struct SplitInfo { 10 | /// @brief 要拷贝的次数和每次拷贝的大小,即所有片段的总大小。 11 | /// NOTICE 要拷贝的次数最小值可以取到 1,表示把数据分成几个连续的块。 12 | dim_t blockCount, sum; 13 | /// @brief 要拷贝的每个片段的大小,已经考虑了每个数据的大小。 14 | absl::InlinedVector segments; 15 | 16 | SplitInfo(dim_t axis, TensorRefs const &outputs); 17 | dim_t unit(dim_t maxBlockSize) const noexcept; 18 | }; 19 | 20 | }// namespace refactor::kernel 21 | 22 | #endif// KERNEL_SPLIT_INFO_H 23 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/attributes/transpose_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_TRANPOSE_INFO_H 2 | #define KERNEL_TRANPOSE_INFO_H 3 | 4 | #include "common.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | using Shape = absl::InlinedVector; 9 | using Permutation = Shape; 10 | 11 | /// @brief 优化用于计算的转置描述。 12 | struct TransposeInfo { 13 | struct Dimension { 14 | dim_t strideI, strideO; 15 | }; 16 | 17 | /// @brief 转置信息包含 `(1+1)rank` 个元素。 18 | /// 由于 rank 常常取 4,参数总数也往往至少有 8 个。 19 | /// 如果使用 uint32_t 并 inline,则共 8x4+8 = 40 字节, 20 | /// 这样拷贝开销还是可以接受的。 21 | absl::InlinedVector dims; 22 | dim_t blockSize, blockCount; 23 | 24 | TransposeInfo(DataType, Shape const &, Permutation const &); 25 | dim_t locate(dim_t) const noexcept; 26 | }; 27 | 28 | }// namespace refactor::kernel 29 | 30 | #endif// KERNEL_TRANPOSE_INFO_H 31 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/blob.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_BLOB_H 2 | #define KERNEL_BLOB_H 3 | 4 | #include "common.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | /// @brief 一次初始化的内存块。 9 | class Blob { 10 | /// @brief ! NOTICE 指针必须非空。 11 | void *_ptr; 12 | 13 | explicit Blob(size_t); 14 | 15 | public: 16 | Blob(Blob const &) = delete; 17 | Blob(Blob &&) = delete; 18 | ~Blob(); 19 | 20 | static std::pair, void *> share(size_t); 21 | operator void const *() const noexcept; 22 | template T const *get() const noexcept { 23 | return reinterpret_cast(_ptr); 24 | } 25 | }; 26 | 27 | }// namespace refactor::kernel 28 | 29 | #endif// KERNEL_BLOB_H 30 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collector.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CANDIDATE_H 2 | #define KERNEL_CANDIDATE_H 3 | 4 | #include "hardware/device.h" 5 | #include "kernel.h" 6 | #include "tensor.h" 7 | 8 | namespace refactor::kernel { 9 | 10 | class InfoCollector { 11 | protected: 12 | hardware::Device::Type _target; 13 | constexpr explicit InfoCollector(decltype(_target) target) 14 | : _target(target) {} 15 | 16 | public: 17 | virtual ~InfoCollector() = default; 18 | virtual std::vector 19 | filter(TensorRefs inputs, TensorRefs outputs) const = 0; 20 | }; 21 | 22 | using CollectorBox = std::unique_ptr; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_CANDIDATE_H 27 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/all_reduce.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_COLLECTOR_ALL_REDUCE_H 2 | #define KERNEL_COLLECTOR_ALL_REDUCE_H 3 | 4 | #include "../collector.h" 5 | #include "kernel/attributes/communication.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct AllReduceCollector final : public InfoCollector { 10 | 11 | AllReduceType type; 12 | 13 | constexpr AllReduceCollector(decltype(_target) target, AllReduceType type_) noexcept 14 | : InfoCollector(target), type(type_) {} 15 | 16 | std::vector 17 | filter(TensorRefs inputs, TensorRefs outputs) const final; 18 | }; 19 | }// namespace refactor::kernel 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/attention.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_ATTENTION_H 2 | #define KERNEL_ATTENTION_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct AttentionCollector final : public InfoCollector { 9 | dim_t maxSeqLen; 10 | 11 | AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept; 12 | 13 | std::vector 14 | filter(TensorRefs inputs, TensorRefs outputs) const final; 15 | }; 16 | 17 | }// namespace refactor::kernel 18 | 19 | #endif// KERNEL_ATTENTION_H 20 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/batch_normalization.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_BATCH_NORMALIZATION_H 2 | #define KERNEL_BATCH_NORMALIZATION_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct BatchNormalizationCollector final : public InfoCollector { 9 | float epsilon; 10 | 11 | constexpr BatchNormalizationCollector(decltype(_target) target, float epsilon_) noexcept 12 | : InfoCollector(target), epsilon(epsilon_) {} 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_BATCH_NORMALIZATION_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/cast.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CAST_H 2 | #define KERNEL_CAST_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct CastCollector final : public InfoCollector { 9 | 10 | explicit CastCollector(decltype(_target)) noexcept; 11 | 12 | std::vector 13 | filter(TensorRefs inputs, TensorRefs outputs) const final; 14 | }; 15 | 16 | }// namespace refactor::kernel 17 | 18 | #endif// KERNEL_CAST_H 19 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/clip.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CLIP_H 2 | #define KERNEL_CLIP_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct ClipCollector final : public InfoCollector { 9 | 10 | explicit ClipCollector(decltype(_target)) noexcept; 11 | 12 | std::vector 13 | filter(TensorRefs inputs, TensorRefs outputs) const final; 14 | }; 15 | 16 | }// namespace refactor::kernel 17 | 18 | #endif// KERNEL_CLIP_H 19 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/concat.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CONCAT_H 2 | #define KERNEL_CONCAT_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct ConcatCollector final : public InfoCollector { 9 | uint32_t axis; 10 | 11 | constexpr ConcatCollector(decltype(_target) target, uint32_t axis_) noexcept 12 | : InfoCollector(target), axis(axis_) {} 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_CONCAT_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/conv.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CONV_H 2 | #define KERNEL_CONV_H 3 | 4 | #include "../attributes/pool_attributes.h" 5 | #include "../collector.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ConvCollector final : public InfoCollector { 10 | PoolAttributes poolAttrs; 11 | 12 | ConvCollector(decltype(_target), PoolAttributes) noexcept; 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_CONV_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/dequantize_linear.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_DEQUANTIZE_LINEAR_H 2 | #define KERNEL_DEQUANTIZE_LINEAR_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct DequantizeLinearCollector final : public InfoCollector { 9 | 10 | explicit DequantizeLinearCollector(decltype(_target)) noexcept; 11 | 12 | std::vector 13 | filter(TensorRefs inputs, TensorRefs outputs) const final; 14 | }; 15 | 16 | }// namespace refactor::kernel 17 | 18 | #endif// KERNEL_DEQUANTIZE_LINEAR_H 19 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/dynamic_quantize_linear.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_H 2 | #define KERNEL_DYNAMIC_QUANTIZE_LINEAR_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct DynamicQuantizeLinearCollector final : public InfoCollector { 9 | 10 | explicit DynamicQuantizeLinearCollector(decltype(_target)) noexcept; 11 | 12 | std::vector 13 | filter(TensorRefs inputs, TensorRefs outputs) const final; 14 | }; 15 | 16 | }// namespace refactor::kernel 17 | 18 | #endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_H 19 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/gather.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_GATHER_H 2 | #define KERNEL_GATHER_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct GatherCollector final : public InfoCollector { 9 | uint32_t axis; 10 | 11 | constexpr GatherCollector(decltype(_target) target, uint32_t axis_) noexcept 12 | : InfoCollector(target), axis(axis_) {} 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_GATHER_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/global_pool.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_GLOBAL_POOL_H 2 | #define KERNEL_GLOBAL_POOL_H 3 | 4 | #include "../attributes/pool_attributes.h" 5 | #include "../collector.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct GlobalPoolCollector final : public InfoCollector { 10 | PoolType type; 11 | 12 | GlobalPoolCollector(decltype(_target), PoolType) noexcept; 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_GLOBAL_POOL_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/hard_sigmoid.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_HARD_SIGMOIG_H 2 | #define KERNEL_HARD_SIGMOIG_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct HardSigmoidCollector final : public InfoCollector { 9 | float alpha, beta; 10 | 11 | constexpr HardSigmoidCollector(decltype(_target) target, float alpha_, float beta_) noexcept 12 | : InfoCollector(target), alpha(alpha_), beta(beta_) {} 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | }// namespace refactor::kernel 18 | 19 | #endif// KERNEL_HARD_SIGMOIG_H 20 | 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/mat_mul.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MAT_MUL_H 2 | #define KERNEL_MAT_MUL_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct MatMulCollector final : public InfoCollector { 9 | float alpha, beta; 10 | bool transA, transB; 11 | 12 | constexpr MatMulCollector(decltype(_target) target, float alpha_, float beta_, bool transA_, bool transB_) noexcept 13 | : InfoCollector(target), alpha(alpha_), beta(beta_), transA(transA_), transB(transB_) {} 14 | 15 | std::vector 16 | filter(TensorRefs inputs, TensorRefs outputs) const final; 17 | }; 18 | 19 | }// namespace refactor::kernel 20 | 21 | #endif// KERNEL_MAT_MUL_H 22 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/mat_mul_integer.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MAT_MUL_INTEGER_H 2 | #define KERNEL_MAT_MUL_INTEGER_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct MatMulIntegerCollector final : public InfoCollector { 9 | 10 | constexpr MatMulIntegerCollector(decltype(_target) target) noexcept 11 | : InfoCollector(target) {} 12 | 13 | std::vector 14 | filter(TensorRefs inputs, TensorRefs outputs) const final; 15 | }; 16 | 17 | }// namespace refactor::kernel 18 | 19 | #endif// KERNEL_MAT_MUL_INTEGER_H 20 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/pad.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_PAD_H 2 | #define KERNEL_PAD_H 3 | 4 | #include "../attributes/pad_info.h" 5 | #include "../collector.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct PadCollector final : public InfoCollector { 10 | PadDimension dims; 11 | PadType mode; 12 | 13 | explicit PadCollector(decltype(_target) target, PadDimension const &dims_, PadType mode_) noexcept 14 | : InfoCollector(target), dims(std::move(dims_)), mode(mode_) {} 15 | 16 | std::vector 17 | filter(TensorRefs inputs, TensorRefs outputs) const final; 18 | }; 19 | }// namespace refactor::kernel 20 | 21 | #endif// KERNEL_PAD_H 22 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/pool.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_POOL_H 2 | #define KERNEL_POOL_H 3 | 4 | #include "../attributes/pool_attributes.h" 5 | #include "../collector.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct PoolCollector final : public InfoCollector { 10 | PoolType type; 11 | bool ceil; 12 | KernelShape kernelShape; 13 | PoolAttributes attributes; 14 | 15 | PoolCollector(decltype(_target), PoolType, bool, KernelShape, PoolAttributes) noexcept; 16 | 17 | std::vector 18 | filter(TensorRefs inputs, TensorRefs outputs) const final; 19 | }; 20 | 21 | }// namespace refactor::kernel 22 | 23 | #endif// KERNEL_POOL_H 24 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/reduce.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_REDUCE_H 2 | #define KERNEL_REDUCE_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | using Axes = absl::InlinedVector; 9 | 10 | enum class ReduceType { 11 | Mean, 12 | L1, 13 | L2, 14 | LogSum, 15 | LogSumExp, 16 | Max, 17 | Min, 18 | Prod, 19 | Sum, 20 | SumSquare, 21 | }; 22 | 23 | struct ReduceCollector final : public InfoCollector { 24 | ReduceType reduceType; 25 | Axes axes; 26 | 27 | ReduceCollector(decltype(_target), ReduceType, Axes) noexcept; 28 | 29 | std::vector filter(TensorRefs inputs, TensorRefs outputs) const final; 30 | }; 31 | }// namespace refactor::kernel 32 | 33 | #endif// KERNEL_REDUCE_H 34 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/rms_normalization.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_RMS_NORMALIZATION_H 2 | #define KERNEL_RMS_NORMALIZATION_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct RmsNormalizationCollector final : public InfoCollector { 9 | float epsilon; 10 | 11 | constexpr RmsNormalizationCollector(decltype(_target) target, float epsilon_) noexcept 12 | : InfoCollector(target), epsilon(epsilon_) {} 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_RMS_NORMALIZATION_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/scatter_nd.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SCATTER_ND_H 2 | #define KERNEL_SCATTER_ND_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct ScatterNDCollector final : public InfoCollector { 9 | 10 | explicit ScatterNDCollector(decltype(_target)) noexcept; 11 | 12 | std::vector 13 | filter(TensorRefs inputs, TensorRefs outputs) const final; 14 | }; 15 | 16 | }// namespace refactor::kernel 17 | 18 | #endif// KERNEL_SCATTER_ND_H 19 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/select.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SELECT_H 2 | #define KERNEL_SELECT_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | enum class SelectType { 9 | Max, 10 | Min, 11 | }; 12 | 13 | std::string_view opName(SelectType type); 14 | 15 | struct SelectCollector final : public InfoCollector { 16 | SelectType selectType; 17 | 18 | SelectCollector(decltype(_target), SelectType) noexcept; 19 | 20 | std::vector 21 | filter(TensorRefs inputs, TensorRefs outputs) const final; 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_SELECT_H 27 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/simple_binary.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SIMPLE_BINARY_H 2 | #define KERNEL_SIMPLE_BINARY_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | enum class SimpleBinaryType : uint8_t { 9 | Add, 10 | Sub, 11 | Mul, 12 | Div, 13 | Pow, 14 | And, 15 | Or, 16 | Xor, 17 | Mod, 18 | Fmod, 19 | }; 20 | 21 | std::string_view opName(SimpleBinaryType type); 22 | 23 | struct SimpleBinaryCollector final : public InfoCollector { 24 | SimpleBinaryType type; 25 | 26 | constexpr SimpleBinaryCollector(decltype(_target) target, SimpleBinaryType type_) noexcept 27 | : InfoCollector(target), type(type_) {} 28 | 29 | std::vector 30 | filter(TensorRefs inputs, TensorRefs outputs) const final; 31 | }; 32 | 33 | }// namespace refactor::kernel 34 | 35 | #endif// KERNEL_SIMPLE_BINARY_H 36 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/slice.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SLICE_H 2 | #define KERNEL_SLICE_H 3 | 4 | #include "../attributes/slice_info.h" 5 | #include "../collector.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SliceCollector final : public InfoCollector { 10 | Dimensions dimentions; 11 | 12 | SliceCollector(decltype(_target), Dimensions) noexcept; 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_SLICE_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/softmax.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SOFTMAX_H 2 | #define KERNEL_SOFTMAX_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct SoftmaxCollector final : public InfoCollector { 9 | dim_t axis; 10 | 11 | constexpr SoftmaxCollector(decltype(_target) target, dim_t axis_) noexcept 12 | : InfoCollector(target), axis(axis_) {} 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_SOFTMAX_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/split.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SPLIT_H 2 | #define KERNEL_SPLIT_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct SplitCollector final : public InfoCollector { 9 | uint32_t axis; 10 | 11 | constexpr SplitCollector(decltype(_target) target, uint32_t axis_) noexcept 12 | : InfoCollector(target), axis(axis_) {} 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_SPLIT_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/transpose.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_TRANSPOSE_H 2 | #define KERNEL_TRANSPOSE_H 3 | 4 | #include "../attributes/transpose_info.h" 5 | #include "../collector.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct TransposeCollector final : public InfoCollector { 10 | Permutation perm; 11 | 12 | TransposeCollector(decltype(_target), decltype(perm)) noexcept; 13 | 14 | std::vector 15 | filter(TensorRefs inputs, TensorRefs outputs) const final; 16 | }; 17 | 18 | }// namespace refactor::kernel 19 | 20 | #endif// KERNEL_TRANSPOSE_H 21 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/collectors/where.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_WHERE_H 2 | #define KERNEL_WHERE_H 3 | 4 | #include "../collector.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct WhereCollector final : public InfoCollector { 9 | constexpr WhereCollector(decltype(_target) target) noexcept 10 | : InfoCollector(target) {} 11 | 12 | std::vector 13 | filter(TensorRefs inputs, TensorRefs outputs) const final; 14 | }; 15 | 16 | }// namespace refactor::kernel 17 | 18 | #endif// KERNEL_WHERE_H 19 | -------------------------------------------------------------------------------- /src/04kernel/include/kernel/kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_KERNEL_H 2 | #define KERNEL_KERNEL_H 3 | 4 | #include "runtime/stream.h" 5 | #include 6 | 7 | namespace refactor::kernel { 8 | using runtime::Resources; 9 | using runtime::Routine; 10 | using RoutineWorkspace = runtime::Node; 11 | 12 | class Kernel { 13 | public: 14 | virtual ~Kernel() = default; 15 | virtual size_t kernelTypeId() const = 0; 16 | virtual std::string_view description() const = 0; 17 | virtual RoutineWorkspace lower(Resources &) const; 18 | 19 | template 20 | bool is(Args &&...args) const noexcept { 21 | return this->kernelTypeId() == T::typeId(std::forward(args)...); 22 | } 23 | }; 24 | 25 | using KernelBox = std::unique_ptr; 26 | 27 | }// namespace refactor::kernel 28 | 29 | #endif// KERNEL_KERNEL_H 30 | -------------------------------------------------------------------------------- /src/04kernel/src/attributes/pad_2d_info.cc: -------------------------------------------------------------------------------- 1 | #include "pad_2d_info.h" 2 | #include 3 | 4 | namespace refactor::kernel { 5 | 6 | Pad2DInfo::Pad2DInfo(DataType dt, slice_t input, ddim_t const *pads) 7 | : blockCount(std::accumulate(input.begin(), input.end() - 2, 1, std::multiplies<>())), 8 | blockSize(dt.size()), 9 | hw(input.end()[-1] * input.end()[-2]), 10 | w(input.end()[-1]), 11 | padHW(pads[0] - pads[2]), 12 | padW(pads[1] - pads[3]) {} 13 | 14 | }// namespace refactor::kernel 15 | -------------------------------------------------------------------------------- /src/04kernel/src/attributes/pad_2d_info.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_PAD_2D_INFO_H 2 | #define KERNEL_PAD_2D_INFO_H 3 | 4 | #include "kernel/tensor.h" 5 | 6 | namespace refactor::kernel { 7 | /// @brief 优化用于计算的 Slice 描述。 8 | struct Pad2DInfo { 9 | dim_t blockCount, blockSize, hw, w; 10 | ddim_t padHW, padW; 11 | 12 | Pad2DInfo(DataType, slice_t input, ddim_t const *pads); 13 | }; 14 | 15 | }// namespace refactor::kernel 16 | 17 | #endif// KERNEL_PAD_2D_INFO_H 18 | -------------------------------------------------------------------------------- /src/04kernel/src/attributes/scatter_nd_info.cc: -------------------------------------------------------------------------------- 1 | #include "kernel/attributes/scatter_nd_info.h" 2 | #include 3 | 4 | namespace refactor::kernel { 5 | 6 | #define K (indices.shape.back()) 7 | 8 | ScatterNDInfo::ScatterNDInfo( 9 | Tensor const &data, 10 | Tensor const &indices) 11 | : prefix(std::accumulate( 12 | indices.shape.begin(), 13 | indices.shape.begin() + indices.shape.size() - 1, 14 | 1, 15 | std::multiplies())), 16 | blockCount(data.shape[0]), 17 | strides(K, 1), 18 | blockSize(std::accumulate( 19 | data.shape.begin() + K, 20 | data.shape.end(), 21 | data.dataType.size(), 22 | std::multiplies())) { 23 | for (auto i : range0_(K - 1).rev()) { 24 | strides[i] = strides[i + 1] * data.shape[i + 1]; 25 | } 26 | blockCount *= strides[0]; 27 | } 28 | 29 | }// namespace refactor::kernel 30 | -------------------------------------------------------------------------------- /src/04kernel/src/attributes/softmax_info.cc: -------------------------------------------------------------------------------- 1 | #include "kernel/attributes/softmax_info.h" 2 | #include 3 | 4 | namespace refactor::kernel { 5 | 6 | SoftmaxInfo::SoftmaxInfo(Tensor const &data, dim_t axis) noexcept 7 | : pre(0), mid(0), post(0), type(data.dataType) { 8 | 9 | auto axisIt = data.shape.begin() + axis; 10 | pre = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies()); 11 | mid = *axisIt++; 12 | post = std::accumulate(axisIt, data.shape.end(), 1, std::multiplies()); 13 | }; 14 | 15 | }// namespace refactor::kernel 16 | -------------------------------------------------------------------------------- /src/04kernel/src/blob.cc: -------------------------------------------------------------------------------- 1 | #include "kernel/blob.hh" 2 | #include 3 | 4 | namespace refactor::kernel { 5 | 6 | Blob::Blob(size_t bytes) : _ptr(std::malloc(bytes)) {} 7 | Blob::~Blob() { std::free(std::exchange(_ptr, nullptr)); } 8 | 9 | std::pair, void *> 10 | Blob::share(size_t bytes) { 11 | auto blob = Arc(new Blob(bytes)); 12 | auto ptr = blob->_ptr; 13 | return {std::move(blob), ptr}; 14 | } 15 | Blob::operator void const *() const noexcept { return _ptr; } 16 | 17 | }// namespace refactor::kernel 18 | -------------------------------------------------------------------------------- /src/04kernel/src/collectors/all_reduce.cc: -------------------------------------------------------------------------------- 1 | #include "kernel/collectors/all_reduce.h" 2 | #include "../kernels/all_reduce/nccl_kernel.hh" 3 | namespace refactor::kernel { 4 | std::vector 5 | AllReduceCollector::filter(TensorRefs inputs, TensorRefs outputs) const { 6 | std::vector ans; 7 | switch (_target) { 8 | case decltype(_target)::Cpu: 9 | break; 10 | case decltype(_target)::Nvidia: 11 | if (auto ptr = AllReduceNccl::build(type, inputs[0], outputs[0]); ptr) { 12 | ans.emplace_back(std::move(ptr)); 13 | } 14 | break; 15 | default: 16 | UNREACHABLEX(void, "Unknown target"); 17 | } 18 | return ans; 19 | } 20 | }// namespace refactor::kernel 21 | -------------------------------------------------------------------------------- /src/04kernel/src/collectors/where.cc: -------------------------------------------------------------------------------- 1 | #include "kernel/collectors/where.h" 2 | #include "../kernels/where/cpu_kernel.hh" 3 | #include "../kernels/where/where_cuda.hh" 4 | 5 | namespace refactor::kernel { 6 | 7 | std::vector 8 | WhereCollector::filter(TensorRefs inputs, TensorRefs) const { 9 | std::vector ans; 10 | switch (_target) { 11 | case decltype(_target)::Cpu: 12 | if (auto ptr = WhereCpu::build(inputs); ptr) { 13 | ans.emplace_back(std::move(ptr)); 14 | } 15 | break; 16 | case decltype(_target)::Nvidia: 17 | if (auto ptr = WhereCuda::build(inputs); ptr) { 18 | ans.emplace_back(std::move(ptr)); 19 | } 20 | break; 21 | default: 22 | UNREACHABLEX(void, "Unknown target"); 23 | } 24 | return ans; 25 | } 26 | 27 | }// namespace refactor::kernel 28 | -------------------------------------------------------------------------------- /src/04kernel/src/kernel.cc: -------------------------------------------------------------------------------- 1 | #include "kernel/kernel.h" 2 | 3 | namespace refactor::kernel { 4 | 5 | RoutineWorkspace Kernel::lower(Resources &) const { 6 | RUNTIME_ERROR(fmt::format("lower not implemented for {}", description())); 7 | } 8 | 9 | }// namespace refactor::kernel 10 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/all_reduce/nccl_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "nccl_kernel.hh" 2 | #include "../../utilities/cuda/nccl_communicator.hh" 3 | #include 4 | namespace refactor::kernel { 5 | using K = AllReduceNccl; 6 | using DT = DataType; 7 | using namespace nccl; 8 | 9 | auto K::lower(Resources &res) const noexcept -> RoutineWorkspace{ 10 | return [count = size, 11 | redOp = getRedOp(opType), 12 | ncclDataType = getNcclDataType(dataType)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { 13 | auto communicator = res.fetch(); 14 | auto input = inputs[0]; 15 | auto output = outputs[0]; 16 | checkNcclError(ncclAllReduce(input, output, count, ncclDataType, 17 | redOp, communicator->get(), 0));// TODO: use default stream for now 18 | }; 19 | } 20 | }// namespace refactor::kernel 21 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/all_reduce/nccl_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_ALLREDUCE_NCCL_KERNEL_HH 2 | #define KERNEL_ALLREDUCE_NCCL_KERNEL_HH 3 | 4 | #include "kernel/collectors/all_reduce.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct AllReduceNccl final : public Kernel { 10 | AllReduceType opType; 11 | DataType dataType; 12 | size_t size; 13 | 14 | AllReduceNccl(AllReduceType, DataType, size_t) noexcept; 15 | 16 | static KernelBox build(AllReduceType, Tensor const &, Tensor const &) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | #ifdef USE_CUDA 22 | RoutineWorkspace lower(Resources &) const noexcept final; 23 | #endif 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_ALLREDUCE_NCCL_KERNEL_HH 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/attention/cuda_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = AttentionCuda; 5 | 6 | K::AttentionCuda(decltype(info) info_) noexcept 7 | : Kernel(), info(info_) {} 8 | 9 | auto K::build(decltype(info) info) noexcept -> KernelBox { 10 | #ifndef USE_CUDA 11 | return nullptr; 12 | #endif 13 | 14 | return std::make_unique(info); 15 | } 16 | auto K::typeId() noexcept -> size_t { 17 | static uint8_t ID = 1; 18 | return reinterpret_cast(&ID); 19 | } 20 | 21 | auto K::kernelTypeId() const noexcept -> size_t { 22 | return typeId(); 23 | } 24 | auto K::description() const noexcept -> std::string_view { 25 | return "Performing multihead attention on Nvidia gpu"; 26 | } 27 | 28 | }// namespace refactor::kernel 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/attention/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH 2 | #define KERNEL_ATTENTION_CUDA_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct AttentionCuda final : public Kernel { 10 | struct { 11 | DataType dataType; 12 | dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim; 13 | bool resetCache; 14 | } info; 15 | 16 | AttentionCuda(decltype(info)) noexcept; 17 | 18 | static KernelBox build(decltype(info)) noexcept; 19 | static size_t typeId() noexcept; 20 | 21 | size_t kernelTypeId() const noexcept final; 22 | std::string_view description() const noexcept final; 23 | #ifdef USE_CUDA 24 | RoutineWorkspace lower(Resources &) const final; 25 | #endif 26 | }; 27 | 28 | }// namespace refactor::kernel 29 | 30 | #endif// KERNEL_ATTENTION_CUDA_KERNEL_HH 31 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/batch_normalization/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_BATCH_NORMALIZATION_CPU_KERNEL_HH 2 | #define KERNEL_BATCH_NORMALIZATION_CPU_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct BatchNormalization final : public Kernel { 10 | float epsilon; 11 | DataType dts[3]; 12 | Shape shape; 13 | 14 | BatchNormalization(float, DataType, DataType, DataType, Shape) noexcept; 15 | 16 | static KernelBox build(float, TensorRefs) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | RoutineWorkspace lower(Resources &) const noexcept final; 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_BATCH_NORMALIZATION_CPU_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/cast/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CLIP_CPU_KERNEL_HH 2 | #define KERNEL_CLIP_CPU_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct CastCpu final : public Kernel { 10 | DataType from, to; 11 | size_t size; 12 | 13 | CastCpu(decltype(from), decltype(to), decltype(size)) noexcept; 14 | 15 | static KernelBox build(Tensor const &, Tensor const &) noexcept; 16 | static size_t typeId() noexcept; 17 | 18 | size_t kernelTypeId() const noexcept final; 19 | std::string_view description() const noexcept final; 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | }; 22 | 23 | }// namespace refactor::kernel 24 | 25 | #endif// KERNEL_CLIP_CPU_KERNEL_HH 26 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/cast/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CLIP_CUDA_KERNEL_HH 2 | #define KERNEL_CLIP_CUDA_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct CastCuda final : public Kernel { 10 | DataType from, to; 11 | size_t size; 12 | 13 | CastCuda(decltype(from), decltype(to), decltype(size)) noexcept; 14 | 15 | static KernelBox build(Tensor const &, Tensor const &) noexcept; 16 | static size_t typeId() noexcept; 17 | 18 | size_t kernelTypeId() const noexcept final; 19 | std::string_view description() const noexcept final; 20 | #ifdef USE_CUDA 21 | RoutineWorkspace lower(Resources &) const final; 22 | #endif 23 | }; 24 | 25 | }// namespace refactor::kernel 26 | 27 | #endif// KERNEL_CLIP_CUDA_KERNEL_HH 28 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/clip/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CLIP_CPU_KERNEL_HH 2 | #define KERNEL_CLIP_CPU_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ClipCpu final : public Kernel { 10 | DataType dataType; 11 | size_t size; 12 | bool hasMax; 13 | 14 | ClipCpu(decltype(dataType), decltype(size), decltype(hasMax)) noexcept; 15 | 16 | static KernelBox build(Tensor const &, bool hasMax) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | RoutineWorkspace lower(Resources &) const noexcept final; 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_CLIP_CPU_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/clip/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CLIP_CUDA_KERNEL_HH 2 | #define KERNEL_CLIP_CUDA_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ClipCuda final : public Kernel { 10 | DataType dataType; 11 | size_t size; 12 | bool hasMax; 13 | 14 | ClipCuda(decltype(dataType), decltype(size), decltype(hasMax)) noexcept; 15 | 16 | static KernelBox build(Tensor const &, bool hasMax) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | #ifdef USE_CUDA 22 | RoutineWorkspace lower(Resources &) const noexcept final; 23 | #endif 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_CLIP_CUDA_KERNEL_HH 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/concat/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CONCAT_CPU_KERNEL_HH 2 | #define KERNEL_CONCAT_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/split_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ConcatCpu final : public Kernel { 10 | SplitInfo info; 11 | 12 | explicit ConcatCpu(SplitInfo) noexcept; 13 | 14 | static KernelBox build(SplitInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::kernel 23 | 24 | #endif// KERNEL_CONCAT_CPU_KERNEL_HH 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/concat/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CONCAT_CUDA_KERNEL_HH 2 | #define KERNEL_CONCAT_CUDA_KERNEL_HH 3 | 4 | #include "kernel/attributes/split_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ConcatCuda final : public Kernel { 10 | SplitInfo info; 11 | 12 | explicit ConcatCuda(SplitInfo) noexcept; 13 | 14 | static KernelBox build(SplitInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_CONCAT_CUDA_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/dequantize_linear/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH 2 | #define KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct DequantizeLinearCpu final : public Kernel { 10 | DataType from; 11 | size_t size; 12 | bool withZeroPoint; 13 | 14 | DequantizeLinearCpu( 15 | decltype(from), 16 | decltype(size), 17 | decltype(withZeroPoint)) noexcept; 18 | 19 | static KernelBox build(TensorRefs const &, Tensor const &) noexcept; 20 | static size_t typeId() noexcept; 21 | 22 | size_t kernelTypeId() const noexcept final; 23 | std::string_view description() const noexcept final; 24 | RoutineWorkspace lower(Resources &) const noexcept final; 25 | }; 26 | 27 | }// namespace refactor::kernel 28 | 29 | #endif// KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH 30 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH 2 | #define KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct DynamicQuantizeLinearCpu final : public Kernel { 9 | size_t size; 10 | 11 | explicit DynamicQuantizeLinearCpu(decltype(size)) noexcept; 12 | 13 | static KernelBox build(decltype(size)) noexcept; 14 | static size_t typeId() noexcept; 15 | 16 | size_t kernelTypeId() const noexcept final; 17 | std::string_view description() const noexcept final; 18 | RoutineWorkspace lower(Resources &) const noexcept final; 19 | }; 20 | 21 | }// namespace refactor::kernel 22 | 23 | #endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH 24 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = DynamicQuantizeLinearCuda; 5 | 6 | K::DynamicQuantizeLinearCuda(decltype(size) size_) noexcept 7 | : Kernel(), size(size_) {} 8 | 9 | auto K::build(decltype(size) size) noexcept -> KernelBox { 10 | return std::make_unique(size); 11 | } 12 | 13 | auto K::typeId() noexcept -> size_t { 14 | static uint8_t ID = 1; 15 | return reinterpret_cast(&ID); 16 | } 17 | 18 | auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } 19 | auto K::description() const noexcept -> std::string_view { 20 | return "Performing dynamic quantize linear using Nvidia GPU"; 21 | } 22 | 23 | }// namespace refactor::kernel 24 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH 2 | #define KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH 3 | 4 | #include "kernel/kernel.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct DynamicQuantizeLinearCuda final : public Kernel { 9 | size_t size; 10 | 11 | explicit DynamicQuantizeLinearCuda(decltype(size)) noexcept; 12 | 13 | static KernelBox build(decltype(size)) noexcept; 14 | static size_t typeId() noexcept; 15 | 16 | size_t kernelTypeId() const noexcept final; 17 | std::string_view description() const noexcept final; 18 | #ifdef USE_CUDA 19 | RoutineWorkspace lower(Resources &) const final; 20 | #endif 21 | }; 22 | 23 | }// namespace refactor::kernel 24 | 25 | #endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH 26 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/expand/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_EXPAND_CPU_KERNEL_HH 2 | #define KERNEL_EXPAND_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/expand_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ExpandCpu final : public Kernel { 10 | ExpandInfo info; 11 | 12 | explicit ExpandCpu(ExpandInfo) noexcept; 13 | 14 | static KernelBox build(ExpandInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::kernel 23 | 24 | #endif// KERNEL_EXPAND_CPU_KERNEL_HH 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/expand/cuda_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = ExpandCuda; 5 | 6 | K::ExpandCuda(ExpandInfo info_) noexcept 7 | : Kernel(), info(info_.reform(16)) {} 8 | 9 | auto K::build(ExpandInfo info) noexcept -> KernelBox { 10 | #ifndef USE_CUDA 11 | return nullptr; 12 | #endif 13 | 14 | return std::make_unique(std::move(info)); 15 | } 16 | auto K::typeId() noexcept -> size_t { 17 | static uint8_t ID = 1; 18 | return reinterpret_cast(&ID); 19 | } 20 | 21 | auto K::kernelTypeId() const noexcept -> size_t { 22 | return typeId(); 23 | } 24 | auto K::description() const noexcept -> std::string_view { 25 | return "Performing expand operation using CUDA"; 26 | } 27 | 28 | }// namespace refactor::kernel 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/expand/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_EXPAND_CUDA_KERNEL_HH 2 | #define KERNEL_EXPAND_CUDA_KERNEL_HH 3 | 4 | #include "kernel/attributes/expand_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ExpandCuda final : public Kernel { 10 | ExpandInfo info; 11 | 12 | explicit ExpandCuda(ExpandInfo) noexcept; 13 | 14 | static KernelBox build(ExpandInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_EXPAND_CUDA_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/gather/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_GATHER_CPU_KERNEL_HH 2 | #define KERNEL_GATHER_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/gather_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct GatherCpu final : public Kernel { 10 | GatherInfo info; 11 | 12 | explicit GatherCpu(GatherInfo) noexcept; 13 | 14 | static KernelBox build(GatherInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | }; 22 | 23 | }// namespace refactor::kernel 24 | 25 | #endif// KERNEL_GATHER_CPU_KERNEL_HH 26 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/gather/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_GATHER_CUDA_KERNEL_HH 2 | #define KERNEL_GATHER_CUDA_KERNEL_HH 3 | 4 | #include "kernel/attributes/gather_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct GatherCuda final : public Kernel { 10 | GatherInfo info; 11 | 12 | explicit GatherCuda(GatherInfo) noexcept; 13 | 14 | static KernelBox build(GatherInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_TRANSPOSE_CUDA_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_HARD_SIGMOID_CPU_KERNEL_HH 2 | #define KERNEL_HARD_SIGMOID_CPU_KERNEL_HH 3 | 4 | #include "kernel/collectors/hard_sigmoid.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct HardSigmoidCpu final : public Kernel { 10 | float alpha, beta; 11 | DataType dataType; 12 | size_t size; 13 | 14 | explicit HardSigmoidCpu(float, float, DataType, size_t) noexcept; 15 | 16 | static KernelBox build(float, float, Tensor const &) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | RoutineWorkspace lower(Resources &) const noexcept final; 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_HARD_SIGMOID_CPU_KERNEL_HH 27 | 28 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH 2 | #define KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH 3 | 4 | #include "kernel/collectors/hard_sigmoid.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct HardSigmoidCuda final : public Kernel { 10 | float alpha, beta; 11 | DataType dataType; 12 | size_t size; 13 | 14 | explicit HardSigmoidCuda(float, float, DataType, size_t) noexcept; 15 | 16 | static KernelBox build(float, float, Tensor const &) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | #ifdef USE_CUDA 22 | RoutineWorkspace lower(Resources &) const final; 23 | #endif 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/mat_mul/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MATMUL_CPU_KERNEL_HH 2 | #define KERNEL_MATMUL_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/mat_mul_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct MatMulCPU final : public Kernel { 10 | MatMulInfo info; 11 | 12 | explicit MatMulCPU(decltype(info)) noexcept; 13 | 14 | static KernelBox build(decltype(info)) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | }; 22 | 23 | }// namespace refactor::kernel 24 | 25 | #endif// KERNEL_MATMUL_CPU_KERNEL_HH 26 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/mat_mul/cublas_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cublas_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = MatMulCublas; 5 | using DT = DataType; 6 | 7 | K::MatMulCublas(decltype(info) info_) noexcept 8 | : Kernel(), info(std::move(info_)) {} 9 | 10 | auto K::build(decltype(info) info) noexcept -> KernelBox { 11 | #ifndef USE_CUDA 12 | return nullptr; 13 | #endif 14 | 15 | return info.dataType.isIeee754() || info.dataType == DT::I8 16 | ? std::make_unique(std::move(info)) 17 | : nullptr; 18 | } 19 | 20 | auto K::typeId() noexcept -> size_t { 21 | static uint8_t ID = 1; 22 | return reinterpret_cast(&ID); 23 | } 24 | 25 | auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } 26 | auto K::description() const noexcept -> std::string_view { 27 | return "Performing MatMul using CUBLAS"; 28 | } 29 | 30 | }// namespace refactor::kernel 31 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/mat_mul/cublas_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH 2 | #define KERNEL_MATMUL_CUBLAS_KERNEL_HH 3 | 4 | #include "kernel/attributes/mat_mul_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct MatMulCublas final : public Kernel { 10 | MatMulInfo info; 11 | 12 | explicit MatMulCublas(decltype(info)) noexcept; 13 | 14 | static KernelBox build(decltype(info)) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_MATMUL_CUBLAS_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MATMUL_INTEGER_CPU_KERNEL_HH 2 | #define KERNEL_MATMUL_INTEGER_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/mat_mul_integer_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct MatMulIntegerCpu final : public Kernel { 10 | MatMulIntegerInfo info; 11 | 12 | explicit MatMulIntegerCpu(decltype(info)) noexcept; 13 | 14 | static KernelBox build(decltype(info)) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | }; 22 | 23 | }// namespace refactor::kernel 24 | 25 | #endif// KERNEL_MATMUL_INTEGER_CPU_KERNEL_HH 26 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cublas_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = MatMulIntegerCublas; 5 | using DT = DataType; 6 | 7 | K::MatMulIntegerCublas(decltype(info) info_) noexcept 8 | : Kernel(), info(std::move(info_)) {} 9 | 10 | auto K::build(decltype(info) info) noexcept -> KernelBox { 11 | #ifndef USE_CUDA 12 | return nullptr; 13 | #endif 14 | 15 | return std::make_unique(std::move(info)); 16 | } 17 | 18 | auto K::typeId() noexcept -> size_t { 19 | static uint8_t ID = 1; 20 | return reinterpret_cast(&ID); 21 | } 22 | 23 | auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } 24 | auto K::description() const noexcept -> std::string_view { 25 | return "Performing MatMulInteger using CUBLAS"; 26 | } 27 | 28 | }// namespace refactor::kernel 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH 2 | #define KERNEL_MATMUL_CUBLAS_KERNEL_HH 3 | 4 | #include "kernel/attributes/mat_mul_integer_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct MatMulIntegerCublas final : public Kernel { 10 | MatMulIntegerInfo info; 11 | 12 | explicit MatMulIntegerCublas(decltype(info)) noexcept; 13 | 14 | static KernelBox build(decltype(info)) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_MATMUL_CUBLAS_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/pad/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_PAD_CPU_KERNEL_HH 2 | #define KERNEL_PAD_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/pad_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct PadCpu final : public Kernel { 10 | PadInfo info; 11 | PadType mode; 12 | size_t valueLength; 13 | 14 | explicit PadCpu(PadInfo, PadType, size_t) noexcept; 15 | 16 | static KernelBox build(PadInfo, PadType, std::optional>) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | RoutineWorkspace lower(Resources &) const noexcept final; 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_PAD_CPU_KERNEL_HH 27 | 28 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/pad/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_PAD_CUDA_HH 2 | #define KERNEL_PAD_CUDA_HH 3 | 4 | #include "kernel/attributes/pad_info.h" 5 | #include "kernel/collectors/pad.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct PadCuda final : public Kernel { 10 | PadInfo info; 11 | PadType mode; 12 | size_t valueLength; 13 | 14 | PadCuda(PadInfo, PadType, size_t) noexcept; 15 | static KernelBox build(PadInfo, PadType, std::optional>) noexcept; 16 | static size_t typeId() noexcept; 17 | 18 | size_t kernelTypeId() const noexcept final; 19 | std::string_view description() const noexcept final; 20 | #ifdef USE_CUDA 21 | RoutineWorkspace lower(Resources &) const noexcept final; 22 | #endif 23 | }; 24 | 25 | }// namespace refactor::kernel 26 | 27 | #endif//KERNEL_PAD_CUDA_HH 28 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/scatter_nd/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SCATTER_ND_CPU_KERNEL_HH 2 | #define KERNEL_SCATTER_ND_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/scatter_nd_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ScatterNDCpu final : public Kernel { 10 | ScatterNDInfo info; 11 | 12 | explicit ScatterNDCpu(decltype(info)); 13 | 14 | static KernelBox build(decltype(info)) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::kernel 23 | 24 | #endif// KERNEL_SCATTER_ND_CPU_KERNEL_HH 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/scatter_nd/cuda_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = ScatterNDCuda; 5 | 6 | K::ScatterNDCuda(decltype(info) info_) 7 | : Kernel(), info(std::move(info_)) {} 8 | 9 | auto K::build(decltype(info) info) noexcept -> KernelBox { 10 | #ifndef USE_CUDA 11 | return nullptr; 12 | #endif 13 | 14 | return std::make_unique(std::move(info)); 15 | } 16 | auto K::typeId() noexcept -> size_t { 17 | static uint8_t ID = 1; 18 | return reinterpret_cast(&ID); 19 | } 20 | 21 | auto K::kernelTypeId() const noexcept -> size_t { 22 | return typeId(); 23 | } 24 | auto K::description() const noexcept -> std::string_view { 25 | return "Performing scatterNd operation on Nvidia GPU"; 26 | } 27 | 28 | }// namespace refactor::kernel 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/scatter_nd/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SCATTER_ND_CUDA_KERNEL_HH 2 | #define KERNEL_SCATTER_ND_CUDA_KERNEL_HH 3 | 4 | #include "kernel/attributes/scatter_nd_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct ScatterNDCuda final : public Kernel { 10 | ScatterNDInfo info; 11 | 12 | explicit ScatterNDCuda(decltype(info)); 13 | 14 | static KernelBox build(decltype(info)) noexcept; 15 | 16 | static KernelBox build(Tensor const &, bool hasMax) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | #ifdef USE_CUDA 22 | RoutineWorkspace lower(Resources &) const noexcept final; 23 | #endif 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_SCATTER_ND_CUDA_KERNEL_HH 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/simple_binary/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_BINARY_BASIC_CPU_HH 2 | #define KERNEL_BINARY_BASIC_CPU_HH 3 | 4 | #include "kernel/attributes/broadcaster.h" 5 | #include "kernel/collectors/simple_binary.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct BinaryCpu final : public Kernel { 10 | DataType dataType; 11 | SimpleBinaryType opType; 12 | Broadcaster broadcaster; 13 | 14 | BinaryCpu(SimpleBinaryType, DataType, Broadcaster) noexcept; 15 | 16 | static KernelBox build(SimpleBinaryType, 17 | Tensor const &, 18 | Tensor const &) noexcept; 19 | static size_t typeId() noexcept; 20 | 21 | size_t kernelTypeId() const noexcept final; 22 | std::string_view description() const noexcept final; 23 | RoutineWorkspace lower(Resources &) const noexcept final; 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_BINARY_BASIC_CPU_HH 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/simple_unary/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SIMPLE_UNARY_CPU_KERNEL_HH 2 | #define KERNEL_SIMPLE_UNARY_CPU_KERNEL_HH 3 | 4 | #include "kernel/collectors/simple_unary.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SimpleUnaryCpu final : public Kernel { 10 | DataType dataType; 11 | SimpleUnaryType opType; 12 | size_t size; 13 | 14 | SimpleUnaryCpu(SimpleUnaryType, DataType, size_t) noexcept; 15 | 16 | static KernelBox build(SimpleUnaryType, Tensor const &) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | RoutineWorkspace lower(Resources &) const noexcept final; 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_SIMPLE_UNARY_CPU_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/simple_unary/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SIMPLE_UNARY_CUDA_KERNEL_HH 2 | #define KERNEL_SIMPLE_UNARY_CUDA_KERNEL_HH 3 | 4 | #include "kernel/collectors/simple_unary.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SimpleUnaryCuda final : public Kernel { 10 | DataType dataType; 11 | SimpleUnaryType opType; 12 | size_t size; 13 | 14 | SimpleUnaryCuda(SimpleUnaryType, DataType, size_t) noexcept; 15 | 16 | static KernelBox build(SimpleUnaryType, Tensor const &) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | #ifdef USE_CUDA 22 | RoutineWorkspace lower(Resources &) const final; 23 | #endif 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_SIMPLE_UNARY_CUDA_KERNEL_HH 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_ACTIVATION_CUDNN_KERNEL_HH 2 | #define KERNEL_ACTIVATION_CUDNN_KERNEL_HH 3 | 4 | #include "kernel/collectors/simple_unary.h" 5 | 6 | namespace refactor::kernel { 7 | 8 | struct ActivationCudnn final : public Kernel { 9 | SimpleUnaryType type; 10 | DataType dataType; 11 | int size; 12 | 13 | ActivationCudnn(SimpleUnaryType, DataType, int) noexcept; 14 | 15 | static KernelBox build(SimpleUnaryType, Tensor const &) noexcept; 16 | static size_t typeId() noexcept; 17 | 18 | size_t kernelTypeId() const noexcept final; 19 | std::string_view description() const noexcept final; 20 | #ifdef USE_CUDA 21 | RoutineWorkspace lower(Resources &) const final; 22 | #endif 23 | }; 24 | 25 | }// namespace refactor::kernel 26 | 27 | #endif// KERNEL_ACTIVATION_CUDNN_KERNEL_HH 28 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/slice/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SPLIT_CPU_KERNEL_HH 2 | #define KERNEL_SPLIT_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/slice_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SliceCpu final : public Kernel { 10 | SliceInfo info; 11 | 12 | explicit SliceCpu(SliceInfo) noexcept; 13 | 14 | static KernelBox build(SliceInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::kernel 23 | 24 | #endif// KERNEL_SPLIT_CPU_KERNEL_HH 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/slice/cuda_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = SliceCuda; 5 | 6 | K::SliceCuda(SliceInfo info_) noexcept 7 | : Kernel(), info(info_.reform(16)) {} 8 | 9 | auto K::build(SliceInfo info) noexcept -> KernelBox { 10 | #ifndef USE_CUDA 11 | return nullptr; 12 | #endif 13 | 14 | return std::make_unique(std::move(info)); 15 | } 16 | auto K::typeId() noexcept -> size_t { 17 | static uint8_t ID = 1; 18 | return reinterpret_cast(&ID); 19 | } 20 | 21 | auto K::kernelTypeId() const noexcept -> size_t { 22 | return typeId(); 23 | } 24 | auto K::description() const noexcept -> std::string_view { 25 | return "Performing slice operation using CUDA"; 26 | } 27 | 28 | }// namespace refactor::kernel 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/slice/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SPLIT_CUDA_KERNEL_HH 2 | #define KERNEL_SPLIT_CUDA_KERNEL_HH 3 | 4 | #include "kernel/attributes/slice_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SliceCuda final : public Kernel { 10 | SliceInfo info; 11 | 12 | explicit SliceCuda(SliceInfo) noexcept; 13 | 14 | static KernelBox build(SliceInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_SPLIT_CUDA_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/softmax/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SOFTMAX_CPU_KERNEL_HH 2 | #define KERNEL_SOFTMAX_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/softmax_info.h" 5 | #include "kernel/collectors/softmax.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SoftmaxCpu final : public Kernel { 10 | SoftmaxInfo info; 11 | 12 | explicit SoftmaxCpu(SoftmaxInfo) noexcept; 13 | 14 | static KernelBox build(SoftmaxInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::kernel 23 | 24 | #endif// KERNEL_SOFTMAX_CPU_KERNEL_HH 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/softmax/cuda_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = SoftmaxCuda; 5 | 6 | K::SoftmaxCuda(SoftmaxInfo info_) noexcept 7 | : Kernel(), info(std::move(info_)) {} 8 | 9 | auto K::build(SoftmaxInfo info) noexcept -> KernelBox { 10 | #ifndef USE_CUDA 11 | return nullptr; 12 | #endif 13 | 14 | return info.type.isFloat() 15 | ? std::make_unique(std::move(info)) 16 | : nullptr; 17 | } 18 | 19 | auto K::typeId() noexcept -> size_t { 20 | static uint8_t ID = 1; 21 | return reinterpret_cast(&ID); 22 | } 23 | 24 | auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } 25 | auto K::description() const noexcept -> std::string_view { 26 | return "Performing Softmax using CUDA"; 27 | } 28 | 29 | }// namespace refactor::kernel 30 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/softmax/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SOFTMAX_CUDA_HH 2 | #define KERNEL_SOFTMAX_CUDA_HH 3 | 4 | #include "kernel/attributes/softmax_info.h" 5 | #include "kernel/collectors/softmax.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SoftmaxCuda final : public Kernel { 10 | SoftmaxInfo info; 11 | 12 | SoftmaxCuda(SoftmaxInfo) noexcept; 13 | static KernelBox build(SoftmaxInfo) noexcept; 14 | static size_t typeId() noexcept; 15 | 16 | size_t kernelTypeId() const noexcept final; 17 | std::string_view description() const noexcept final; 18 | #ifdef USE_CUDA 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | #endif 21 | }; 22 | 23 | }// namespace refactor::kernel 24 | 25 | #endif//KERNEL_SOFTMAX_CUDA_HH 26 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/split/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SPLIT_CPU_KERNEL_HH 2 | #define KERNEL_SPLIT_CPU_KERNEL_HH 3 | 4 | #include "kernel/attributes/split_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SplitCpu final : public Kernel { 10 | SplitInfo info; 11 | 12 | explicit SplitCpu(SplitInfo) noexcept; 13 | 14 | static KernelBox build(SplitInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::kernel 23 | 24 | #endif// KERNEL_SPLIT_CPU_KERNEL_HH 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/split/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_SPLIT_CUDA_KERNEL_HH 2 | #define KERNEL_SPLIT_CUDA_KERNEL_HH 3 | 4 | #include "kernel/attributes/split_info.h" 5 | #include "kernel/kernel.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct SplitCuda final : public Kernel { 10 | SplitInfo info; 11 | 12 | explicit SplitCuda(SplitInfo) noexcept; 13 | 14 | static KernelBox build(SplitInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_SPLIT_CUDA_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/transpose/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_TRANSPOSE_CPU_KERNEL_HH 2 | #define KERNEL_TRANSPOSE_CPU_KERNEL_HH 3 | 4 | #include "kernel/collectors/transpose.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct TransposeCpu final : public Kernel { 10 | TransposeInfo info; 11 | 12 | TransposeCpu(TransposeInfo) noexcept; 13 | 14 | static KernelBox build(TransposeInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | RoutineWorkspace lower(Resources &) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::kernel 23 | 24 | #endif// KERNEL_TRANSPOSE_CPU_KERNEL_HH 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/transpose/cuda_kernel.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_kernel.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = TransposeCuda; 5 | using Info = TransposeInfo; 6 | 7 | K::TransposeCuda(Info info_) noexcept 8 | : Kernel(), info(std::move(info_)) {} 9 | 10 | auto K::build(Info info) noexcept -> KernelBox { 11 | #ifndef USE_CUDA 12 | return nullptr; 13 | #endif 14 | return std::make_unique(std::move(info)); 15 | } 16 | auto K::typeId() noexcept -> size_t { 17 | static uint8_t ID = 1; 18 | return reinterpret_cast(&ID); 19 | } 20 | 21 | auto K::kernelTypeId() const noexcept -> size_t { 22 | return typeId(); 23 | } 24 | auto K::description() const noexcept -> std::string_view { 25 | return "Performing transpose operation using CUDA"; 26 | } 27 | 28 | }// namespace refactor::kernel 29 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/transpose/cuda_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_TRANSPOSE_CUDA_KERNEL_HH 2 | #define KERNEL_TRANSPOSE_CUDA_KERNEL_HH 3 | 4 | #include "kernel/collectors/transpose.h" 5 | #include "kernel/tensor.h" 6 | 7 | namespace refactor::kernel { 8 | 9 | struct TransposeCuda final : public Kernel { 10 | TransposeInfo info; 11 | 12 | TransposeCuda(TransposeInfo) noexcept; 13 | 14 | static KernelBox build(TransposeInfo) noexcept; 15 | static size_t typeId() noexcept; 16 | 17 | size_t kernelTypeId() const noexcept final; 18 | std::string_view description() const noexcept final; 19 | #ifdef USE_CUDA 20 | RoutineWorkspace lower(Resources &) const noexcept final; 21 | #endif 22 | }; 23 | 24 | }// namespace refactor::kernel 25 | 26 | #endif// KERNEL_TRANSPOSE_CUDA_KERNEL_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/where/cpu_kernel.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_WHERE_CPU_HH 2 | #define KERNEL_WHERE_CPU_HH 3 | 4 | #include "kernel/attributes/broadcaster.h" 5 | #include "kernel/collectors/where.h" 6 | #include "kernel/tensor.h" 7 | 8 | namespace refactor::kernel { 9 | 10 | struct WhereCpu final : public Kernel { 11 | DataType dataType; 12 | Broadcaster broadcaster; 13 | 14 | WhereCpu(DataType, Broadcaster) noexcept; 15 | 16 | static KernelBox build(TensorRefs const &) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | RoutineWorkspace lower(Resources &) const noexcept final; 22 | }; 23 | }// namespace refactor::kernel 24 | 25 | #endif// KERNEL_WHERE_CPU_HH 26 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/where/where_cuda.cc: -------------------------------------------------------------------------------- 1 | #include "where_cuda.hh" 2 | 3 | namespace refactor::kernel { 4 | using K = WhereCuda; 5 | 6 | K::WhereCuda(DataType dataType_, Broadcaster b) noexcept 7 | : Kernel(), 8 | dataType(dataType_), 9 | broadcaster(std::move(b)) {} 10 | 11 | auto K::build(TensorRefs const &inputs) noexcept -> KernelBox { 12 | #ifndef USE_CUDA 13 | return nullptr; 14 | #endif 15 | return std::make_unique(inputs[1].get().dataType, Broadcaster(inputs)); 16 | } 17 | auto K::typeId() noexcept -> size_t { 18 | static uint8_t ID = 1; 19 | return reinterpret_cast(&ID); 20 | } 21 | 22 | auto K::kernelTypeId() const noexcept -> size_t { 23 | return typeId(); 24 | } 25 | auto K::description() const noexcept -> std::string_view { 26 | return "Performing where operation using CUDA"; 27 | } 28 | 29 | }// namespace refactor::kernel 30 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/where/where_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "kernel/cuda/where.cuh" 2 | #include "where_cuda.hh" 3 | #include 4 | 5 | namespace refactor::kernel { 6 | using namespace runtime; 7 | 8 | auto WhereCuda::lower(Resources &) const noexcept -> RoutineWorkspace { 9 | return [strides = thrust::device_vector(broadcaster.strides.begin(), broadcaster.strides.end()), 10 | params = cuda::ThreadsDistributer()(broadcaster.outputsCount), 11 | eleSize = static_cast(dataType.size())](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { 12 | cuda::launchWhere( 13 | params, 14 | strides.data().get(), 15 | inputs[0], 16 | inputs[1], 17 | inputs[2], 18 | outputs[0], 19 | strides.size() / 4, 20 | eleSize); 21 | }; 22 | } 23 | 24 | }// namespace refactor::kernel 25 | -------------------------------------------------------------------------------- /src/04kernel/src/kernels/where/where_cuda.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_WHERE_CUDA_HH 2 | #define KERNEL_WHERE_CUDA_HH 3 | 4 | #include "kernel/attributes/broadcaster.h" 5 | #include "kernel/collectors/where.h" 6 | #include "kernel/tensor.h" 7 | 8 | namespace refactor::kernel { 9 | 10 | struct WhereCuda final : public Kernel { 11 | DataType dataType; 12 | Broadcaster broadcaster; 13 | 14 | WhereCuda(DataType, Broadcaster) noexcept; 15 | 16 | static KernelBox build(TensorRefs const &) noexcept; 17 | static size_t typeId() noexcept; 18 | 19 | size_t kernelTypeId() const noexcept final; 20 | std::string_view description() const noexcept final; 21 | #ifdef USE_CUDA 22 | RoutineWorkspace lower(Resources &) const noexcept final; 23 | #endif 24 | }; 25 | 26 | }// namespace refactor::kernel 27 | 28 | #endif// KERNEL_WHERE_CUDA_HH 29 | -------------------------------------------------------------------------------- /src/04kernel/src/utilities/cuda/cudnn_context.cu: -------------------------------------------------------------------------------- 1 | #include "cudnn_context.hh" 2 | #include "cudnn_functions.h" 3 | 4 | namespace refactor::kernel::cudnn { 5 | 6 | CudnnContext::CudnnContext() : runtime::Resource() { 7 | CUDNN_ASSERT(cudnnCreate(&handle)); 8 | } 9 | CudnnContext::~CudnnContext() { 10 | CUDNN_ASSERT(cudnnDestroy(handle)); 11 | } 12 | 13 | auto CudnnContext::typeId() noexcept -> size_t { 14 | static uint8_t ID = 1; 15 | return reinterpret_cast(&ID); 16 | } 17 | auto CudnnContext::build() -> runtime::ResourceBox { 18 | return std::make_unique(); 19 | } 20 | 21 | auto CudnnContext::resourceTypeId() const noexcept -> size_t { 22 | return typeId(); 23 | } 24 | auto CudnnContext::description() const noexcept -> std::string_view { 25 | return "CudnnContext"; 26 | } 27 | 28 | }// namespace refactor::kernel::cudnn 29 | -------------------------------------------------------------------------------- /src/04kernel/src/utilities/cuda/cudnn_context.hh: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_CUDNN_CONTEXT_HH 2 | #define KERNEL_CUDNN_CONTEXT_HH 3 | 4 | #include "runtime/resource.h" 5 | #include 6 | 7 | namespace refactor::kernel::cudnn { 8 | 9 | struct CudnnContext final : public runtime::Resource { 10 | cudnnHandle_t handle; 11 | 12 | CudnnContext(); 13 | ~CudnnContext(); 14 | CudnnContext(CudnnContext const &) noexcept = delete; 15 | CudnnContext(CudnnContext &&) noexcept = delete; 16 | 17 | static size_t typeId() noexcept; 18 | static runtime::ResourceBox build(); 19 | 20 | size_t resourceTypeId() const noexcept final; 21 | std::string_view description() const noexcept final; 22 | }; 23 | 24 | }// namespace refactor::kernel::cudnn 25 | 26 | #endif// KERNEL_CUDNN_CONTEXT_HH 27 | -------------------------------------------------------------------------------- /src/04kernel/test/attributes/test_expand_info.cpp: -------------------------------------------------------------------------------- 1 | #include "kernel/attributes/expand_info.h" 2 | #include 3 | 4 | using namespace refactor; 5 | using namespace kernel; 6 | 7 | TEST(kernel, ExpandInfo) { 8 | auto input = Tensor::share(DataType::F32, {3, 4, 1, 6}), 9 | output = Tensor::share(DataType::F32, {2, 3, 4, 5, 6}); 10 | 11 | ExpandInfo info(*input, *output); 12 | EXPECT_EQ(info.blockSize, 24); 13 | EXPECT_EQ(info.blockCount, 120); 14 | EXPECT_EQ(info.strides, (std::vector{{0, 60}, {1, 5}, {0, 1}})); 15 | 16 | auto reformed = info.reform(16); 17 | EXPECT_EQ(reformed.blockSize, 8); 18 | EXPECT_EQ(reformed.blockCount, 360); 19 | EXPECT_EQ(reformed.strides, (std::vector{{0, 180}, {3, 15}, {0, 3}, {1, 1}})); 20 | } 21 | -------------------------------------------------------------------------------- /src/04kernel/test/attributes/test_gather_info.cpp: -------------------------------------------------------------------------------- 1 | #include "kernel/attributes/gather_info.h" 2 | #include 3 | 4 | using namespace refactor; 5 | using namespace kernel; 6 | 7 | TEST(kernel, GatherInfo) { 8 | auto data = Tensor::share(DataType::F32, {2, 3, 10, 7, 8}); 9 | auto indices = Tensor::share(DataType::I64, {4, 5, 6}); 10 | GatherInfo info(2, *data, *indices); 11 | EXPECT_EQ(info.prefix, 2 * 3); 12 | EXPECT_EQ(info.postfix, 7 * 8 * DataType(DataType::F32).size()); 13 | EXPECT_EQ(info.midSizeI, 10); 14 | EXPECT_EQ(info.midSizeO, 4 * 5 * 6); 15 | } 16 | -------------------------------------------------------------------------------- /src/04kernel/test/attributes/test_scatter_nd_info.cpp: -------------------------------------------------------------------------------- 1 | #include "kernel/attributes/scatter_nd_info.h" 2 | #include 3 | 4 | using namespace refactor; 5 | using namespace kernel; 6 | 7 | TEST(kernel, ScatterNDInfo) { 8 | auto data = Tensor::share(DataType::F32, {2, 3, 5, 7}); 9 | auto indices = Tensor::share(DataType::I64, {3, 5, 2}); 10 | ScatterNDInfo info(*data, *indices); 11 | EXPECT_EQ(info.prefix, 3 * 5); 12 | EXPECT_EQ(info.strides, (decltype(info.strides){3, 1})); 13 | EXPECT_EQ(info.blockSize, data->dataType.size() * 5 * 7); 14 | } 15 | -------------------------------------------------------------------------------- /src/04kernel/test/generator/test_cuda.cpp: -------------------------------------------------------------------------------- 1 | #ifdef USE_CUDA 2 | 3 | #include "../../src/generator/nvrtc_repo.h" 4 | #include 5 | 6 | using namespace refactor; 7 | using namespace kernel; 8 | 9 | constexpr static const char *code = R"~( 10 | extern "C" __global__ void kernel() { 11 | printf("Hello World from GPU!\n"); 12 | } 13 | )~"; 14 | 15 | TEST(generator, nvrtc) { 16 | auto handler = nvrtc::Handler::compile("helloWorld.cu", code, "kernel"); 17 | handler->launch( 18 | 1, 1, 1, 19 | 1, 1, 1, 20 | 0, nullptr); 21 | } 22 | 23 | #endif// USE_CUDA 24 | -------------------------------------------------------------------------------- /src/05computation/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(computation VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE COMPUTATION_SRC src/*.cc src/*.cpp) 6 | add_library(computation STATIC ${COMPUTATION_SRC}) 7 | target_link_libraries(computation PUBLIC kernel) 8 | target_include_directories(computation PUBLIC include) 9 | 10 | file(GLOB_RECURSE COMPUTATION_TEST test/*.cpp) 11 | if(COMPUTATION_TEST) 12 | add_executable(computation_test ${COMPUTATION_TEST}) 13 | add_test(computation_test computation_test) 14 | target_link_libraries(computation_test computation GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/all_reduce.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_ALL_REDUCE_H 2 | #define COMPUTATION_ALL_REDUCE_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/attributes/communication.h" 6 | 7 | namespace refactor::computation { 8 | 9 | struct AllReduce final : public Operator { 10 | kernel::AllReduceType type; 11 | 12 | constexpr explicit AllReduce(kernel::AllReduceType type_) noexcept 13 | : Operator(), type(type_) {} 14 | 15 | static size_t typeId(kernel::AllReduceType) noexcept; 16 | size_t opTypeId() const noexcept final; 17 | std::string_view name() const noexcept final; 18 | kernel::CollectorBox candidateKernels(Target) const final; 19 | }; 20 | 21 | }// namespace refactor::computation 22 | 23 | #endif// COMPUTATION_ALL_REDUCE_H 24 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/attention.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_ATTENTION_H 2 | #define COMPUTATION_ATTENTION_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Attention final : public Operator { 9 | dim_t maxSeqLen; 10 | 11 | constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept 12 | : Operator(), maxSeqLen(maxSeqLen_) {} 13 | 14 | static size_t typeId() noexcept; 15 | size_t opTypeId() const noexcept final; 16 | std::string_view name() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_ATTENTION_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/batch_normalization.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_BATCH_NORMALIZATION_H 2 | #define COMPUTATION_BATCH_NORMALIZATION_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct BatchNormalization final : public Operator { 9 | float epsilon; 10 | 11 | constexpr explicit BatchNormalization(float epsilon_) noexcept 12 | : Operator(), epsilon(epsilon_) {} 13 | 14 | static size_t typeId() noexcept; 15 | size_t opTypeId() const noexcept final; 16 | std::string_view name() const noexcept final; 17 | kernel::CollectorBox candidateKernels(Target) const final; 18 | std::string serialize() const noexcept final; 19 | }; 20 | 21 | }// namespace refactor::computation 22 | 23 | #endif// COMPUTATION_BATCH_NORMALIZATION_H 24 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/broadcast.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_BROADCAST_H 2 | #define COMPUTATION_BROADCAST_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Broadcast final : public LayoutDependentOperator { 9 | constexpr Broadcast() noexcept : LayoutDependentOperator() {} 10 | 11 | static size_t typeId() noexcept; 12 | size_t opTypeId() const noexcept final; 13 | std::string_view name() const noexcept final; 14 | }; 15 | 16 | }// namespace refactor::computation 17 | 18 | #endif// COMPUTATION_BROADCAST_H 19 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/cast.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_CAST_H 2 | #define COMPUTATION_CAST_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Cast final : public Operator { 9 | 10 | constexpr Cast() noexcept = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | std::string serialize() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_CAST_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/clip.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_CLIP_H 2 | #define COMPUTATION_CLIP_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Clip final : public Operator { 9 | 10 | Clip() = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | std::string serialize() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_CLIP_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/compair.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_COMPAIR_H 2 | #define COMPUTATION_COMPAIR_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | enum class CompairType { 9 | EQ, 10 | NE, 11 | LT, 12 | LE, 13 | GT, 14 | GE, 15 | }; 16 | 17 | struct Compair final : public Operator { 18 | CompairType type; 19 | 20 | constexpr explicit Compair(CompairType type_) noexcept 21 | : Operator(), type(type_) {} 22 | 23 | static size_t typeId(CompairType) noexcept; 24 | size_t opTypeId() const noexcept final; 25 | std::string_view name() const noexcept final; 26 | }; 27 | 28 | }// namespace refactor::computation 29 | 30 | #endif// COMPUTATION_COMPAIR_H 31 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/concat.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_CONCAT_H 2 | #define COMPUTATION_CONCAT_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Concat final : public AxisRankOperator { 9 | constexpr Concat(uint32_t axis, uint32_t rank) noexcept 10 | : AxisRankOperator(axis, rank) {} 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | }; 17 | 18 | }// namespace refactor::computation 19 | 20 | #endif// COMPUTATION_CONCAT_H 21 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/conv.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_CONV_H 2 | #define COMPUTATION_CONV_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/attributes/pool_attributes.h" 6 | 7 | namespace refactor::computation { 8 | using kernel::PoolAttributes; 9 | 10 | struct Conv final : public Operator { 11 | PoolAttributes attributes; 12 | 13 | explicit Conv(PoolAttributes) noexcept; 14 | 15 | static size_t typeId() noexcept; 16 | size_t opTypeId() const noexcept final; 17 | std::string_view name() const noexcept final; 18 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 19 | std::string serialize() const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::computation 23 | 24 | #endif// COMPUTATION_CONV_H 25 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/cum_sum.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_CUM_SUM_H 2 | #define COMPUTATION_CUM_SUM_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct CumSum final : public Operator { 9 | bool exclusive, reverse; 10 | 11 | constexpr CumSum(bool exclusive_, bool reverse_) noexcept 12 | : Operator(), exclusive(exclusive_), reverse(reverse_) {} 13 | 14 | static size_t typeId() noexcept; 15 | size_t opTypeId() const noexcept final; 16 | std::string_view name() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_CUM_SUM_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/dequantize_linear.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_DEQUANTIZE_LINEAR_H 2 | #define COMPUTATION_DEQUANTIZE_LINEAR_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct DequantizeLinear final : public Operator { 9 | 10 | constexpr DequantizeLinear() noexcept = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | std::string serialize() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_DEQUANTIZE_LINEAR_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/dynamic_quantize_linear.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_DYNAMIC_QUANTIZE_LINEAR_H 2 | #define COMPUTATION_DYNAMIC_QUANTIZE_LINEAR_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct DynamicQuantizeLinear final : public Operator { 9 | 10 | constexpr DynamicQuantizeLinear() noexcept = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | std::string serialize() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_DYNAMIC_QUANTIZE_LINEAR_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/gather.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_GATHER_H 2 | #define COMPUTATION_GATHER_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Gather final : public AxisRankOperator { 9 | constexpr Gather(uint32_t axis, uint32_t rank) noexcept 10 | : AxisRankOperator(axis, rank) {} 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | }; 17 | 18 | }// namespace refactor::computation 19 | 20 | #endif// COMPUTATION_GATHER_H 21 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/gather_elements.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_GATHER_ELEMENTS_H 2 | #define COMPUTATION_GATHER_ELEMENTS_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct GatherElements final : public AxisRankOperator { 9 | constexpr GatherElements(uint32_t axis, uint32_t rank) noexcept 10 | : AxisRankOperator(axis, rank) {} 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | }; 16 | 17 | }// namespace refactor::computation 18 | 19 | #endif// COMPUTATION_GATHER_ELEMENTS_H 20 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/global_pool.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_GLOBAL_POOL_H 2 | #define COMPUTATION_GLOBAL_POOL_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/collectors/global_pool.h" 6 | 7 | namespace refactor::computation { 8 | using kernel::PoolType; 9 | 10 | struct GlobalPool final : public Operator { 11 | PoolType type; 12 | 13 | constexpr GlobalPool(PoolType type_) noexcept 14 | : Operator(), type(type_) {} 15 | 16 | static size_t typeId(PoolType) noexcept; 17 | size_t opTypeId() const noexcept final; 18 | std::string_view name() const noexcept final; 19 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 20 | std::string serialize() const noexcept final; 21 | }; 22 | 23 | }// namespace refactor::computation 24 | 25 | #endif// COMPUTATION_GLOBAL_POOL_H 26 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/hard_sigmoid.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_HARD_SIGMOID_H 2 | #define COMPUTATION_HARD_SIGMOID_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct HardSigmoid final : public Operator { 9 | float alpha, beta; 10 | 11 | constexpr HardSigmoid(float alpha_, float beta_) noexcept 12 | : Operator(), alpha(alpha_), beta(beta_){}; 13 | 14 | static size_t typeId() noexcept; 15 | size_t opTypeId() const noexcept final; 16 | std::string_view name() const noexcept final; 17 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 18 | std::string serialize() const noexcept final; 19 | }; 20 | 21 | }// namespace refactor::computation 22 | 23 | #endif// COMPUTATION_HARD_SIGMOID_H 24 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/identity.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_IDENTITY_H 2 | #define COMPUTATION_IDENTITY_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Identity final : public Operator { 9 | 10 | constexpr Identity() noexcept = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | bool isIdentity() const noexcept final; 16 | }; 17 | 18 | }// namespace refactor::computation 19 | 20 | #endif// COMPUTATION_IDENTITY_H 21 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/mat_mul.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_MAT_MUL_H 2 | #define COMPUTATION_MAT_MUL_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct MatMul final : public LayoutDependentOperator { 9 | float alpha, beta; 10 | bool transA, transB; 11 | 12 | constexpr MatMul(float alpha_, float beta_, bool transA_, bool transB_) noexcept 13 | : LayoutDependentOperator(), 14 | alpha(alpha_), 15 | beta(beta_), 16 | transA(transA_), 17 | transB(transB_) {} 18 | 19 | static size_t typeId() noexcept; 20 | size_t opTypeId() const noexcept final; 21 | std::string_view name() const noexcept final; 22 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 23 | std::string serialize() const noexcept final; 24 | }; 25 | 26 | }// namespace refactor::computation 27 | 28 | #endif// COMPUTATION_MAT_MUL_H 29 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/mat_mul_integer.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_MAT_MUL_INTEGER_H 2 | #define COMPUTATION_MAT_MUL_INTEGER_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct MatMulInteger final : public LayoutDependentOperator { 9 | 10 | constexpr MatMulInteger() noexcept = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | std::string serialize() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// #ifndef COMPUTATION_MAT_MUL_INTEGER_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/pad.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_PAD_H 2 | #define COMPUTATION_PAD_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/collectors/pad.h" 6 | 7 | namespace refactor::computation { 8 | using kernel::PadType; 9 | using Dimensions = kernel::PadDimension; 10 | 11 | struct Pad final : public LayoutDependentOperator { 12 | Dimensions dims; 13 | PadType mode; 14 | 15 | Pad(decltype(dims), PadType) noexcept; 16 | 17 | static size_t typeId() noexcept; 18 | size_t opTypeId() const noexcept final; 19 | std::string_view name() const noexcept final; 20 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 21 | std::string serialize() const noexcept final; 22 | }; 23 | 24 | }// namespace refactor::computation 25 | 26 | #endif// COMPUTATION_PAD_H 27 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/pool.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_POOL_H 2 | #define COMPUTATION_POOL_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/collectors/pool.h" 6 | 7 | namespace refactor::computation { 8 | using kernel::KernelShape; 9 | using kernel::PoolAttributes; 10 | using kernel::PoolType; 11 | 12 | struct Pool final : public Operator { 13 | PoolType type; 14 | bool ceil; 15 | KernelShape kernelShape; 16 | PoolAttributes attributes; 17 | 18 | Pool(PoolType, bool, decltype(kernelShape), PoolAttributes) noexcept; 19 | 20 | static size_t typeId(PoolType) noexcept; 21 | size_t opTypeId() const noexcept final; 22 | std::string_view name() const noexcept final; 23 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 24 | std::string serialize() const noexcept final; 25 | }; 26 | 27 | }// namespace refactor::computation 28 | 29 | #endif// COMPUTATION_POOL_H 30 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/reshape.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_RESHAPE_H 2 | #define COMPUTATION_RESHAPE_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Reshape final : public LayoutDependentOperator { 9 | constexpr Reshape() noexcept : LayoutDependentOperator() {} 10 | 11 | static size_t typeId() noexcept; 12 | size_t opTypeId() const noexcept final; 13 | std::string_view name() const noexcept final; 14 | bool isIdentity() const noexcept final; 15 | }; 16 | 17 | }// namespace refactor::computation 18 | 19 | #endif// COMPUTATION_RESHAPE_H 20 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/rms_normalization.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_RMS_NORMALIZATION_H 2 | #define COMPUTATION_RMS_NORMALIZATION_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct RmsNormalization final : public Operator { 9 | float epsilon; 10 | 11 | constexpr explicit RmsNormalization(float epsilon_) noexcept 12 | : Operator(), epsilon(epsilon_) {} 13 | 14 | static size_t typeId() noexcept; 15 | size_t opTypeId() const noexcept final; 16 | std::string_view name() const noexcept final; 17 | kernel::CollectorBox candidateKernels(Target) const final; 18 | std::string serialize() const noexcept final; 19 | }; 20 | 21 | }// namespace refactor::computation 22 | 23 | #endif// COMPUTATION_RMS_NORMALIZATION_H 24 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/scatter_nd.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_SCATTER_ND_H 2 | #define COMPUTATION_SCATTER_ND_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct ScatterND final : public Operator { 9 | 10 | ScatterND() = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | std::string serialize() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_SCATTER_ND_H 22 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/select.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_SELECT_H 2 | #define COMPUTATION_SELECT_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/collectors/select.h" 6 | 7 | namespace refactor::computation { 8 | using kernel::SelectType; 9 | 10 | struct Select final : public Operator { 11 | SelectType type; 12 | 13 | constexpr explicit Select(SelectType type_) noexcept 14 | : Operator(), type(type_) {} 15 | 16 | static size_t typeId(SelectType) noexcept; 17 | size_t opTypeId() const noexcept final; 18 | std::string_view name() const noexcept final; 19 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::computation 23 | 24 | #endif// COMPUTATION_SELECT_H 25 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/simple_binary.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_SIMPLE_BINARY_H 2 | #define COMPUTATION_SIMPLE_BINARY_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/collectors/simple_binary.h" 6 | 7 | namespace refactor::computation { 8 | using kernel::SimpleBinaryType; 9 | 10 | struct SimpleBinary final : public Operator { 11 | SimpleBinaryType type; 12 | 13 | constexpr explicit SimpleBinary(SimpleBinaryType type_) noexcept 14 | : Operator(), type(type_) {} 15 | 16 | static size_t typeId(SimpleBinaryType) noexcept; 17 | size_t opTypeId() const noexcept final; 18 | std::string_view name() const noexcept final; 19 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 20 | std::string serialize() const noexcept final; 21 | }; 22 | 23 | }// namespace refactor::computation 24 | 25 | #endif// COMPUTATION_SIMPLE_BINARY_H 26 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/simple_unary.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_SIMPLE_UNARY_H 2 | #define COMPUTATION_SIMPLE_UNARY_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/collectors/simple_unary.h" 6 | 7 | namespace refactor::computation { 8 | using kernel::SimpleUnaryType; 9 | 10 | struct SimpleUnary final : public Operator { 11 | SimpleUnaryType type; 12 | 13 | constexpr explicit SimpleUnary(SimpleUnaryType type_) noexcept 14 | : Operator(), type(type_) {} 15 | 16 | static size_t typeId(SimpleUnaryType) noexcept; 17 | size_t opTypeId() const noexcept final; 18 | std::string_view name() const noexcept final; 19 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 20 | std::string serialize() const noexcept final; 21 | }; 22 | 23 | }// namespace refactor::computation 24 | 25 | #endif// COMPUTATION_SIMPLE_UNARY_H 26 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/slice.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_SLICE_H 2 | #define COMPUTATION_SLICE_H 3 | 4 | #include "../operator.h" 5 | #include "kernel/collectors/slice.h" 6 | 7 | namespace refactor::computation { 8 | using Dimensions = kernel::Dimensions; 9 | 10 | struct Slice final : public LayoutDependentOperator { 11 | Dimensions dims; 12 | 13 | explicit Slice(Dimensions) noexcept; 14 | 15 | static size_t typeId() noexcept; 16 | size_t opTypeId() const noexcept final; 17 | std::string_view name() const noexcept final; 18 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 19 | std::string serialize() const noexcept final; 20 | }; 21 | 22 | }// namespace refactor::computation 23 | 24 | #endif// COMPUTATION_SLICE_H 25 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/softmax.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_SOFTMAX_H 2 | #define COMPUTATION_SOFTMAX_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Softmax final : public AxisRankOperator { 9 | constexpr Softmax(uint32_t axis, uint32_t rank) noexcept 10 | : AxisRankOperator(axis, rank) {} 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | }; 17 | 18 | }// namespace refactor::computation 19 | 20 | #endif// COMPUTATION_SOFTMAX_H 21 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/split.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_SPLIT_H 2 | #define COMPUTATION_SPLIT_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Split final : public AxisRankOperator { 9 | constexpr Split(uint32_t axis, uint32_t rank) noexcept 10 | : AxisRankOperator(axis, rank) {} 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | }; 17 | 18 | }// namespace refactor::computation 19 | 20 | #endif// COMPUTATION_SPLIT_H 21 | -------------------------------------------------------------------------------- /src/05computation/include/computation/operators/where.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTATION_WHERE_H 2 | #define COMPUTATION_WHERE_H 3 | 4 | #include "../operator.h" 5 | 6 | namespace refactor::computation { 7 | 8 | struct Where final : public Operator { 9 | 10 | constexpr Where() noexcept = default; 11 | 12 | static size_t typeId() noexcept; 13 | size_t opTypeId() const noexcept final; 14 | std::string_view name() const noexcept final; 15 | kernel::CollectorBox candidateKernels(Target) const noexcept final; 16 | std::string serialize() const noexcept final; 17 | }; 18 | 19 | }// namespace refactor::computation 20 | 21 | #endif// COMPUTATION_WHERE_H 22 | -------------------------------------------------------------------------------- /src/05computation/src/operators/attention.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/attention.h" 2 | 3 | namespace refactor::computation { 4 | using Op = Attention; 5 | 6 | auto Op::typeId() noexcept -> size_t { 7 | static uint8_t ID = 1; 8 | return reinterpret_cast(&ID); 9 | } 10 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 11 | auto Op::name() const noexcept -> std::string_view { return "Attention"; } 12 | 13 | }// namespace refactor::computation 14 | -------------------------------------------------------------------------------- /src/05computation/src/operators/broadcast.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/broadcast.h" 2 | 3 | namespace refactor::computation { 4 | 5 | size_t Broadcast::typeId() noexcept { 6 | static uint8_t ID = 1; 7 | return reinterpret_cast(&ID); 8 | } 9 | size_t Broadcast::opTypeId() const noexcept { return typeId(); } 10 | std::string_view Broadcast::name() const noexcept { return "Broadcast"; } 11 | 12 | }// namespace refactor::computation 13 | -------------------------------------------------------------------------------- /src/05computation/src/operators/cast.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/cast.h" 2 | #include "kernel/collectors/cast.h" 3 | 4 | namespace refactor::computation { 5 | using Op = Cast; 6 | 7 | size_t Op::typeId() noexcept { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | size_t Op::opTypeId() const noexcept { return typeId(); } 12 | std::string_view Op::name() const noexcept { return "Cast"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector_ = kernel::CastCollector; 15 | return std::make_unique(target); 16 | } 17 | auto Op::serialize() const noexcept -> std::string { 18 | return "Cast()"; 19 | } 20 | 21 | }// namespace refactor::computation 22 | -------------------------------------------------------------------------------- /src/05computation/src/operators/clip.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/clip.h" 2 | #include "kernel/collectors/clip.h" 3 | 4 | namespace refactor::computation { 5 | using Op = Clip; 6 | 7 | auto Op::typeId() noexcept -> size_t { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 12 | auto Op::name() const noexcept -> std::string_view { return "Clip"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector_ = kernel::ClipCollector; 15 | return std::make_unique(target); 16 | } 17 | auto Op::serialize() const noexcept -> std::string { 18 | return "Clip()"; 19 | } 20 | 21 | }// namespace refactor::computation 22 | -------------------------------------------------------------------------------- /src/05computation/src/operators/concat.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/concat.h" 2 | #include "kernel/collectors/concat.h" 3 | 4 | namespace refactor::computation { 5 | using Op = Concat; 6 | 7 | auto Op::typeId() noexcept -> size_t { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 12 | auto Op::name() const noexcept -> std::string_view { return "Concat"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector_ = kernel::ConcatCollector; 15 | return std::make_unique(target, axis); 16 | } 17 | 18 | }// namespace refactor::computation 19 | -------------------------------------------------------------------------------- /src/05computation/src/operators/cum_sum.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/cum_sum.h" 2 | 3 | namespace refactor::computation { 4 | 5 | size_t CumSum::typeId() noexcept { 6 | static uint8_t ID = 1; 7 | return reinterpret_cast(&ID); 8 | } 9 | size_t CumSum::opTypeId() const noexcept { return typeId(); } 10 | std::string_view CumSum::name() const noexcept { return "CumSum"; } 11 | 12 | }// namespace refactor::computation 13 | -------------------------------------------------------------------------------- /src/05computation/src/operators/dequantize_linear.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/dequantize_linear.h" 2 | #include "kernel/collectors/dequantize_linear.h" 3 | 4 | namespace refactor::computation { 5 | using Op = DequantizeLinear; 6 | 7 | size_t Op::typeId() noexcept { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | size_t Op::opTypeId() const noexcept { return typeId(); } 12 | std::string_view Op::name() const noexcept { return "DequantizeLinear"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector = kernel::DequantizeLinearCollector; 15 | return std::make_unique(target); 16 | } 17 | auto Op::serialize() const noexcept -> std::string { 18 | return "DequantizeLinear()"; 19 | } 20 | 21 | }// namespace refactor::computation 22 | -------------------------------------------------------------------------------- /src/05computation/src/operators/dynamic_quantize_linear.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/dynamic_quantize_linear.h" 2 | #include "kernel/collectors/dynamic_quantize_linear.h" 3 | 4 | namespace refactor::computation { 5 | using Op = DynamicQuantizeLinear; 6 | 7 | size_t Op::typeId() noexcept { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | size_t Op::opTypeId() const noexcept { return typeId(); } 12 | std::string_view Op::name() const noexcept { return "DynamicQuantizeLinear"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector = kernel::DynamicQuantizeLinearCollector; 15 | return std::make_unique(target); 16 | } 17 | auto Op::serialize() const noexcept -> std::string { 18 | return "DynamicQuantizeLinear()"; 19 | } 20 | 21 | }// namespace refactor::computation 22 | -------------------------------------------------------------------------------- /src/05computation/src/operators/gather.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/gather.h" 2 | #include "kernel/collectors/gather.h" 3 | 4 | namespace refactor::computation { 5 | using Op = Gather; 6 | 7 | auto Op::typeId() noexcept -> size_t { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 12 | auto Op::name() const noexcept -> std::string_view { return "Gather"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector_ = kernel::GatherCollector; 15 | return std::make_unique(target, axis); 16 | } 17 | 18 | }// namespace refactor::computation 19 | -------------------------------------------------------------------------------- /src/05computation/src/operators/gather_elements.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/gather_elements.h" 2 | 3 | namespace refactor::computation { 4 | 5 | size_t GatherElements::typeId() noexcept { 6 | static uint8_t ID = 1; 7 | return reinterpret_cast(&ID); 8 | } 9 | size_t GatherElements::opTypeId() const noexcept { return typeId(); } 10 | std::string_view GatherElements::name() const noexcept { return "GatherElements"; } 11 | 12 | }// namespace refactor::computation 13 | -------------------------------------------------------------------------------- /src/05computation/src/operators/hard_sigmoid.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/hard_sigmoid.h" 2 | #include "kernel/collectors/hard_sigmoid.h" 3 | 4 | namespace refactor::computation { 5 | using Op = HardSigmoid; 6 | 7 | auto Op::typeId() noexcept -> size_t { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 12 | auto Op::name() const noexcept -> std::string_view { return "HardSigmoid"; } 13 | 14 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 15 | using Collector_ = kernel::HardSigmoidCollector; 16 | return std::make_unique(target, alpha, beta); 17 | } 18 | auto Op::serialize() const noexcept -> std::string { 19 | return fmt::format("{}()", name()); 20 | } 21 | 22 | }// namespace refactor::computation 23 | 24 | -------------------------------------------------------------------------------- /src/05computation/src/operators/identity.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/identity.h" 2 | 3 | namespace refactor::computation { 4 | 5 | size_t Identity::typeId() noexcept { 6 | static uint8_t ID = 1; 7 | return reinterpret_cast(&ID); 8 | } 9 | size_t Identity::opTypeId() const noexcept { return typeId(); } 10 | std::string_view Identity::name() const noexcept { return "Identity"; } 11 | bool Identity::isIdentity() const noexcept { return true; } 12 | 13 | }// namespace refactor::computation 14 | -------------------------------------------------------------------------------- /src/05computation/src/operators/mat_mul_integer.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/mat_mul_integer.h" 2 | #include "kernel/collectors/mat_mul_integer.h" 3 | 4 | namespace refactor::computation { 5 | using Op = MatMulInteger; 6 | 7 | auto Op::typeId() noexcept -> size_t { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 12 | auto Op::name() const noexcept -> std::string_view { return "MatMulInteger"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | return std::make_unique(target); 15 | } 16 | auto Op::serialize() const noexcept -> std::string { 17 | return "MatMulInteger()"; 18 | } 19 | 20 | }// namespace refactor::computation 21 | -------------------------------------------------------------------------------- /src/05computation/src/operators/reshape.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/reshape.h" 2 | 3 | namespace refactor::computation { 4 | 5 | size_t Reshape::typeId() noexcept { 6 | static uint8_t ID = 1; 7 | return reinterpret_cast(&ID); 8 | } 9 | size_t Reshape::opTypeId() const noexcept { return typeId(); } 10 | std::string_view Reshape::name() const noexcept { return "Reshape"; } 11 | bool Reshape::isIdentity() const noexcept { return true; } 12 | 13 | }// namespace refactor::computation 14 | -------------------------------------------------------------------------------- /src/05computation/src/operators/scatter_nd.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/scatter_nd.h" 2 | #include "kernel/collectors/scatter_nd.h" 3 | 4 | namespace refactor::computation { 5 | using Op = ScatterND; 6 | 7 | auto Op::typeId() noexcept -> size_t { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 12 | auto Op::name() const noexcept -> std::string_view { return "ScatterND"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector_ = kernel::ScatterNDCollector; 15 | return std::make_unique(target); 16 | } 17 | auto Op::serialize() const noexcept -> std::string { 18 | return "ScatterND()"; 19 | } 20 | 21 | }// namespace refactor::computation 22 | -------------------------------------------------------------------------------- /src/05computation/src/operators/softmax.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/softmax.h" 2 | #include "kernel/collectors/softmax.h" 3 | 4 | namespace refactor::computation { 5 | 6 | size_t Softmax::typeId() noexcept { 7 | static uint8_t ID = 1; 8 | return reinterpret_cast(&ID); 9 | } 10 | size_t Softmax::opTypeId() const noexcept { return typeId(); } 11 | std::string_view Softmax::name() const noexcept { return "Softmax"; } 12 | auto Softmax::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 13 | using Collector_ = kernel::SoftmaxCollector; 14 | return std::make_unique(target, axis); 15 | } 16 | 17 | }// namespace refactor::computation 18 | -------------------------------------------------------------------------------- /src/05computation/src/operators/split.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/split.h" 2 | #include "kernel/collectors/split.h" 3 | 4 | namespace refactor::computation { 5 | using Op = Split; 6 | 7 | auto Op::typeId() noexcept -> size_t { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | auto Op::opTypeId() const noexcept -> size_t { return typeId(); } 12 | auto Op::name() const noexcept -> std::string_view { return "Split"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector_ = kernel::SplitCollector; 15 | return std::make_unique(target, axis); 16 | } 17 | 18 | }// namespace refactor::computation 19 | -------------------------------------------------------------------------------- /src/05computation/src/operators/where.cc: -------------------------------------------------------------------------------- 1 | #include "computation/operators/where.h" 2 | #include "kernel/collectors/where.h" 3 | 4 | namespace refactor::computation { 5 | using Op = Where; 6 | 7 | size_t Op::typeId() noexcept { 8 | static uint8_t ID = 1; 9 | return reinterpret_cast(&ID); 10 | } 11 | size_t Op::opTypeId() const noexcept { return typeId(); } 12 | std::string_view Op::name() const noexcept { return "Where"; } 13 | auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { 14 | using Collector_ = kernel::WhereCollector; 15 | return std::make_unique(target); 16 | } 17 | auto Op::serialize() const noexcept -> std::string { 18 | return "Where()"; 19 | } 20 | 21 | }// namespace refactor::computation 22 | -------------------------------------------------------------------------------- /src/06frontend/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(frontend VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE FRONTEND_SRC src/*.cc src/*.cpp) 6 | add_library(frontend STATIC ${FRONTEND_SRC}) 7 | target_link_libraries(frontend PUBLIC computation) 8 | target_include_directories(frontend PUBLIC include) 9 | 10 | file(GLOB_RECURSE FRONTEND_TEST test/*.cpp) 11 | if(FRONTEND_TEST) 12 | add_executable(frontend_test ${FRONTEND_TEST}) 13 | add_test(frontend_test frontend_test) 14 | target_link_libraries(frontend_test frontend GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/06frontend/README.md: -------------------------------------------------------------------------------- 1 | # 前端图表示 2 | 3 | 前端图表示的功能是消除动态性,包括 3 种来源的动态性: 4 | 5 | 1. 来自变量形状的动态性; 6 | 2. 来自形状旁路的动态性,形状旁路指的是拓扑通向 Reshape、Expand、Squeeze、Unsqueeze 和 Slice 的子图; 7 | 3. 来自形状占位符的动态性,即 Reshape 中的 0、-1 和其他算子中常见的倒数轴表示; 8 | -------------------------------------------------------------------------------- /src/07onnx/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(onnx VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE ONNX_SRC src/*.cc src/*.cpp) 6 | add_library(onnx STATIC ${ONNX_SRC}) 7 | target_link_libraries(onnx PUBLIC frontend) 8 | target_include_directories(onnx PUBLIC include) 9 | 10 | file(GLOB_RECURSE ONNX_TEST test/*.cpp) 11 | if(ONNX_TEST) 12 | add_executable(onnx_test ${ONNX_TEST}) 13 | add_test(onnx_test onnx_test) 14 | target_link_libraries(onnx_test onnx GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/07onnx/README.md: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /src/07onnx/include/onnx/operators.h: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_OPERATORS_H 2 | #define ONNX_OPERATORS_H 3 | 4 | namespace refactor::onnx { 5 | 6 | void register_(); 7 | 8 | }// namespace refactor::onnx 9 | 10 | #endif// ONNX_OPERATORS_H 11 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/batch_normalization.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_BATCH_NORMALIZATION_HH 2 | #define ONNX_BATCH_NORMALIZATION_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct BatchNormalization final : public Operator { 10 | bool trainingMode; 11 | float epsilon; 12 | 13 | BatchNormalization(bool, float); 14 | 15 | static OpBox build(ModelContext const &, std::string_view, Attributes); 16 | static size_t typeId(); 17 | 18 | size_t opTypeId() const final; 19 | std::string_view opTypeName() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | computation::OpBox lower(TensorRefs) const final; 22 | }; 23 | 24 | }// namespace refactor::onnx 25 | 26 | #endif// ONNX_BATCH_NORMALIZATION_HH 27 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/cast.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_CAST_HH 2 | #define ONNX_CAST_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Cast final : public Operator { 10 | DataType to; 11 | 12 | explicit Cast(DataType); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_CAST_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/clip.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_CLIP_HH 2 | #define ONNX_CLIP_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Clip final : public Operator { 10 | 11 | Clip() = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InferResult infer(TensorRefs, InferOptions const &) const final; 19 | computation::OpBox lower(TensorRefs) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_CLIP_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/common.h: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_INFER_H 2 | #define ONNX_INFER_H 3 | 4 | #include "common.h" 5 | #include "frontend/operator.h" 6 | #include 7 | 8 | namespace refactor::onnx { 9 | using namespace frontend; 10 | 11 | using OptionalInts = std::optional; 12 | using OptionalIntsRef = std::optional>; 13 | 14 | constexpr Int StandardOpsetVersion = 18; 15 | 16 | /// @brief 池化形状推断。 17 | /// @param data 输入张量的形状。 18 | /// @param kernel kernel 的形状。 19 | /// @param dilations 空洞参数。 20 | /// @param pads 扩张参数。 21 | /// @param strides 跳步参数。 22 | /// @return 池化后的形状。 23 | ShapeResult pool(SmallInts<4> const &data, 24 | Ints const &kernel, 25 | OptionalIntsRef const &dilations, 26 | OptionalIntsRef const &pads, 27 | OptionalIntsRef const &strides); 28 | 29 | }// namespace refactor::onnx 30 | 31 | #endif// ONNX_INFER_H 32 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/compair.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_COMPAIR_HH 2 | #define ONNX_COMPAIR_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | enum class CompairType { 10 | EQ, 11 | GT, 12 | GE, 13 | LT, 14 | LE, 15 | }; 16 | 17 | struct Compair final : public Operator { 18 | CompairType type; 19 | 20 | explicit Compair(CompairType); 21 | 22 | static OpBox build(ModelContext const &, std::string_view, Attributes); 23 | static size_t typeId(CompairType); 24 | 25 | size_t opTypeId() const final; 26 | std::string_view opTypeName() const final; 27 | InferResult infer(TensorRefs, InferOptions const &) const final; 28 | computation::OpBox lower(TensorRefs) const final; 29 | }; 30 | 31 | }// namespace refactor::onnx 32 | 33 | #endif// ONNX_COMPAIR_HH 34 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/concat.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_CONCAT_HH 2 | #define ONNX_CONCAT_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Concat final : public Operator { 10 | Int axis; 11 | 12 | explicit Concat(int64_t); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_CONCAT_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/constant.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_CONSTANT_HH 2 | #define ONNX_CONSTANT_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Constant final : public Operator { 10 | Attribute value; 11 | 12 | explicit Constant(Attribute); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_CONSTANT_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/constant_of_shape.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_CONSTANT_OF_SHAPE_HH 2 | #define ONNX_CONSTANT_OF_SHAPE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct ConstantOfShape final : public Operator { 10 | Tensor_ value; 11 | 12 | explicit ConstantOfShape(Tensor_); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InputVec valueDependentInputs() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_CONSTANT_OF_SHAPE_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/conv.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_CONV_HH 2 | #define ONNX_CONV_HH 3 | 4 | #include "common.h" 5 | #include "frontend/operator.h" 6 | 7 | namespace refactor::onnx { 8 | using namespace frontend; 9 | 10 | struct Conv final : public Operator { 11 | OptionalInts dilations, pads, strides; 12 | 13 | Conv(OptionalInts dilations, 14 | OptionalInts pads, 15 | OptionalInts strides); 16 | 17 | static OpBox build(ModelContext const &, std::string_view, Attributes); 18 | static size_t typeId(); 19 | 20 | size_t opTypeId() const final; 21 | std::string_view opTypeName() const final; 22 | InferResult infer(TensorRefs, InferOptions const &) const final; 23 | computation::OpBox lower(TensorRefs) const final; 24 | }; 25 | 26 | }// namespace refactor::onnx 27 | 28 | #endif// ONNX_CONV_HH 29 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/cum_sum.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_CUM_SUM_HH 2 | #define ONNX_CUM_SUM_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct CumSum final : public Operator { 10 | bool exclusive, reverse; 11 | 12 | CumSum(bool exclusive, bool reverse); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_CUM_SUM_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/dequantize_linear.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_DEQUANTIZE_LINEAR_HH 2 | #define ONNX_DEQUANTIZE_LINEAR_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct DequantizeLinear final : public Operator { 10 | Int axis; 11 | 12 | explicit DequantizeLinear(Int) noexcept; 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_DEQUANTIZE_LINEAR_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/dynamic_quantize_linear.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_DYNAMIC_QUANTIZE_LINEAR_HH 2 | #define ONNX_DYNAMIC_QUANTIZE_LINEAR_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct DynamicQuantizeLinear final : public Operator { 10 | 11 | DynamicQuantizeLinear() = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InferResult infer(TensorRefs, InferOptions const &) const final; 19 | computation::OpBox lower(TensorRefs) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_DYNAMIC_QUANTIZE_LINEAR_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/einsum.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_EINSUM_HH 2 | #define ONNX_EINSUM_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Einsum final : public Operator { 10 | std::string equation; 11 | 12 | Einsum(std::string); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_EINSUM_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/expand.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_EXPAND_HH 2 | #define ONNX_EXPAND_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Expand final : public Operator { 10 | 11 | constexpr Expand() noexcept = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InputVec valueDependentInputs() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_EXPAND_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/flatten.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_FLATTEN_HH 2 | #define ONNX_FLATTEN_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Flatten final : public Operator { 10 | Int axis; 11 | 12 | explicit Flatten(Int); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InputVec valueDependentInputs() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | computation::OpBox lower(TensorRefs) const final; 22 | }; 23 | 24 | }// namespace refactor::onnx 25 | 26 | #endif// ONNX_FLATTEN_HH 27 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/gather.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_GATHER_HH 2 | #define ONNX_GATHER_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Gather final : public Operator { 10 | Int axis; 11 | 12 | explicit Gather(Int); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_GATHER_ELEMENTS_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/gather_elements.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_GATHER_ELEMENTS_HH 2 | #define ONNX_GATHER_ELEMENTS_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct GatherElements final : public Operator { 10 | Int axis; 11 | 12 | explicit GatherElements(Int); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_GATHER_ELEMENTS_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/gemm.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_GEMM_HH 2 | #define ONNX_GEMM_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Gemm final : public Operator { 10 | Float alpha, beta; 11 | bool transA, transB; 12 | 13 | Gemm(Float, Float, bool, bool); 14 | 15 | static OpBox build(ModelContext const &, std::string_view, Attributes); 16 | static size_t typeId(); 17 | 18 | size_t opTypeId() const final; 19 | std::string_view opTypeName() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | computation::OpBox lower(TensorRefs) const final; 22 | }; 23 | 24 | }// namespace refactor::onnx 25 | 26 | #endif// ONNX_GEMM_HH 27 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/global_pool.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_GLOBAL_POOL_HH 2 | #define ONNX_GLOBAL_POOL_HH 3 | 4 | #include "frontend/operator.h" 5 | #include "pool_type.h" 6 | 7 | namespace refactor::onnx { 8 | using namespace frontend; 9 | 10 | struct GlobalPool final : public Operator { 11 | PoolType type; 12 | 13 | explicit GlobalPool(PoolType); 14 | 15 | static OpBox build(ModelContext const &, std::string_view, Attributes); 16 | static size_t typeId(PoolType); 17 | 18 | size_t opTypeId() const final; 19 | std::string_view opTypeName() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | computation::OpBox lower(TensorRefs) const final; 22 | }; 23 | 24 | }// namespace refactor::onnx 25 | 26 | #endif// ONNX_GLOBAL_POOL_HH 27 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/hard_sigmoid.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_HARD_SIGMOID_HH 2 | #define ONNX_HARD_SIGMOID_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct HardSigmoid final : public Operator { 10 | Float alpha, beta; 11 | 12 | explicit HardSigmoid(Float, Float); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_HARD_SIGMOID_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/mat_mul.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_MAT_MUL_HH 2 | #define ONNX_MAT_MUL_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct MatMul final : public Operator { 10 | 11 | constexpr MatMul() noexcept = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InferResult infer(TensorRefs, InferOptions const &) const final; 19 | computation::OpBox lower(TensorRefs) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_MAT_MUL_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/mat_mul_integer.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_MAT_MUL_INTEGER_HH 2 | #define ONNX_MAT_MUL_INTEGER_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct MatMulInteger final : public Operator { 10 | 11 | constexpr MatMulInteger() noexcept = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InferResult infer(TensorRefs, InferOptions const &) const final; 19 | computation::OpBox lower(TensorRefs) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_MAT_MUL_INTEGER_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/pad.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_PAD_HH 2 | #define ONNX_PAD_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | enum class PadMode { 10 | Constant, 11 | Reflect, 12 | Edge, 13 | Wrap, 14 | }; 15 | 16 | struct Pad final : public Operator { 17 | PadMode mode; 18 | 19 | Pad(PadMode); 20 | 21 | static OpBox build(ModelContext const &, std::string_view, Attributes); 22 | static size_t typeId(); 23 | size_t opTypeId() const final; 24 | std::string_view opTypeName() const final; 25 | InferResult infer(TensorRefs, InferOptions const &) const final; 26 | computation::OpBox lower(TensorRefs) const final; 27 | }; 28 | 29 | }// namespace refactor::onnx 30 | 31 | #endif// ONNX_PAD_HH 32 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/pool_type.h: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_POOL_TYPE_H 2 | #define ONNX_POOL_TYPE_H 3 | 4 | namespace refactor::onnx { 5 | 6 | enum class PoolType { 7 | Average, 8 | Lp, 9 | Max, 10 | }; 11 | 12 | }// namespace refactor::onnx 13 | 14 | #endif// ONNX_POOL_TYPE_H 15 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/range.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_RANGE_HH 2 | #define ONNX_RANGE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Range final : public Operator { 10 | 11 | constexpr Range() noexcept = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InputVec valueDependentInputs() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_RANGE_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/reshape.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_RESHAPE_HH 2 | #define ONNX_RESHAPE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Reshape final : public Operator { 10 | bool allowzero; 11 | 12 | explicit Reshape(bool); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InputVec valueDependentInputs() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | computation::OpBox lower(TensorRefs) const final; 22 | }; 23 | 24 | }// namespace refactor::onnx 25 | 26 | #endif// ONNX_RESHAPE_HH 27 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/scatter_nd.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_SCATTER_ND_HH 2 | #define ONNX_SCATTER_ND_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct ScatterND final : public Operator { 10 | 11 | ScatterND() = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InferResult infer(TensorRefs, InferOptions const &) const final; 19 | computation::OpBox lower(TensorRefs) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_SCATTER_ND_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/select.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_SELECT_HH 2 | #define ONNX_SELECT_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | enum class SelectType { 10 | Max, 11 | Min, 12 | }; 13 | 14 | struct Select final : public Operator { 15 | SelectType type; 16 | 17 | explicit Select(SelectType); 18 | 19 | static OpBox build(ModelContext const &, std::string_view, Attributes); 20 | static size_t typeId(SelectType); 21 | 22 | size_t opTypeId() const final; 23 | std::string_view opTypeName() const final; 24 | InferResult infer(TensorRefs, InferOptions const &) const final; 25 | computation::OpBox lower(TensorRefs) const final; 26 | }; 27 | 28 | }// namespace refactor::onnx 29 | 30 | #endif// ONNX_SELECT_HH 31 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/shape.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_SHAPE_HH 2 | #define ONNX_SHAPE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Shape final : public Operator { 10 | Int start; 11 | std::optional end; 12 | 13 | Shape(Int, std::optional); 14 | 15 | static OpBox build(ModelContext const &, std::string_view, Attributes); 16 | static size_t typeId(); 17 | 18 | size_t opTypeId() const final; 19 | std::string_view opTypeName() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_SHAPE_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/slice.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_SLICE_HH 2 | #define ONNX_SLICE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Slice final : public Operator { 10 | 11 | constexpr Slice() noexcept = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InputVec valueDependentInputs() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_SLICE_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/softmax.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_SOFTMAX_HH 2 | #define ONNX_SOFTMAX_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Softmax final : public Operator { 10 | Int axis; 11 | 12 | explicit Softmax(Int); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_SOFTMAX_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/split.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_SPLIT_HH 2 | #define ONNX_SPLIT_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Split final : public Operator { 10 | Int axis, numOutputs; 11 | 12 | Split(Int, Int); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InputVec valueDependentInputs() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | computation::OpBox lower(TensorRefs) const final; 22 | }; 23 | 24 | }// namespace refactor::onnx 25 | 26 | #endif// ONNX_SPLIT_HH 27 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/squeeze.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_SQUEEZE_HH 2 | #define ONNX_SQUEEZE_HH 3 | 4 | #include "frontend/operator.h" 5 | #include 6 | 7 | namespace refactor::onnx { 8 | using namespace frontend; 9 | 10 | struct Squeeze final : public Operator { 11 | std::optional> axes; 12 | 13 | explicit Squeeze(decltype(axes)); 14 | 15 | static OpBox build(ModelContext const &, std::string_view, Attributes); 16 | static size_t typeId(); 17 | 18 | size_t opTypeId() const final; 19 | std::string_view opTypeName() const final; 20 | InputVec valueDependentInputs() const final; 21 | InferResult infer(TensorRefs, InferOptions const &) const final; 22 | computation::OpBox lower(TensorRefs) const final; 23 | }; 24 | 25 | }// namespace refactor::onnx 26 | 27 | #endif// ONNX_SQUEEZE_HH 28 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/tile.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_TILE_HH 2 | #define ONNX_TILE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Tile final : public Operator { 10 | 11 | constexpr Tile() noexcept = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InputVec valueDependentInputs() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_TILE_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/transpose.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_TRANSPOSE_HH 2 | #define ONNX_TRANSPOSE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Transpose final : public Operator { 10 | Ints perm; 11 | 12 | explicit Transpose(Ints); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::onnx 24 | 25 | #endif// ONNX_TRANSPOSE_HH 26 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/unsqueeze.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_UNSQUEEZE_HH 2 | #define ONNX_UNSQUEEZE_HH 3 | 4 | #include "frontend/operator.h" 5 | #include 6 | 7 | namespace refactor::onnx { 8 | using namespace frontend; 9 | 10 | struct Unsqueeze final : public Operator { 11 | std::optional axes; 12 | 13 | explicit Unsqueeze(decltype(axes)); 14 | 15 | static OpBox build(ModelContext const &, std::string_view, Attributes); 16 | static size_t typeId(); 17 | 18 | size_t opTypeId() const final; 19 | std::string_view opTypeName() const final; 20 | InputVec valueDependentInputs() const final; 21 | InferResult infer(TensorRefs, InferOptions const &) const final; 22 | computation::OpBox lower(TensorRefs) const final; 23 | }; 24 | 25 | }// namespace refactor::onnx 26 | 27 | #endif// ONNX_UNSQUEEZE_HH 28 | -------------------------------------------------------------------------------- /src/07onnx/src/operators/where.hh: -------------------------------------------------------------------------------- 1 | #ifndef ONNX_WHERE_HH 2 | #define ONNX_WHERE_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::onnx { 7 | using namespace frontend; 8 | 9 | struct Where final : public Operator { 10 | 11 | constexpr Where() noexcept = default; 12 | 13 | static OpBox build(ModelContext const &, std::string_view, Attributes); 14 | static size_t typeId(); 15 | 16 | size_t opTypeId() const final; 17 | std::string_view opTypeName() const final; 18 | InferResult infer(TensorRefs, InferOptions const &) const final; 19 | computation::OpBox lower(TensorRefs) const final; 20 | }; 21 | 22 | }// namespace refactor::onnx 23 | 24 | #endif// ONNX_WHERE_HH 25 | -------------------------------------------------------------------------------- /src/07onnx/test/test_clip.cpp: -------------------------------------------------------------------------------- 1 | #include "../src/operators/clip.hh" 2 | #include "onnx/operators.h" 3 | #include 4 | 5 | using namespace refactor; 6 | using namespace onnx; 7 | 8 | TEST(infer, Clip) { 9 | onnx::register_(); 10 | auto edges = Edges{ 11 | {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, 12 | }; 13 | count_t inputs[]{0}; 14 | auto infered = Clip().infer(TensorRefs(edges, inputs), {true}); 15 | ASSERT_TRUE(infered.isOk()); 16 | auto outputs = std::move(infered.unwrap()); 17 | ASSERT_EQ(outputs.size(), 1); 18 | auto y = std::move(outputs[0]); 19 | ASSERT_EQ(y->dataType, DataType::F32); 20 | ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); 21 | } 22 | -------------------------------------------------------------------------------- /src/07onnx/test/test_concat.cpp: -------------------------------------------------------------------------------- 1 | #include "../src/operators/concat.hh" 2 | #include "onnx/operators.h" 3 | #include 4 | 5 | using namespace refactor; 6 | using namespace frontend; 7 | using namespace onnx; 8 | 9 | TEST(infer, Concat) { 10 | onnx::register_(); 11 | auto edges = Edges{ 12 | {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, 13 | {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(2)}, {}), ""}, 14 | {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(5)}, {}), ""}, 15 | }; 16 | count_t inputs[]{0, 1, 2}; 17 | auto infered = Concat(1).infer(TensorRefs(edges, inputs), {true}); 18 | ASSERT_TRUE(infered.isOk()); 19 | auto outputs = std::move(infered.unwrap()); 20 | ASSERT_EQ(outputs.size(), 1); 21 | auto y = std::move(outputs[0]); 22 | ASSERT_EQ(y->dataType, DataType::F32); 23 | ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(10)})); 24 | } 25 | -------------------------------------------------------------------------------- /src/07onnx/test/test_einsum.cpp: -------------------------------------------------------------------------------- 1 | #include "../src/operators/einsum.hh" 2 | #include "onnx/operators.h" 3 | #include 4 | 5 | using namespace refactor; 6 | using namespace frontend; 7 | using namespace onnx; 8 | 9 | TEST(infer, Einsum) { 10 | onnx::register_(); 11 | auto edges = Edges{ 12 | {Tensor::share(DataType::F32, Shape{DimExpr(3), DimExpr(2), DimExpr(5)}, {}), ""}, 13 | {Tensor::share(DataType::F32, Shape{DimExpr(3), DimExpr(5), DimExpr(4)}, {}), ""}, 14 | }; 15 | count_t inputs[]{0, 1}; 16 | auto infered = Einsum("bik,bkj->bij").infer(TensorRefs(edges, inputs), {true}); 17 | ASSERT_TRUE(infered.isOk()); 18 | auto outputs = std::move(infered.unwrap()); 19 | ASSERT_EQ(outputs.size(), 1); 20 | auto y = std::move(outputs[0]); 21 | ASSERT_EQ(y->dataType, DataType::F32); 22 | ASSERT_EQ(y->shape, (Shape{DimExpr(3), DimExpr(2), DimExpr(4)})); 23 | } 24 | -------------------------------------------------------------------------------- /src/07onnx/test/test_hard_sigmoid.cpp: -------------------------------------------------------------------------------- 1 | #include "../src/operators/hard_sigmoid.hh" 2 | #include "onnx/operators.h" 3 | #include 4 | 5 | using namespace refactor; 6 | using namespace onnx; 7 | 8 | TEST(infer, HardSigmoid) { 9 | onnx::register_(); 10 | 11 | auto edges = Edges{ 12 | {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, 13 | }; 14 | count_t inputs[]{0}; 15 | auto infered = HardSigmoid(0.2f, 0.5f).infer(TensorRefs(edges, inputs), {false}); 16 | ASSERT_TRUE(infered.isOk()); 17 | auto outputs = std::move(infered.unwrap()); 18 | ASSERT_EQ(outputs.size(), 1); 19 | auto y = std::move(outputs[0]); 20 | ASSERT_EQ(y->dataType, DataType::F32); 21 | ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); 22 | } 23 | 24 | -------------------------------------------------------------------------------- /src/07onnx/test/test_unsqueeze.cpp: -------------------------------------------------------------------------------- 1 | #include "../src/operators/unsqueeze.hh" 2 | #include "onnx/operators.h" 3 | #include 4 | 5 | using namespace refactor; 6 | using namespace frontend; 7 | using namespace onnx; 8 | 9 | TEST(infer, Unsqueeze) { 10 | onnx::register_(); 11 | auto edges = Edges{ 12 | {Tensor::share(DataType::F32, Shape{DimExpr(3), DimExpr(5)}, {}), ""}, 13 | {Tensor::share(DataType::I64, Shape{DimExpr(2)}, {}), ""}, 14 | }; 15 | int64_t axes[]{2, 0}; 16 | std::memcpy(edges[1].tensor->malloc(), axes, sizeof(axes)); 17 | count_t inputs[]{0, 1}; 18 | auto infered = Unsqueeze(std::nullopt).infer(TensorRefs(edges, inputs), {true}); 19 | ASSERT_TRUE(infered.isOk()); 20 | auto outputs = std::move(infered.unwrap()); 21 | ASSERT_EQ(outputs.size(), 1); 22 | auto y = std::move(outputs[0]); 23 | ASSERT_EQ(y->dataType, DataType::F32); 24 | ASSERT_EQ(y->shape, (Shape{DimExpr(1), DimExpr(3), DimExpr(1), DimExpr(5)})); 25 | } 26 | -------------------------------------------------------------------------------- /src/08-01llm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(llm VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE LLM_SRC src/*.cc src/*.cpp) 6 | add_library(llm STATIC ${LLM_SRC}) 7 | target_link_libraries(llm PUBLIC frontend) 8 | target_include_directories(llm PUBLIC include) 9 | 10 | file(GLOB_RECURSE LLM_TEST test/*.cpp) 11 | if(LLM_TEST) 12 | add_executable(llm_test ${LLM_TEST}) 13 | add_test(llm_test llm_test) 14 | target_link_libraries(llm_test llm GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/08-01llm/include/llm/operators.h: -------------------------------------------------------------------------------- 1 | #ifndef LLM_OPERATORS_H 2 | #define LLM_OPERATORS_H 3 | 4 | namespace refactor::llm { 5 | 6 | void register_(); 7 | 8 | }// namespace refactor::llm 9 | 10 | #endif// LLM_OPERATORS_H 11 | -------------------------------------------------------------------------------- /src/08-01llm/src/operators.cpp: -------------------------------------------------------------------------------- 1 | #include "llm/operators.h" 2 | #include "operators/attention.hh" 3 | #include "operators/mat_mul.hh" 4 | #include "operators/rms_normalization.hh" 5 | 6 | namespace refactor::llm { 7 | using namespace frontend; 8 | 9 | void register_() { 10 | #define REGISTER(NAME, CLASS) Operator::register_("llm::" #NAME) 11 | // clang-format off 12 | REGISTER(Attention , Attention ); 13 | REGISTER(RmsNormalization, RmsNormalization); 14 | REGISTER(MatMul , MatMul ); 15 | // clang-format on 16 | #undef REGISTER 17 | } 18 | 19 | }// namespace refactor::llm 20 | -------------------------------------------------------------------------------- /src/08-01llm/src/operators/attention.hh: -------------------------------------------------------------------------------- 1 | #ifndef LLM_RMS_ATTENTION_HH 2 | #define LLM_RMS_ATTENTION_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::llm { 7 | using namespace frontend; 8 | 9 | struct Attention final : public Operator { 10 | dim_t maxSeqLen; 11 | 12 | explicit Attention(decltype(maxSeqLen)); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::llm 24 | 25 | #endif// LLM_RMS_ATTENTION_HH 26 | -------------------------------------------------------------------------------- /src/08-01llm/src/operators/common.h: -------------------------------------------------------------------------------- 1 | #ifndef LLM_COMMON_H 2 | #define LLM_COMMON_H 3 | 4 | #include "common.h" 5 | 6 | #define EXPECT_SIZE(N) \ 7 | if (inputs.size() != (N)) { \ 8 | return Err(InferError(ERROR_MSG("Input size error"))); \ 9 | } 10 | 11 | #define EXPECT_VAL(DIM, VAL) \ 12 | int64_t VAL; \ 13 | if ((DIM).hasValue()) { \ 14 | VAL = (DIM).value(); \ 15 | } else { \ 16 | return Err(InferError(UnknownVariable{(DIM.variable()->name)})); \ 17 | } 18 | 19 | #endif// LLM_COMMON_H 20 | -------------------------------------------------------------------------------- /src/08-01llm/src/operators/mat_mul.hh: -------------------------------------------------------------------------------- 1 | #ifndef LLM_MAT_MUL_HH 2 | #define LLM_MAT_MUL_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::llm { 7 | using namespace frontend; 8 | 9 | struct MatMul final : public Operator { 10 | bool transA, transB; 11 | 12 | MatMul(decltype(transA), decltype(transB)); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::llm 24 | 25 | #endif// LLM_MAT_MUL_HH 26 | -------------------------------------------------------------------------------- /src/08-01llm/src/operators/rms_normalization.hh: -------------------------------------------------------------------------------- 1 | #ifndef LLM_RMS_NORMALIZATION_HH 2 | #define LLM_RMS_NORMALIZATION_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::llm { 7 | using namespace frontend; 8 | 9 | struct RmsNormalization final : public Operator { 10 | float epsilon; 11 | 12 | explicit RmsNormalization(decltype(epsilon)); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | computation::OpBox lower(TensorRefs) const final; 21 | }; 22 | 23 | }// namespace refactor::llm 24 | 25 | #endif// LLM_RMS_NORMALIZATION_HH 26 | -------------------------------------------------------------------------------- /src/08-01llm/test/test_rms_normalization.cpp: -------------------------------------------------------------------------------- 1 | #include "../src/operators/rms_normalization.hh" 2 | #include "llm/operators.h" 3 | #include 4 | 5 | using namespace refactor; 6 | using namespace llm; 7 | 8 | TEST(infer, RmsNormalization) { 9 | llm::register_(); 10 | auto edges = Edges{ 11 | {Tensor::share(DataType::F32, Shape{DimExpr(7), DimExpr(2), DimExpr(3)}, {}), ""}, 12 | {Tensor::share(DataType::F32, Shape{DimExpr(3)}, {}), ""}, 13 | }; 14 | count_t inputs[]{0, 1}; 15 | auto infered = RmsNormalization(1e-6).infer(TensorRefs(edges, inputs), {true}); 16 | ASSERT_TRUE(infered.isOk()); 17 | auto outputs = std::move(infered.unwrap()); 18 | ASSERT_EQ(outputs.size(), 1); 19 | auto y = std::move(outputs[0]); 20 | ASSERT_EQ(y->dataType, DataType::F32); 21 | ASSERT_EQ(y->shape, (Shape{DimExpr(7), DimExpr(2), DimExpr(3)})); 22 | } 23 | -------------------------------------------------------------------------------- /src/08communication/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(communication VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | file(GLOB_RECURSE COMMUNICATION_SRC src/*.cc src/*.cpp) 6 | add_library(communication STATIC ${COMMUNICATION_SRC}) 7 | target_link_libraries(communication PUBLIC frontend) 8 | target_include_directories(communication PUBLIC include) 9 | 10 | file(GLOB_RECURSE COMMUNICATION_TEST test/*.cpp) 11 | if(COMMUNICATION_TEST) 12 | add_executable(communication_test ${COMMUNICATION_TEST}) 13 | add_test(communication_test communication_test) 14 | target_link_libraries(communication_test communication GTest::gtest_main Backward::Object) 15 | endif() 16 | -------------------------------------------------------------------------------- /src/08communication/README.md: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /src/08communication/include/communication/operators.h: -------------------------------------------------------------------------------- 1 | #ifndef COMMUNICATION_OPERATORS_H 2 | #define COMMUNICATION_OPERATORS_H 3 | 4 | namespace refactor::communication { 5 | 6 | void register_(); 7 | 8 | }// namespace refactor::communication 9 | 10 | #endif// COMMUNICATION_OPERATORS_H 11 | -------------------------------------------------------------------------------- /src/08communication/src/operators.cpp: -------------------------------------------------------------------------------- 1 | #include "communication/operators.h" 2 | #include "operators/all_gather.hh" 3 | #include "operators/all_reduce.hh" 4 | 5 | namespace refactor::communication { 6 | using namespace frontend; 7 | 8 | void register_() { 9 | #define REGISTER(NAME, CLASS) Operator::register_("onnx::" #NAME) 10 | // clang-format off 11 | REGISTER(AllReduceAvg , AllReduce); 12 | REGISTER(AllReduceSum , AllReduce); 13 | REGISTER(AllReduceMin , AllReduce); 14 | REGISTER(AllReduceMax , AllReduce); 15 | REGISTER(AllReduceProd, AllReduce); 16 | REGISTER(AllGather , AllGather); 17 | // clang-format on 18 | #undef REGISTER 19 | } 20 | 21 | }// namespace refactor::communication 22 | -------------------------------------------------------------------------------- /src/08communication/src/operators/all_gather.hh: -------------------------------------------------------------------------------- 1 | #ifndef COMMUNICATION_ALL_GATHER_HH 2 | #define COMMUNICATION_ALL_GATHER_HH 3 | 4 | #include "frontend/operator.h" 5 | 6 | namespace refactor::communication { 7 | using namespace frontend; 8 | 9 | struct AllGather final : public Operator { 10 | Int nranks; 11 | 12 | explicit AllGather(Int); 13 | 14 | static OpBox build(ModelContext const &, std::string_view, Attributes); 15 | static size_t typeId(); 16 | 17 | size_t opTypeId() const final; 18 | std::string_view opTypeName() const final; 19 | InferResult infer(TensorRefs, InferOptions const &) const final; 20 | }; 21 | 22 | }// namespace refactor::communication 23 | 24 | #endif// COMMUNICATION_ALL_GATHER_HH 25 | -------------------------------------------------------------------------------- /src/08communication/src/operators/all_reduce.hh: -------------------------------------------------------------------------------- 1 | #ifndef COMMUNICATION_ALL_REDUCE_HH 2 | #define COMMUNICATION_ALL_REDUCE_HH 3 | 4 | #include "frontend/operator.h" 5 | #include "kernel/attributes/communication.h" 6 | 7 | namespace refactor::communication { 8 | using namespace frontend; 9 | 10 | struct AllReduce final : public Operator { 11 | kernel::AllReduceType type; 12 | 13 | AllReduce(kernel::AllReduceType); 14 | 15 | static OpBox build(ModelContext const &, std::string_view, Attributes); 16 | static size_t typeId(kernel::AllReduceType); 17 | 18 | size_t opTypeId() const final; 19 | std::string_view opTypeName() const final; 20 | InferResult infer(TensorRefs, InferOptions const &) const final; 21 | computation::OpBox lower(TensorRefs) const final; 22 | }; 23 | 24 | }// namespace refactor::communication 25 | 26 | #endif// COMMUNICATION_ALL_REDUCE_HH 27 | -------------------------------------------------------------------------------- /src/08communication/src/operators/common.h: -------------------------------------------------------------------------------- 1 | #ifndef COMMUNICATION_COMMON_H 2 | #define COMMUNICATION_COMMON_H 3 | 4 | #include "common.h" 5 | 6 | #define EXPECT_SIZE(N) \ 7 | if (inputs.size() != (N)) { \ 8 | return Err(InferError(ERROR_MSG("Input size error"))); \ 9 | } 10 | 11 | #define EXPECT_VAL(DIM, VAL) \ 12 | int64_t VAL; \ 13 | if ((DIM).hasValue()) { \ 14 | VAL = (DIM).value(); \ 15 | } else { \ 16 | return Err(InferError(UnknownVariable{(DIM.variable()->name)})); \ 17 | } 18 | 19 | #endif// COMMUNICATION_COMMON_H 20 | -------------------------------------------------------------------------------- /src/09python_ffi/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12 FATAL_ERROR) 2 | project(python_ffi VERSION 0.0.0 LANGUAGES CXX) 3 | message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) 4 | 5 | include_directories(pybind11/include) 6 | add_subdirectory(pybind11) 7 | 8 | file(GLOB_RECURSE PYFFI_SRC src/*.cc src/*.cpp) 9 | pybind11_add_module(python_ffi SHARED ${PYFFI_SRC}) 10 | target_link_libraries(python_ffi PRIVATE onnx llm communication) 11 | target_include_directories(python_ffi PRIVATE include) 12 | 13 | # EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a 14 | # define (VERSION_INFO) here. 15 | # target_compile_definitions(python_ffi 16 | # PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO}) 17 | -------------------------------------------------------------------------------- /src/09python_ffi/README.md: -------------------------------------------------------------------------------- 1 | # Python FFI with pybind11 2 | 3 | Using [pybind11 2.11.1](https://github.com/pybind/pybind11/releases/tag/v2.11.1). 4 | 5 | See for project example. 6 | 7 | ## pass 控制符号 8 | 9 | - `ce`: constant elimination,前端常量折叠; 10 | - `lp`: layout permutation,张量布局从 NCHW 变换到 NHWC; 11 | -------------------------------------------------------------------------------- /src/09python_ffi/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "refactor_graph" 7 | version = "0.0.0" 8 | authors = [{ name = "YdrMaster", email = "ydrml@hotmail.com" }] 9 | description = "Python frontend of RefactorGraph" 10 | readme = "README.md" 11 | requires-python = ">=3.7" 12 | keywords = ["ai-compiler"] 13 | license = { text = "Apache" } 14 | classifiers = ["Programming Language :: Python :: 3"] 15 | dependencies = ["onnx"] 16 | 17 | [tool.setuptools.packages.find] 18 | where = ["src"] 19 | 20 | [tool.setuptools.package-data] 21 | pyinfinitensor = ["*.so"] 22 | -------------------------------------------------------------------------------- /src/09python_ffi/src/functions.h: -------------------------------------------------------------------------------- 1 | #ifndef PYTHON_FFI_FUNCTIONS_H 2 | #define PYTHON_FFI_FUNCTIONS_H 3 | 4 | #include "common.h" 5 | #include "frontend/tensor.h" 6 | #include 7 | 8 | namespace refactor::python_ffi { 9 | using DimVec = std::vector>; 10 | 11 | DataType parseNumpyDType(pybind11::dtype const &); 12 | pybind11::dtype buildNumpyDType(DataType); 13 | frontend::Shape dimVec2Shape(DimVec const &); 14 | 15 | }// namespace refactor::python_ffi 16 | 17 | #endif// PYTHON_FFI_FUNCTIONS_H 18 | -------------------------------------------------------------------------------- /src/09python_ffi/src/refactor_graph/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.extend(__path__) 4 | 5 | import python_ffi 6 | 7 | print("import backend: ", python_ffi) 8 | -------------------------------------------------------------------------------- /utilities/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [alias] 2 | project = "run --release --" 3 | -------------------------------------------------------------------------------- /utilities/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /utilities/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "utilities" 3 | version = "0.0.0" 4 | edition = "2021" 5 | authors = ["YdrMaster "] 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | clap = { version = "4.4", features = ["derive"] } 11 | --------------------------------------------------------------------------------