├── .cargo └── config.toml ├── .gitattributes ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── crates ├── altius_py │ ├── .gitignore │ ├── Cargo.toml │ ├── altius_py │ │ └── __init__.py │ ├── bert.py │ ├── deeplab.py │ ├── deit.py │ ├── export-bert.sh │ ├── export-fugumt.sh │ ├── export-gpt2.sh │ ├── export-tinystories.sh │ ├── export_vit.py │ ├── fastvit.py │ ├── fcn.py │ ├── fuse_attn.py │ ├── gpt2.py │ ├── mandelbrot.py │ ├── mobilenet.py │ ├── pyproject.toml │ ├── real-esrgan.py │ ├── resnet50.py │ ├── show-ort-profile.py │ ├── src │ │ └── lib.rs │ ├── test.sh │ ├── tests │ │ ├── test.rs │ │ ├── test_ops_bin.py │ │ ├── test_ops_concat.py │ │ ├── test_ops_conv.py │ │ ├── test_ops_elemwise.py │ │ ├── test_ops_gather.py │ │ ├── test_ops_gemm.py │ │ ├── test_ops_matmul.py │ │ ├── test_ops_norm.py │ │ ├── test_ops_pool.py │ │ ├── test_ops_reduce.py │ │ ├── test_ops_resize.py │ │ ├── test_ops_transpose.py │ │ └── test_ops_where.py │ ├── translation.py │ ├── uv.lock │ ├── vit.py │ └── yolov5.py ├── core │ ├── Cargo.toml │ ├── build.rs │ └── src │ │ ├── analysis │ │ ├── mod.rs │ │ └── shape.rs │ │ ├── dim.rs │ │ ├── fixed_dim.rs │ │ ├── flops.rs │ │ ├── graph.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ ├── node.rs │ │ ├── onnx │ │ ├── load.rs │ │ ├── mod.rs │ │ ├── onnx.proto │ │ └── save.rs │ │ ├── op.rs │ │ ├── optimize │ │ ├── conv_act_fusion.rs │ │ ├── elemwise_fusion.rs │ │ ├── fast_gelu_fusion.rs │ │ ├── gelu_fusion.rs │ │ ├── identity_elim.rs │ │ ├── layer_norm_fusion.rs │ │ ├── mod.rs │ │ └── transpose_fusion.rs │ │ ├── snapshots │ │ ├── altius_core__model__mnist_model.snap │ │ ├── altius_core__tensor__dump_bool_tensor.snap │ │ ├── altius_core__tensor__dump_f32_tensor.snap │ │ ├── altius_core__tensor__dump_i32_tensor.snap │ │ └── altius_core__tensor__dump_i64_tensor.snap │ │ ├── tensor.rs │ │ └── value.rs ├── session │ ├── Cargo.toml │ ├── src │ │ ├── lib.rs │ │ └── plan.rs │ └── tests │ │ └── ort.rs ├── session_clang │ ├── Cargo.toml │ ├── examples │ │ ├── deit_cpu.rs │ │ ├── mnist_cpu.rs │ │ ├── mobilenet_cpu.rs │ │ └── vit_cpu.rs │ ├── src │ │ ├── builder.rs │ │ ├── lib.rs │ │ ├── session.rs │ │ └── translator.rs │ └── tests │ │ ├── ops_bin.rs │ │ └── ops_conv.rs ├── session_interpreter │ ├── Cargo.toml │ ├── benches │ │ └── interpreter.rs │ ├── examples │ │ ├── deit.rs │ │ ├── infer.rs │ │ ├── mnist.rs │ │ ├── mobilenet.rs │ │ └── vit.rs │ ├── src │ │ ├── builder.rs │ │ ├── conv2d.rs │ │ ├── fast_math.rs │ │ ├── gemm.rs │ │ ├── lib.rs │ │ ├── session.rs │ │ └── thread.rs │ └── tests │ │ ├── mobilenet.rs │ │ └── op_bin.rs └── wasm │ ├── .gitignore │ ├── Cargo.toml │ ├── package.json │ ├── src │ ├── index.tsx │ └── lib.rs │ ├── static │ ├── index.css │ └── index.html │ ├── tsconfig.json │ ├── webpack.config.ts │ └── yarn.lock ├── models ├── MNIST_test.txt ├── download.sh └── imagenet_classes.txt ├── rust-toolchain.toml └── snippets ├── coreml ├── mobilenet.py └── requirements.txt ├── cuda ├── Makefile ├── cuda-gemm-act.cu └── main.c ├── float.c ├── onnx_float16.py ├── q.cc ├── sgemm ├── .gitignore ├── Makefile ├── gemm.deit.cc └── main.cc ├── softmax.c ├── softmax.cc └── test_nchwc.py /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | rustflags = ["-Ctarget-cpu=native"] 3 | 4 | [env] 5 | RUST_LOG = "debug" 6 | CC = "clang" 7 | CXX = "clang++" 8 | GOMP_CPU_AFFINITY='0-7' 9 | MACOSX_DEPLOYMENT_TARGET = "14.0" 10 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | core/examples/MNIST_test.txt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | pull_request: 7 | branches: [ '*' ] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | env: 14 | CARGO_TERM_COLOR: always 15 | 16 | jobs: 17 | Linux: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | - uses: actions/cache@v4 22 | with: 23 | path: | 24 | ~/.cargo/registry 25 | ~/.cargo/git 26 | target 27 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 28 | - name: Download large files 29 | working-directory: ./models 30 | run: ./download.sh CI 31 | - name: Add llvm-tools-preview 32 | run: rustup component add llvm-tools-preview 33 | - name: Install grcov 34 | run: cargo install grcov 35 | - name: Install dependencies 36 | run: sudo apt install libomp-dev xz-utils 37 | - name: Install uv 38 | uses: astral-sh/setup-uv@v6 39 | - name: Free up disk 40 | run: sudo rm -rf /usr/local/lib/android || true 41 | - name: Test 42 | run: | 43 | cargo test --release 44 | ALTIUS_ENABLE_CLIF=1 cargo test --release 45 | env: 46 | RUSTFLAGS: -Cinstrument-coverage 47 | LLVM_PROFILE_FILE: coverage-%p-%m.profraw 48 | - name: Run examples 49 | run: | 50 | (cd crates/altius_py && uv run python deit.py) 51 | # (cd crates/altius_py && uv run python resnet50.py) 52 | (cd crates/altius_py && uv run python export_vit.py) 53 | cargo run --release --example mnist 54 | cargo run --release --example mobilenet 55 | cargo run --release --example deit 56 | cargo run --release --example mnist_cpu 57 | cargo run --release --example mobilenet_cpu 58 | cargo run --release --example deit_cpu 59 | cargo run --release --example vit_cpu 60 | cargo run --release --example vit 61 | cargo run --release --example infer -- ./models/mnist-8.onnx 62 | env: 63 | RUSTFLAGS: -Cinstrument-coverage 64 | LLVM_PROFILE_FILE: coverage-%p-%m.profraw 65 | - name: Submit coverage 66 | run: | 67 | mkdir -p /tmp/cov/ 68 | cp -rf ./target/release/* /tmp/cov/ 69 | grcov . --binary-path /tmp/cov/ -s . -t cobertura --branch --ignore-not-existing --ignore "*cargo*" -o coverage.xml 70 | bash <(curl -s https://codecov.io/bash) 71 | env: 72 | RUSTFLAGS: -Cinstrument-coverage 73 | LLVM_PROFILE_FILE: coverage-%p-%m.profraw 74 | 75 | macOS: 76 | runs-on: macos-14 77 | steps: 78 | - uses: actions/checkout@v4 79 | - uses: actions/cache@v4 80 | with: 81 | path: | 82 | ~/.cargo/registry 83 | ~/.cargo/git 84 | target 85 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 86 | - name: Download large files 87 | working-directory: ./models 88 | run: ./download.sh CI 89 | - name: Install dependencies 90 | run: | 91 | brew install llvm libomp 92 | echo "PATH=$(brew --prefix llvm)/bin:${PATH}" >> $GITHUB_ENV 93 | echo "CPPFLAGS=-I$(brew --prefix libomp)/include" >> $GITHUB_ENV 94 | echo "LDFLAGS=-L$(brew --prefix libomp)/lib" >> $GITHUB_ENV 95 | - name: Install uv 96 | uses: astral-sh/setup-uv@v6 97 | - name: Setup Python environment 98 | working-directory: ./crates/altius_py 99 | run: uv sync 100 | - name: Test 101 | run: | 102 | cargo test --release 103 | ALTIUS_ENABLE_CLIF=1 cargo test --release 104 | env: 105 | RUSTFLAGS: "-C target-cpu=apple-m1" 106 | PYO3_PYTHON: ${{ github.workspace }}/crates/altius_py/.venv/bin/python 107 | - name: Run examples 108 | run: | 109 | (cd crates/altius_py && uv run python deit.py) 110 | (cd crates/altius_py && uv run python export_vit.py) 111 | cargo run --release --example mnist 112 | cargo run --release --example mobilenet 113 | cargo run --release --example deit 114 | cargo run --release --example mnist_cpu 115 | cargo run --release --example mobilenet_cpu 116 | cargo run --release --example deit_cpu 117 | cargo run --release --example vit_cpu 118 | cargo run --release --example vit 119 | cargo run --release --example infer -- ./models/mnist-8.onnx 120 | env: 121 | RUSTFLAGS: "-C target-cpu=apple-m1" 122 | PYO3_PYTHON: ${{ github.workspace }}/crates/altius_py/.venv/bin/python 123 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /pkg 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "crates/core", 4 | "crates/session", 5 | "crates/session_clang", 6 | "crates/session_interpreter", 7 | "crates/altius_py", 8 | "crates/wasm" 9 | ] 10 | resolver = "2" 11 | 12 | [workspace.dependencies] 13 | thiserror = "^1.0.31" 14 | log = "^0.4.17" 15 | rustc-hash = "^1.1.0" 16 | cranelift = "^0.111.0" 17 | cranelift-module = "^0.111.0" 18 | cranelift-object = "^0.111.0" 19 | cranelift-codegen = "^0.111.0" 20 | ndarray = "^0.15.6" 21 | 22 | [profile.release] 23 | opt-level = 3 24 | overflow-checks = false 25 | codegen-units = 8 26 | debug = true 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 uint256_t 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Altius

3 | 4 | CI 5 | 6 | 7 | Coverage 8 | 9 |
10 | Small ONNX inference runtime written in Rust. 11 |
12 | Feel free to create 13 | 14 | issues 15 | 16 | and 17 | 18 | discussions! 19 | 20 |
21 | 22 | # Requirements 23 | 24 | - cargo 25 | - uv 26 | 27 | # Run 28 | 29 | ```sh 30 | # Download models. 31 | (cd models && ./download.sh) 32 | # Download minimum models. 33 | # (cd models && ./download.sh CI) 34 | 35 | # Run examples. 36 | # {mnist, mobilenet, deit, vit} are available. 37 | # You can specify the number of threads for computation by editing the code. 38 | cargo run --release --example mnist 39 | cargo run --release --example mobilenet 40 | cargo run --release --example deit 41 | cargo run --release --example vit 42 | 43 | # Experimental CPU backend (that generates code in C) 44 | cargo run --release --example mnist_cpu -- --iters 10 45 | cargo run --release --example mobilenet_cpu -- --iters 10 --profile 46 | cargo run --release --example deit_cpu -- --iters 10 --threads 8 --profile 47 | ``` 48 | 49 | # Run from WebAssembly 50 | 51 | Currently, mobilenet v3 runs on web browsers. 52 | 53 | ```sh 54 | cd wasm 55 | cargo install wasm-pack 56 | wasm-pack build --target web 57 | yarn 58 | yarn serve 59 | ``` 60 | 61 | # Run from Python 62 | 63 | ```sh 64 | cd ./crates/altius_py 65 | uv sync 66 | uv run maturin develop -r 67 | uv run python mobilenet.py 68 | ``` 69 | -------------------------------------------------------------------------------- /crates/altius_py/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | .pytest_cache/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | .venv/ 14 | env/ 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | include/ 26 | man/ 27 | venv/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | pip-selfcheck.json 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | .DS_Store 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyCharm 66 | .idea/ 67 | 68 | # VSCode 69 | .vscode/ 70 | 71 | # Pyenv 72 | .python-version 73 | 74 | # Environments 75 | .env 76 | .venv 77 | env/ 78 | venv/ 79 | ENV/ 80 | env.bak/ 81 | venv.bak/ 82 | -------------------------------------------------------------------------------- /crates/altius_py/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "altius_py" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "altius_py" 8 | crate-type = ["cdylib"] 9 | test = false 10 | 11 | [dependencies] 12 | altius-core = { path = "../core" } 13 | altius_session = { path = "../session" } 14 | altius_session_clang = { path = "../session_clang" } 15 | altius_session_interpreter = { path = "../session_interpreter" } 16 | pyo3 = { version = "^0.20.0", features = ["extension-module"] } 17 | pyo3-log = "^0.9.0" 18 | numpy = "^0.20.0" 19 | 20 | [dev-dependencies] 21 | cargo-util = "^0.2.1" 22 | 23 | [features] 24 | default = ["cblas"] 25 | matrixmultiply-threading = [ "altius_session_interpreter/matrixmultiply-threading" ] 26 | cuda = [ "altius_session_interpreter/cuda" ] 27 | heavy-log = [ "altius_session_interpreter/heavy-log" ] 28 | cblas = [] 29 | -------------------------------------------------------------------------------- /crates/altius_py/altius_py/__init__.py: -------------------------------------------------------------------------------- 1 | from .altius_py import load, session 2 | 3 | 4 | class InferenceSession: 5 | """ 6 | ``InferenceSession`` is the class used to run a model. 7 | """ 8 | 9 | def __init__(self, model_path, enable_profile=False, intra_op_num_threads=1, backend="interpreter"): 10 | self.model_path = model_path 11 | self.model = load(model_path) 12 | self.session = session(self.model, enable_profile, intra_op_num_threads, backend) 13 | 14 | def run(self, output, input): 15 | """ 16 | Compute the predictions. 17 | 18 | Args: 19 | output (Optional[list[str]]): Name of the outputs, but must be None for now. 20 | input (dict[str, numpy.ndarray]): Dictionary ``{ input_name: input_value }``. 21 | 22 | Returns: 23 | list[numpy.ndarray]: Output values. 24 | """ 25 | 26 | assert output is None 27 | return self.session.run(input) 28 | -------------------------------------------------------------------------------- /crates/altius_py/bert.py: -------------------------------------------------------------------------------- 1 | # python -m transformers.onnx --model=bert-base-cased --feature=masked-lm ./a 2 | 3 | import time 4 | import logging 5 | import os 6 | import sys 7 | 8 | from transformers import AutoTokenizer, BertTokenizer 9 | import onnxruntime as ort 10 | import numpy as np 11 | import altius_py 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | 15 | tokenizer = BertTokenizer.from_pretrained("bert-base-cased", mask_token="[MASK]") 16 | 17 | if not os.path.exists("../../models/bert.onnx"): 18 | print("Run ../../models/download.sh to download ../../models/bert.onnx") 19 | sys.exit(0) 20 | 21 | # session = ort.InferenceSession("../../models/bert.onnx") 22 | session = altius_py.InferenceSession( 23 | "../../models/bert.onnx", intra_op_num_threads=8, enable_profile=True 24 | ) 25 | 26 | # msg = "Paris is the [MASK] city of France" 27 | # msg = "Deep [MASK] network has been widely used" 28 | msg = "We usually use a [MASK] to input characters to a computer" 29 | # msg = "The number [MASK] is famous as the ultimate answer of everything" 30 | mask_pos = msg.split().index("[MASK]") + 1 31 | print(f"Masked sentence (up to 20 tokens): {msg}") 32 | 33 | inputs = tokenizer(msg, return_tensors="np") 34 | for name in ["input_ids", "attention_mask", "token_type_ids"]: 35 | input = np.zeros((1, 20), dtype=np.int64) 36 | input[0, : inputs[name].shape[1]] = inputs[name] 37 | inputs[name] = input 38 | 39 | repeat = 10 # TIPS: First run is usually slow. 40 | for _ in range(repeat): 41 | start = time.time() 42 | outputs = session.run(None, dict(inputs)) 43 | end = time.time() 44 | print(f"Inference time: {end - start}") 45 | 46 | ids = np.argsort(-outputs[0][0, mask_pos])[:5] 47 | for i, tok in enumerate(tokenizer.convert_ids_to_tokens(ids.tolist())): 48 | print(f"Top{i+1}: {msg.replace('[MASK]', tok.upper())}") 49 | -------------------------------------------------------------------------------- /crates/altius_py/deeplab.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import random 4 | import logging 5 | from itertools import cycle 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.transforms.functional import to_pil_image 10 | from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights 11 | from torchvision import transforms 12 | from matplotlib import colors as mcolors 13 | from PIL import Image 14 | 15 | import onnxruntime as ort 16 | import altius_py 17 | 18 | 19 | def main(): 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | path = "../../models/cat.png" 23 | image = Image.open(path).resize((520, 520)) 24 | 25 | weights = FCN_ResNet50_Weights.DEFAULT 26 | preprocess = weights.transforms() 27 | input = np.ascontiguousarray((preprocess(image).unsqueeze(0))) 28 | 29 | # sess_options = ort.SessionOptions() 30 | # # sess_options.intra_op_num_threads = 1 31 | # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL 32 | # sess = ort.InferenceSession( 33 | # "../../models/deeplab_mobilenetv3.onnx", sess_options=sess_options 34 | # ) 35 | sess = altius_py.InferenceSession( 36 | "../../models/deeplab_mobilenetv3.onnx", enable_profile=True 37 | ) 38 | 39 | inputs = {"input.1": input} 40 | 41 | start = time.time() 42 | output = sess.run(None, inputs)[0] 43 | print(f"Inference elapsed: {time.time() - start}") 44 | 45 | prediction = torch.tensor(output) 46 | normalized_masks = prediction.softmax(dim=1) 47 | class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])} 48 | colors = cycle(mcolors.BASE_COLORS.values()) 49 | color_like = lambda input: [torch.full_like(input, c) for c in next(colors)] 50 | 51 | for klass, idx in class_to_idx.items(): 52 | if klass == "__background__": 53 | continue 54 | 55 | mask = normalized_masks[0, idx] 56 | if torch.max(mask) < 0.2: 57 | # No objects of this class 58 | continue 59 | 60 | mask_img = to_pil_image( 61 | torch.stack(color_like(mask) + [mask * 0.5]), 62 | mode="RGBA", 63 | ) 64 | image = Image.alpha_composite(image.convert("RGBA"), mask_img) 65 | 66 | image.save("masked.png") 67 | image.show() 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /crates/altius_py/deit.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import os 4 | 5 | import onnx 6 | import altius_py 7 | import torch 8 | 9 | from PIL import Image 10 | from torchvision import transforms 11 | 12 | 13 | def main(): 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | image = Image.open("../../models/cat.png") 17 | labels = open("../../models/imagenet_classes.txt").readlines() 18 | 19 | preprocess = transforms.Compose( 20 | [ 21 | transforms.Resize(224), 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 24 | ] 25 | ) 26 | input = preprocess(image) 27 | input = input.unsqueeze(0).numpy() 28 | 29 | onnx_path = "../../models/deit.onnx" 30 | 31 | if not os.path.exists(onnx_path): 32 | import onnxsim 33 | from transformers import ViTImageProcessor, ViTForImageClassification 34 | 35 | model = ViTForImageClassification.from_pretrained( 36 | "facebook/deit-small-patch16-224" 37 | ) 38 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), onnx_path) 39 | simplified_model, success = onnxsim.simplify(onnx_path) 40 | assert success 41 | onnx.save(simplified_model, onnx_path) 42 | 43 | altius_model = altius_py.InferenceSession( 44 | onnx_path, intra_op_num_threads=1, enable_profile=True 45 | ) 46 | 47 | with torch.no_grad(): 48 | for i in range(1): 49 | output = altius_model.run(None, {"input.1": input}) 50 | pred = torch.tensor(output).reshape((-1,)).argsort().numpy()[::-1][:5] 51 | top5 = [labels[i].strip() for i in pred] 52 | print(top5) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /crates/altius_py/export-bert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | EXPORTER_VENV=.exporter.venv 4 | 5 | if [ ! -d ${EXPORTER_VENV} ]; then 6 | python3 -m venv ${EXPORTER_VENV} 7 | source ${EXPORTER_VENV}/bin/activate 8 | pip install -U pip 9 | pip install onnx onnxruntime onnxsim optimum==1.16.2 10 | fi 11 | 12 | source ${EXPORTER_VENV}/bin/activate 13 | 14 | DIR=bert-onnx 15 | 16 | python -m optimum.exporters.onnx --model "bert-base-uncased" --task fill-mask --opset 14 ${DIR} 17 | 18 | ONNXSIM_FIXED_POINT_ITERS=1000 \ 19 | onnxsim ./${DIR}/model.onnx ./${DIR}/model.onnx --overwrite-input-shape input_ids:1,100 attention_mask:1,100 token_type_ids:1,100 20 | 21 | printf "\e[1;32mExported in ${DIR}\e[0m\n" 22 | -------------------------------------------------------------------------------- /crates/altius_py/export-fugumt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | EXPORTER_VENV=.exporter.venv 4 | 5 | if [ ! -d ${EXPORTER_VENV} ]; then 6 | python3 -m venv ${EXPORTER_VENV} 7 | source ${EXPORTER_VENV}/bin/activate 8 | pip install -U pip 9 | pip install onnx onnxruntime optimum==1.16.2 10 | fi 11 | 12 | source ${EXPORTER_VENV}/bin/activate 13 | 14 | DIR=fugumt-en-ja 15 | 16 | python -m optimum.exporters.onnx --model "staka/${DIR}" ${DIR} 17 | 18 | onnxsim ./${DIR}/encoder_model.onnx ./${DIR}/encoder_model.onnx --overwrite-input-shape input_ids:1,100 attention_mask:1,100 19 | onnxsim ./${DIR}/decoder_model.onnx ./${DIR}/decoder_model.onnx --overwrite-input-shape encoder_attention_mask:1,100 input_ids:1,100 encoder_hidden_states:1,100,512 20 | 21 | onnxsim ./${DIR}/decoder_model.onnx ./${DIR}/decoder_model.onnx --unused-output \ 22 | present.0.encoder.key present.1.encoder.key present.2.encoder.key present.3.encoder.key present.4.encoder.key present.5.encoder.key \ 23 | present.0.encoder.value present.1.encoder.value present.2.encoder.value present.3.encoder.value present.4.encoder.value present.5.encoder.value \ 24 | present.0.decoder.key present.1.decoder.key present.2.decoder.key present.3.decoder.key present.4.decoder.key present.5.decoder.key \ 25 | present.0.decoder.value present.1.decoder.value present.2.decoder.value present.3.decoder.value present.4.decoder.value present.5.decoder.value 26 | 27 | printf "\e[1;32mExported in ${DIR}\e[0m\n" 28 | -------------------------------------------------------------------------------- /crates/altius_py/export-gpt2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | EXPORTER_VENV=.exporter.venv 4 | 5 | if [ ! -d ${EXPORTER_VENV} ]; then 6 | python3 -m venv ${EXPORTER_VENV} 7 | source ${EXPORTER_VENV}/bin/activate 8 | pip install -U pip 9 | pip install onnx onnxruntime onnxsim optimum==1.16.2 10 | fi 11 | 12 | source ${EXPORTER_VENV}/bin/activate 13 | 14 | DIR=gpt2-onnx 15 | 16 | python -m optimum.exporters.onnx --model "gpt2" --task text-generation --opset 14 ${DIR} 17 | 18 | ONNXSIM_FIXED_POINT_ITERS=1000 \ 19 | onnxsim ./${DIR}/model.onnx ./${DIR}/model.onnx --overwrite-input-shape input_ids:1,100 attention_mask:1,100 position_ids:1,100 20 | 21 | printf "\e[1;32mExported in ${DIR}\e[0m\n" 22 | -------------------------------------------------------------------------------- /crates/altius_py/export-tinystories.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -eux 2 | 3 | DIR=TinyStories-33M 4 | 5 | optimum-cli export onnx \ 6 | -m 'roneneldan/TinyStories-33M' \ 7 | --opset 13 \ 8 | --task causal-lm \ 9 | $DIR 10 | 11 | onnxsim \ 12 | $DIR/decoder_model.onnx $DIR/decoder_model.onnxsim.onnx \ 13 | --overwrite-input-shape input_ids:1,100 attention_mask:1,100 14 | 15 | -------------------------------------------------------------------------------- /crates/altius_py/export_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision, torchvision 3 | 4 | import onnx, onnx.checker, onnx.shape_inference 5 | import onnxsim 6 | from torchvision.models import ViT_B_16_Weights 7 | 8 | model = torchvision.models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1) 9 | model.eval() 10 | 11 | path = "../../models/vit_b_16.onnx" 12 | torch.onnx.export(model, torch.randn(1, 3, 224, 224), path, opset_version=14) 13 | 14 | model, ok = onnxsim.simplify(path) 15 | assert ok 16 | 17 | onnx.checker.check_model(model) 18 | model = onnx.shape_inference.infer_shapes(model) 19 | onnx.save(model, path) 20 | -------------------------------------------------------------------------------- /crates/altius_py/fastvit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from urllib.request import urlopen 5 | from PIL import Image 6 | 7 | import onnx 8 | import onnxsim 9 | 10 | import torch 11 | import timm 12 | import onnxruntime as ort 13 | import altius_py as alt 14 | 15 | 16 | def main(): 17 | img = Image.open( 18 | urlopen( 19 | "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" 20 | ) 21 | ) 22 | 23 | model = timm.create_model("fastvit_s12.apple_in1k", pretrained=True) 24 | model = model.eval() 25 | 26 | path = "../../models/fastvit.onnx" 27 | if not os.path.exists(path): 28 | torch.onnx.export( 29 | model, 30 | torch.randn(1, 3, 256, 256), 31 | path, 32 | input_names=["input"], 33 | output_names=["output"], 34 | opset_version=12, 35 | ) 36 | _model, check = onnxsim.simplify(onnx.load(path)) 37 | assert check, "Failed to simplify model" 38 | onnx.save(_model, path) 39 | 40 | data_config = timm.data.resolve_model_data_config(model) 41 | transforms = timm.data.create_transform(**data_config, is_training=False) 42 | 43 | # For altius 44 | os.environ["GOMP_WAIT_POLICY"] = "ACTIVE" 45 | os.environ["GOMP_CPU_AFFINITY"] = "0-7" 46 | 47 | # ort_sess = ort.InferenceSession("fastvit.onnx", providers=["CPUExecutionProvider"]) 48 | alt_sess = alt.InferenceSession(path, backend="cpu", intra_op_num_threads=8) 49 | 50 | with open("../../models/imagenet_classes.txt") as f: 51 | class_names = [line.strip() for line in f.readlines()] 52 | class_idx_to_label = {i: class_names[i] for i in range(len(class_names))} 53 | 54 | output = alt_sess.run( 55 | None, 56 | {"input": transforms(img).unsqueeze(0).numpy()}, 57 | )[0] 58 | 59 | top5_probabilities, top5_class_indices = torch.topk( 60 | torch.tensor(output).softmax(dim=1) * 100, k=5 61 | ) 62 | print(f"top 5 probs: {top5_probabilities}") 63 | print( 64 | f"top 5 labels: {[class_idx_to_label[idx] for idx in top5_class_indices.squeeze(0).tolist()]}" 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /crates/altius_py/fcn.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import random 4 | import logging 5 | from itertools import cycle 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.transforms.functional import to_pil_image 10 | from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights 11 | from torchvision import transforms 12 | from matplotlib import colors as mcolors 13 | from PIL import Image 14 | 15 | import onnxruntime as ort 16 | import altius_py 17 | 18 | 19 | def main(): 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | path = "../../models/cat.png" 23 | image = Image.open(path).resize((520, 520)) 24 | 25 | weights = FCN_ResNet50_Weights.DEFAULT 26 | preprocess = weights.transforms() 27 | input = np.ascontiguousarray((preprocess(image).unsqueeze(0))) 28 | 29 | # sess_options = ort.SessionOptions() 30 | # sess_options.intra_op_num_threads = 1 31 | # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL 32 | # sess = ort.InferenceSession( 33 | # "../../models/fcn-resnet50.onnx", sess_options=sess_options 34 | # ) 35 | sess = altius_py.InferenceSession("../../models/fcn-resnet50.onnx") 36 | 37 | inputs = {"input.1": input} 38 | 39 | start = time.time() 40 | output = sess.run(None, inputs)[0] 41 | print(f"Inference elapsed: {time.time() - start}") 42 | 43 | prediction = torch.tensor(output) 44 | normalized_masks = prediction.softmax(dim=1) 45 | class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])} 46 | colors = cycle(mcolors.BASE_COLORS.values()) 47 | color_like = lambda input: [torch.full_like(input, c) for c in next(colors)] 48 | 49 | for klass, idx in class_to_idx.items(): 50 | if klass == "__background__": 51 | continue 52 | 53 | mask = normalized_masks[0, idx] 54 | if torch.max(mask) < 0.2: 55 | # No objects of this class 56 | continue 57 | 58 | mask_img = to_pil_image( 59 | torch.stack(color_like(mask) + [mask * 0.5]), 60 | mode="RGBA", 61 | ) 62 | image = Image.alpha_composite(image.convert("RGBA"), mask_img) 63 | 64 | image.save("masked.png") 65 | image.show() 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /crates/altius_py/fuse_attn.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | from typing import DefaultDict, Dict, List, Optional, Set, Tuple 4 | 5 | import onnx 6 | from onnx import ModelProto, NodeProto, helper 7 | import onnxruntime as ort 8 | 9 | # from onnxruntime.transformers.onnx_model import OnnxModel 10 | # from onnxruntime.transformers.fusion_attention import FusionAttention 11 | # from onnxruntime.transformers.fusion_attention import AttentionMask 12 | 13 | # class DeiT(OnnxModel): 14 | # def __init__(self, model: ModelProto): 15 | # super().__init__(model) 16 | # self.attn_mask = AttentionMask(self) 17 | # self.attn_fusion = FusionAttention(self, 384, 6, self.attn_mask) 18 | # 19 | # def fuse(self): 20 | # self.attn_fusion.apply() 21 | 22 | 23 | def create_value_to_users( 24 | model: onnx.ModelProto, 25 | ) -> DefaultDict[str, List[onnx.NodeProto]]: 26 | value_to_users = collections.defaultdict(lambda: []) 27 | for node in model.graph.node: 28 | for input in node.input: 29 | value_to_users[input].append(node) 30 | return value_to_users 31 | 32 | 33 | def fuse_mha( 34 | root: NodeProto, 35 | visited: Set[str], 36 | value_to_users: DefaultDict[str, List[NodeProto]], 37 | ) -> Optional[NodeProto]: 38 | if root.op_type != "LayerNormalization": 39 | return None 40 | 41 | ln = root 42 | if len(value_to_users[ln.output[0]]) != 3: 43 | return None 44 | 45 | mm1, mm2, mm3 = value_to_users[ln.output[0]] 46 | if mm1.op_type != "MatMul" or mm2.op_type != "MatMul" or mm3.op_type != "MatMul": 47 | return None 48 | if mm1.op_type != "MatMul" or mm2.op_type != "MatMul" or mm3.op_type != "MatMul": 49 | return None 50 | 51 | add1, add2, add3 = ( 52 | value_to_users[mm1.output[0]], 53 | value_to_users[mm2.output[0]], 54 | value_to_users[mm3.output[0]], 55 | ) 56 | if len(add1) != 1 or len(add2) != 1 or len(add3) != 1: 57 | return None 58 | add1, add2, add3 = add1[0], add2[0], add3[0] 59 | if add1.op_type != "Add" or add2.op_type != "Add" or add3.op_type != "Add": 60 | return None 61 | 62 | key = None 63 | query = None 64 | value = None 65 | for out in [add1.output[0], add2.output[0], add3.output[0]]: 66 | if "attention/key" in out: 67 | key = out 68 | elif "attention/query" in out: 69 | query = out 70 | elif "attention/value" in out: 71 | value = out 72 | if key is None or query is None or value is None: 73 | return None 74 | 75 | que = [] 76 | que.extend(value_to_users[key]) 77 | que.extend(value_to_users[query]) 78 | que.extend(value_to_users[value]) 79 | exit_reshape_node = None 80 | while que: 81 | node = que.pop(0) 82 | visited.add(node.name) 83 | if node.op_type == "Reshape" and "attention/attention/Reshape_3" in node.name: 84 | print(node.name) 85 | exit_reshape_node = node 86 | break 87 | users = value_to_users[node.output[0]] 88 | que.extend(users) 89 | assert exit_reshape_node is not None 90 | 91 | num_heads = 6 92 | mha_node = helper.make_node( 93 | "MultiHeadAttention", 94 | inputs=[query, key, value], 95 | outputs=[exit_reshape_node.output[0]], 96 | name=f"MultiHeadAttention@{ln.name}", 97 | ) 98 | mha_node.domain = "com.microsoft" 99 | mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) 100 | 101 | print("Fusing MHA") 102 | 103 | return mha_node 104 | 105 | 106 | def topo_sort( 107 | model: onnx.ModelProto, nodes: List[onnx.NodeProto] 108 | ) -> List[onnx.NodeProto]: 109 | node_to_order = {} 110 | for i, n in enumerate(model.graph.node): 111 | node_to_order[n.output[0]] = i 112 | 113 | order_and_nodes = [] 114 | for n in nodes: 115 | order_and_nodes.append((node_to_order[n.output[0]], n)) 116 | 117 | order_and_nodes.sort(key=lambda x: x[0]) 118 | return [n for _, n in order_and_nodes] 119 | 120 | 121 | def fuse(model: ModelProto) -> ModelProto: 122 | users = create_value_to_users(model) 123 | new_model = copy.deepcopy(model) 124 | del new_model.graph.node[:] 125 | 126 | visited: Set[str] = set() 127 | nodes = [] 128 | for node in model.graph.node: 129 | if node.name in visited: 130 | continue 131 | 132 | nodes.append(node) 133 | 134 | mha = fuse_mha(node, visited, users) 135 | if mha is not None: 136 | nodes.append(mha) 137 | 138 | sorted_nodes = topo_sort(model, nodes) 139 | for node in sorted_nodes: 140 | new_model.graph.node.add().CopyFrom(node) 141 | 142 | new_model.opset_import.append(helper.make_opsetid("com.microsoft", 1)) 143 | 144 | onnx.checker.check_model(new_model) 145 | 146 | return new_model 147 | 148 | 149 | def main(): 150 | model = onnx.load("../../models/deit.onnx") 151 | # deit = DeiT(copy.deepcopy(model)) 152 | # deit.fuse() 153 | 154 | new_model = fuse(model) 155 | 156 | onnx.save(new_model, "./fused_deit.onnx") 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /crates/altius_py/gpt2.py: -------------------------------------------------------------------------------- 1 | # python -m transformers.onnx --model=gpt2 --feature=causal-lm ./a 2 | 3 | import time 4 | import logging 5 | import os 6 | import sys 7 | 8 | from transformers import AutoTokenizer, BertTokenizer, top_k_top_p_filtering 9 | from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Tokenizer 10 | import onnxruntime as ort 11 | import numpy as np 12 | import altius_py 13 | import torch 14 | from torch.nn import functional as F 15 | 16 | 17 | logging.basicConfig(level=logging.INFO) 18 | 19 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 20 | sess = altius_py.InferenceSession( 21 | "./gpt2-onnx/model.onnx", intra_op_num_threads=16, enable_profile=True, backend="cpu" 22 | ) 23 | # sess = ort.InferenceSession("./gpt2-onnx/model.onnx", providers=["CPUExecutionProvider"]) 24 | 25 | torch.manual_seed(42) 26 | 27 | max_tokens = 100 28 | text = "Rust is a multi-paradigm, general-purpose programming language. Rust emphasizes performance," 29 | for _ in range(1000): 30 | inputs = tokenizer(text, return_tensors="np") 31 | len = inputs["input_ids"].shape[1] 32 | 33 | if len >= max_tokens: 34 | break 35 | 36 | for name in ["input_ids", "attention_mask"]: 37 | input = np.zeros((1, max_tokens), dtype=np.int64) 38 | input[0, : inputs[name].shape[1]] = inputs[name] 39 | inputs[name] = input 40 | 41 | inputs["position_ids"] = np.arange(max_tokens).reshape((1, -1)) 42 | 43 | outputs = sess.run(None, dict(inputs)) 44 | 45 | next_token_logits = outputs[0][:, len - 1, :] 46 | 47 | filtered_next_token_logits = top_k_top_p_filtering( 48 | torch.tensor(next_token_logits), top_k=50, top_p=1.0 49 | ) 50 | probs = F.softmax(filtered_next_token_logits, dim=-1) 51 | next_token = torch.multinomial(probs, num_samples=1) 52 | generated = torch.cat([torch.tensor(inputs["input_ids"][0, :len]), next_token[0]]) 53 | resulting_string = tokenizer.decode(generated.tolist()) 54 | print(resulting_string) 55 | text = resulting_string 56 | -------------------------------------------------------------------------------- /crates/altius_py/mandelbrot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | 4 | import torch.nn as nn 5 | import onnxsim 6 | import onnx 7 | import torch 8 | import onnxruntime as ort 9 | import altius_py 10 | 11 | W = 320 * 3 12 | H = 240 * 3 13 | XMIN = -2.4 14 | XMAX = 1.2 15 | YMIN = -1.2 16 | YMAX = 1.2 17 | 18 | 19 | class Mandelbrot(nn.Module): 20 | def forward(self, k, zx, zy): 21 | w = W 22 | h = H 23 | x = torch.linspace(XMIN, XMAX, W, dtype=torch.float32) 24 | y = torch.linspace(YMIN, YMAX, H, dtype=torch.float32) 25 | cx, cy = torch.meshgrid([x, y]) 26 | cx = cx.to(torch.float32) 27 | cy = cy.to(torch.float32) 28 | 29 | zx2 = zx**2 30 | zy2 = zy**2 31 | inf = (zx2 + zy2) > 4 32 | max = torch.max(k) 33 | k[inf] = max + 1 34 | zxn = zx2 - zy2 + cx 35 | zyn = 2 * zx * zy + cy 36 | return k, zxn, zyn 37 | 38 | 39 | if __name__ == "__main__": 40 | model = Mandelbrot() 41 | 42 | zx = torch.zeros(W * H, dtype=torch.float32).reshape(W, H) 43 | zy = torch.zeros(W * H, dtype=torch.float32).reshape(W, H) 44 | k = torch.zeros(W * H, dtype=torch.float32).reshape(W, H) 45 | path = "/tmp/mandelbrot.onnx" 46 | 47 | if not os.path.exists(path): 48 | torch.onnx.export( 49 | model, 50 | {"k": k, "zx": zx, "zy": zy}, 51 | path, 52 | input_names=["k", "zx", "zy"], 53 | opset_version=12, 54 | ) 55 | simplified, ok = onnxsim.simplify(onnx.load(path)) 56 | assert ok 57 | onnx.save(simplified, path) 58 | 59 | model = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) 60 | # model = altius_py.InferenceSession(path) 61 | 62 | k = k.numpy() 63 | zx = zx.numpy() 64 | zy = zy.numpy() 65 | 66 | for i in range(100): 67 | k, zxn, zyn = model.run(None, {"k": k, "zx": zx, "zy": zy}) 68 | zx = zxn 69 | zy = zyn 70 | 71 | mandelbrot = k.T 72 | 73 | plt.figure(figsize=(3.200, 2.400), dpi=1000) 74 | img = plt.imshow(mandelbrot) 75 | img.set_cmap("hot") 76 | plt.axis("off") 77 | # plt.savefig("mandel.png", dpi=100) 78 | plt.show() 79 | -------------------------------------------------------------------------------- /crates/altius_py/mobilenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | from matplotlib import pyplot as plt 6 | from PIL import Image 7 | import numpy as np 8 | 9 | from torchvision import transforms 10 | import onnxruntime as ort 11 | import altius_py 12 | 13 | 14 | def main(): 15 | labels = open("../../models/imagenet_classes.txt").readlines() 16 | image = Image.open("../../models/cat.png") 17 | 18 | preprocess = transforms.Compose( 19 | [ 20 | transforms.Resize(256), 21 | transforms.CenterCrop(224), 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ] 25 | ) 26 | input = preprocess(image) 27 | input = input.unsqueeze(0).numpy() 28 | 29 | sess = altius_py.InferenceSession("../../models/mobilenetv3.onnx") 30 | 31 | inputs = {"input": input} 32 | output = sess.run(None, inputs)[0][0] 33 | output = np.argsort(output)[::-1][:5] 34 | output = [labels[i].strip() for i in output] 35 | print(f"top5: {output}") 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /crates/altius_py/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "altius_py" 3 | version = "0.1.0" 4 | authors = [ 5 | { name = "maekawatoshiki" } 6 | ] 7 | dependencies = [ 8 | "pillow>=9.2.0", 9 | "matplotlib>=3.5.3", 10 | "maturin>=1.1.0", 11 | "onnx>=1.15.0", 12 | "transformers==4.41.2", 13 | "onnxsim==0.4.17", 14 | "numpy>=1.25.2", 15 | "onnxruntime>=1.15.1", 16 | "pytest>=7.4.4", 17 | "pytest-xdist>=3.3.1", 18 | "torch>=2.2.2", 19 | "onnxscript>=0.1.0.dev20240227", 20 | "pip>=24.0", 21 | "optimum>=1.18.1", 22 | "torchvision>=0.17.2", 23 | "tabulate>=0.9.0", 24 | "timm>=0.9.16", 25 | "packaging>=24.1", 26 | "einops>=0.8.0", 27 | "fpzip>=1.2.4", 28 | "zfpy>=1.0.0", 29 | "onnxruntime-genai>=0.3.0; sys_platform == 'linux'", 30 | "huggingface>=0.0.1", 31 | "netron>=7.8.3", 32 | "patchelf>=0.17.2.1; sys_platform == 'linux'", 33 | ] 34 | requires-python = "==3.12.*" 35 | 36 | [build-system] 37 | requires = ["maturin>=1.1.0"] 38 | build-backend = "maturin" 39 | -------------------------------------------------------------------------------- /crates/altius_py/real-esrgan.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import time 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision import transforms 6 | import os, random 7 | from matplotlib import pyplot as plt 8 | import onnxruntime as ort 9 | import logging 10 | from torchvision.transforms.functional import to_pil_image 11 | import torch 12 | 13 | 14 | def main(): 15 | logging.basicConfig(level=logging.INFO) 16 | image = Image.open("../../models/cat.png").convert("RGB") 17 | 18 | preprocess = transforms.Compose( 19 | [ 20 | transforms.Resize(256), 21 | transforms.CenterCrop(256), 22 | transforms.ToTensor(), 23 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ] 25 | ) 26 | input = preprocess(image) 27 | input = input.unsqueeze(0).numpy() 28 | print(input.shape) 29 | 30 | path = "../../models/realesrgan_256x256.onnx" 31 | sess = altius_py.InferenceSession( 32 | path, intra_op_num_threads=32, enable_profile=True 33 | ) 34 | # sess = ort.InferenceSession(path, providers=["CUDAExecutionProvider"]) 35 | # sess_options = ort.SessionOptions() 36 | # # sess_options.enable_profiling = True 37 | # sess_options.intra_op_num_threads = 16 38 | # sess_options.inter_op_num_threads = 1 39 | # sess = ort.InferenceSession( 40 | # path, 41 | # providers=["CPUExecutionProvider"], 42 | # # providers=["CUDAExecutionProvider"], 43 | # sess_options=sess_options, 44 | # ) 45 | 46 | inputs = {"input.1": input} 47 | start = time.time() 48 | output = sess.run(None, inputs)[0] 49 | print(f"elapsed: {time.time() - start}") 50 | 51 | # print(output.shape) 52 | # print(output.max()) 53 | # print(output.min()) 54 | img = to_pil_image(torch.tensor(output.clip(0, 1)).squeeze()) 55 | 56 | img.save("a.png") 57 | img.show() 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /crates/altius_py/resnet50.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | # os.environ["GOMP_CPU_AFFINITY"] = "0-7" 5 | # os.environ["OMP_WAIT_POLICY"] = "active" 6 | 7 | from PIL import Image 8 | import numpy as np 9 | 10 | import torch 11 | from torchvision import transforms 12 | from torchvision.models import resnet50 13 | 14 | import onnxruntime as ort 15 | import altius_py 16 | 17 | 18 | def main(): 19 | model_path = "../../models/resnet50.onnx" 20 | 21 | if not os.path.exists(model_path): 22 | with torch.no_grad(): 23 | model = resnet50(pretrained=True) 24 | torch.onnx.export( 25 | model, 26 | torch.randn(1, 3, 224, 224, dtype=torch.float32), 27 | model_path, 28 | verbose=True, 29 | ) 30 | 31 | labels = open("../../models/imagenet_classes.txt").readlines() 32 | image = Image.open("../../models/cat.png") 33 | 34 | preprocess = transforms.Compose( 35 | [ 36 | transforms.Resize(256), 37 | transforms.CenterCrop(224), 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 40 | ] 41 | ) 42 | input = preprocess(image).unsqueeze(0).numpy() 43 | 44 | use_ort = False 45 | if use_ort: 46 | sess = ort.InferenceSession(model_path) 47 | else: 48 | sess = altius_py.InferenceSession( 49 | model_path, 50 | intra_op_num_threads=1, 51 | backend="cpu", 52 | ) 53 | 54 | inputs = {"input.1": input} 55 | for _ in range(10): 56 | start = time.time() 57 | output = sess.run(None, inputs)[0][0] 58 | print(f"Elapsed: {(time.time() - start) * 1000.0:.3f} [ms]") 59 | output = np.argsort(output)[::-1][:5] 60 | output = [labels[i].strip() for i in output] 61 | print(f"Top-5: {output}") 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /crates/altius_py/show-ort-profile.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from collections import defaultdict 4 | 5 | from tabulate import tabulate 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("filepath", help="onnxruntime profile json", type=str) 11 | args = parser.parse_args() 12 | 13 | profile = json.load(open(args.filepath)) 14 | durations = defaultdict(lambda: 0) 15 | 16 | for elem in profile: 17 | args = elem.get("args") 18 | if args: 19 | op = args.get("op_name") 20 | if op: 21 | dur = int(elem.get("dur")) 22 | durations[op] += dur 23 | 24 | table = [(op, dur / 1000.0) for op, dur in durations.items()] 25 | table.append(("*Total*", sum(dur for _, dur in table))) 26 | table = sorted(table, key=lambda x: x[1]) 27 | 28 | print(tabulate(table, tablefmt="simple_outline", headers=["Op", "Duration [ms]"])) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /crates/altius_py/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | if [ ! -d ".venv" ]; then 4 | uv sync 5 | fi 6 | 7 | export RUST_LOG=INFO 8 | 9 | if [ "${1:-nobuild}" = "build" ]; then 10 | if [ -z "${GITHUB_ACTIONS}" ]; then 11 | uv run maturin develop -r --target-dir ./target > /dev/null 12 | else 13 | uv run maturin develop -r > /dev/null 14 | fi 15 | fi 16 | 17 | unset GOMP_CPU_AFFINITY 18 | 19 | n=$(nproc) 20 | uv run python -m pytest . -n $((n > 16 ? 16 : n)) 21 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | env, 3 | fs::read_dir, 4 | io, 5 | path::{Path, PathBuf}, 6 | }; 7 | 8 | use cargo_util::paths::mtime_recursive; 9 | 10 | #[test] 11 | fn run_python_tests() { 12 | // If build artifacts are modified, run `maturin develop -r` by passing `build` option to 13 | // `./test.sh`. 14 | let root = get_project_root().unwrap(); 15 | let target_mtime = mtime_recursive(&root.join("target/")).unwrap(); 16 | // TODO: Better not hard-code venv dir `.venv`. 17 | let build = mtime_recursive(Path::new(".venv")).map_or("build", |src_mtime| { 18 | if target_mtime > src_mtime { 19 | "build" 20 | } else { 21 | "" 22 | } 23 | }); 24 | assert!(std::process::Command::new("bash") 25 | .arg("./test.sh") 26 | .arg(build) 27 | .spawn() 28 | .unwrap() 29 | .wait() 30 | .unwrap() 31 | .success()) 32 | } 33 | 34 | #[cfg(test)] 35 | fn get_project_root() -> io::Result { 36 | let path = env::current_dir()?; 37 | let path_ancestors = path.as_path().ancestors(); 38 | 39 | for p in path_ancestors { 40 | let has_cargo = read_dir(p)?.any(|p| p.unwrap().file_name() == *"Cargo.lock"); 41 | if has_cargo { 42 | return Ok(PathBuf::from(p)); 43 | } 44 | } 45 | 46 | Err(io::Error::new( 47 | io::ErrorKind::NotFound, 48 | "Cargo.lock not found", 49 | )) 50 | } 51 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_concat.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_concat_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_concat( 14 | os.path.join(tmpdir, "model.onnx"), 15 | [1, 1, 10], 16 | [1, 3, 10], 17 | [1, 4, 10], 18 | axis=1, 19 | ) 20 | 21 | 22 | def op_concat(filepath, shape_x, shape_y, shape_z, **kwargs): 23 | inputs = [ 24 | helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x), 25 | helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y), 26 | ] 27 | outputs = [helper.make_tensor_value_info("z", TensorProto.FLOAT, shape_z)] 28 | nodes = [helper.make_node("Concat", ["x", "y"], ["z"], **kwargs)] 29 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 30 | model = helper.make_model(graph) 31 | 32 | onnx.save(model, filepath) 33 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 34 | 35 | for backend in ["interpreter", "cpu"]: 36 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 37 | 38 | x = np.random.random_sample(shape_x).astype(np.float32) 39 | y = np.random.random_sample(shape_y).astype(np.float32) 40 | inputs = {"x": x, "y": y} 41 | expected = ort_sess.run(None, inputs) 42 | actual = altius_sess.run(None, inputs) 43 | 44 | for expected, actual in zip(expected, actual): 45 | assert np.allclose(expected, actual) 46 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_conv.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | @pytest.mark.parametrize("bias", [False, True]) 12 | def test_conv2d_1(bias): 13 | with tempfile.TemporaryDirectory() as tmpdir: 14 | op_conv2d( 15 | os.path.join(tmpdir, "model.onnx"), 16 | [1, 3, 224, 224], 17 | [16, 3, 3, 3], 18 | [1, 16, 112, 112], 19 | bias=bias, 20 | pads=[1, 1, 1, 1], 21 | strides=[2, 2], 22 | ) 23 | 24 | 25 | @pytest.mark.parametrize("bias", [False, True]) 26 | def test_conv2d_2(bias): 27 | with tempfile.TemporaryDirectory() as tmpdir: 28 | op_conv2d( 29 | os.path.join(tmpdir, "model.onnx"), 30 | [1, 16, 112, 112], 31 | [16, 1, 3, 3], 32 | [1, 16, 112, 112], 33 | bias=bias, 34 | group=16, 35 | pads=[1, 1, 1, 1], 36 | ) 37 | 38 | 39 | def op_conv2d(filepath, shape_x, shape_w, shape_y, bias=False, **kwargs): 40 | inputs = [ 41 | helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x), 42 | helper.make_tensor_value_info("w", TensorProto.FLOAT, shape_w), 43 | ] 44 | if bias: 45 | inputs.append( 46 | helper.make_tensor_value_info("b", TensorProto.FLOAT, [shape_w[0]]) 47 | ) 48 | 49 | outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y)] 50 | nodes = [ 51 | helper.make_node( 52 | "Conv", 53 | ["x", "w", "b"] if bias else ["x", "w"], 54 | ["y"], 55 | kernel_shape=[shape_w[2], shape_w[3]], 56 | **kwargs, 57 | ) 58 | ] 59 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 60 | model = helper.make_model(graph) 61 | 62 | onnx.save(model, filepath) 63 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 64 | altius_sess = altius_py.InferenceSession(filepath) 65 | 66 | x = np.random.random_sample(shape_x).astype(np.float32) 67 | w = np.random.random_sample(shape_w).astype(np.float32) 68 | b = np.random.random_sample(shape_w[0]).astype(np.float32) if bias else None 69 | inputs = {"x": x, "w": w, "b": b} if bias else {"x": x, "w": w} 70 | expected = ort_sess.run(None, inputs) 71 | actual = altius_sess.run(None, inputs) 72 | 73 | for expected, actual in zip(expected, actual): 74 | assert np.allclose(expected, actual) 75 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_gather.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_gather_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_gather( 14 | os.path.join(tmpdir, "model.onnx"), 15 | [1, 5, 10], 16 | 2, 17 | [1, 1, 10], 18 | axis=1, 19 | ) 20 | 21 | 22 | def test_gather_2(): 23 | with tempfile.TemporaryDirectory() as tmpdir: 24 | op_gather( 25 | os.path.join(tmpdir, "model.onnx"), 26 | [5, 10], 27 | [1, 3], 28 | [2, 10], 29 | axis=0, 30 | ) 31 | 32 | 33 | def op_gather(filepath, shape_x, indices, shape_z, **kwargs): 34 | shape_y = [] if isinstance(indices, int) else [1, len(indices)] 35 | inputs = [ 36 | helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x), 37 | helper.make_tensor_value_info("y", TensorProto.INT64, shape_y), 38 | ] 39 | outputs = [helper.make_tensor_value_info("z", TensorProto.FLOAT, shape_z)] 40 | nodes = [helper.make_node("Gather", ["x", "y"], ["z"], **kwargs)] 41 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 42 | model = helper.make_model(graph) 43 | 44 | onnx.save(model, filepath) 45 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 46 | 47 | for backend in ["interpreter", "cpu"]: 48 | altius_sess = altius_py.InferenceSession(filepath, backend="cpu") 49 | 50 | x = np.random.random_sample(shape_x).astype(np.float32) 51 | y = np.array(indices).astype(np.int64).reshape(shape_y) 52 | inputs = {"x": x, "y": y} 53 | expected = ort_sess.run(None, inputs) 54 | actual = altius_sess.run(None, inputs) 55 | 56 | for expected, actual in zip(expected, actual): 57 | assert np.allclose(expected, actual) 58 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_gemm.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_gemm_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_gemm(os.path.join(tmpdir, "model.onnx"), [5, 10], [10, 15], [5, 15], [15]) 14 | 15 | 16 | def test_gemm_2(): 17 | with tempfile.TemporaryDirectory() as tmpdir: 18 | op_gemm(os.path.join(tmpdir, "model.onnx"), [5, 10], [10, 15], [5, 15], [5, 15]) 19 | 20 | 21 | def test_gemm_3(): 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | op_gemm(os.path.join(tmpdir, "model.onnx"), [5, 10], [10, 15], [5, 15]) 24 | 25 | 26 | # TODO 27 | # def test_gemm_3(): 28 | # with tempfile.TemporaryDirectory() as tmpdir: 29 | # op_gemm( 30 | # os.path.join(tmpdir, "model.onnx"), [3, 5, 10], [3, 10, 15], [3, 5, 15] 31 | # ) 32 | # 33 | # 34 | # def test_gemm_4(): 35 | # with tempfile.TemporaryDirectory() as tmpdir: 36 | # op_gemm(os.path.join(tmpdir, "model.onnx"), [1, 5, 10], [10, 15], [1, 5, 15]) 37 | 38 | 39 | def op_gemm(filepath, shape_a, shape_b, shape_y, shape_c=None): 40 | inputs = [ 41 | helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a), 42 | helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b), 43 | ] 44 | if shape_c: 45 | inputs.append(helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c)) 46 | outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y)] 47 | nodes = [ 48 | helper.make_node( 49 | "Gemm", 50 | ["a", "b", "c"] if shape_c else ["a", "b"], 51 | ["y"], 52 | ) 53 | ] 54 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 55 | model = helper.make_model(graph) 56 | 57 | onnx.save(model, filepath) 58 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 59 | 60 | for backend in ["interpreter", "cpu"]: 61 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 62 | 63 | a = np.random.random_sample(shape_a).astype(np.float32) 64 | b = np.random.random_sample(shape_b).astype(np.float32) 65 | inputs = {} 66 | if shape_c: 67 | c = np.random.random_sample(shape_c).astype(np.float32) 68 | inputs = {"a": a, "b": b, "c": c} 69 | else: 70 | inputs = {"a": a, "b": b} 71 | expected = ort_sess.run(None, inputs) 72 | actual = altius_sess.run(None, inputs) 73 | 74 | for expected, actual in zip(expected, actual): 75 | assert np.allclose(expected, actual) 76 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_matmul.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_matmul_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_matmul(os.path.join(tmpdir, "model.onnx"), [5, 10], [10, 15], [5, 15]) 14 | 15 | 16 | def test_matmul_2(): 17 | with tempfile.TemporaryDirectory() as tmpdir: 18 | op_matmul(os.path.join(tmpdir, "model.onnx"), [3, 5, 10], [10, 15], [3, 5, 15]) 19 | 20 | 21 | def test_matmul_3(): 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | op_matmul( 24 | os.path.join(tmpdir, "model.onnx"), [3, 5, 10], [3, 10, 15], [3, 5, 15] 25 | ) 26 | 27 | 28 | def test_matmul_4(): 29 | with tempfile.TemporaryDirectory() as tmpdir: 30 | op_matmul(os.path.join(tmpdir, "model.onnx"), [1, 5, 10], [10, 15], [1, 5, 15]) 31 | 32 | 33 | def op_matmul( 34 | filepath, 35 | shape_x, 36 | shape_y, 37 | shape_z, 38 | ): 39 | inputs = [ 40 | helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x), 41 | helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y), 42 | ] 43 | outputs = [helper.make_tensor_value_info("z", TensorProto.FLOAT, shape_z)] 44 | nodes = [ 45 | helper.make_node( 46 | "MatMul", 47 | ["x", "y"], 48 | ["z"], 49 | ) 50 | ] 51 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 52 | model = helper.make_model(graph) 53 | 54 | onnx.save(model, filepath) 55 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 56 | 57 | for backend in ["interpreter", "cpu"]: 58 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 59 | 60 | x = np.random.random_sample(shape_x).astype(np.float32) 61 | y = np.random.random_sample(shape_y).astype(np.float32) 62 | inputs = {"x": x, "y": y} 63 | expected = ort_sess.run(None, inputs) 64 | actual = altius_sess.run(None, inputs) 65 | 66 | for expected, actual in zip(expected, actual): 67 | assert np.allclose(expected, actual) 68 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_norm.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_batch_norm_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_batch_norm(os.path.join(tmpdir, "model.onnx"), [1, 20, 10, 10]) 14 | 15 | 16 | def test_layer_norm_1(): 17 | with tempfile.TemporaryDirectory() as tmpdir: 18 | op_layer_norm(os.path.join(tmpdir, "model.onnx"), [1, 20, 10]) 19 | 20 | 21 | def op_batch_norm(filepath, shape, **kwargs): 22 | assert len(shape) == 4 23 | inputs = [ 24 | helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), 25 | helper.make_tensor_value_info("scale", TensorProto.FLOAT, [shape[1]]), 26 | helper.make_tensor_value_info("bias", TensorProto.FLOAT, [shape[1]]), 27 | helper.make_tensor_value_info("mean", TensorProto.FLOAT, [shape[1]]), 28 | helper.make_tensor_value_info("var", TensorProto.FLOAT, [shape[1]]), 29 | ] 30 | outputs = [helper.make_tensor_value_info("z", TensorProto.FLOAT, shape)] 31 | nodes = [ 32 | helper.make_node( 33 | "BatchNormalization", ["x", "scale", "bias", "mean", "var"], ["z"], **kwargs 34 | ) 35 | ] 36 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 37 | model = helper.make_model(graph) 38 | 39 | onnx.save(model, filepath) 40 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 41 | 42 | for backend in ["interpreter", "cpu"]: 43 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 44 | 45 | x = np.random.random_sample(shape).astype(np.float32) 46 | scale = np.random.random_sample(shape[1]).astype(np.float32) 47 | bias = np.random.random_sample(shape[1]).astype(np.float32) 48 | mean = np.random.random_sample(shape[1]).astype(np.float32) 49 | var = np.random.random_sample(shape[1]).astype(np.float32) 50 | inputs = {"x": x, "scale": scale, "bias": bias, "mean": mean, "var": var} 51 | expected = ort_sess.run(None, inputs) 52 | actual = altius_sess.run(None, inputs) 53 | 54 | for expected, actual in zip(expected, actual): 55 | assert np.allclose(expected, actual, rtol=1e-4, atol=1e-5) 56 | 57 | 58 | def op_layer_norm(filepath, shape, **kwargs): 59 | shape_scale = [1] * (len(shape) - 1) + [shape[-1]] 60 | inputs = [ 61 | helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), 62 | helper.make_tensor_value_info("scale", TensorProto.FLOAT, shape_scale), 63 | helper.make_tensor_value_info("bias", TensorProto.FLOAT, shape_scale), 64 | ] 65 | outputs = [helper.make_tensor_value_info("z", TensorProto.FLOAT, shape)] 66 | nodes = [ 67 | helper.make_node("LayerNormalization", ["x", "scale", "bias"], ["z"], **kwargs) 68 | ] 69 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 70 | model = helper.make_model(graph) 71 | 72 | onnx.save(model, filepath) 73 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 74 | 75 | for backend in ["interpreter", "cpu"]: 76 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 77 | 78 | x = np.random.random_sample(shape).astype(np.float32) 79 | scale = np.random.random_sample(shape_scale).astype(np.float32) 80 | bias = np.random.random_sample(shape_scale).astype(np.float32) 81 | inputs = {"x": x, "scale": scale, "bias": bias} 82 | expected = ort_sess.run(None, inputs) 83 | actual = altius_sess.run(None, inputs) 84 | 85 | for expected, actual in zip(expected, actual): 86 | assert np.allclose(expected, actual, rtol=1e-4, atol=1e-5) 87 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_pool.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_maxpool_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_maxpool( 14 | os.path.join(tmpdir, "model.onnx"), 15 | [1, 3, 224, 224], 16 | [1, 3, 112, 112], 17 | kernel_shape=[2, 2], 18 | pads=[0, 0, 0, 0], 19 | strides=[2, 2], 20 | auto_pad="NOTSET", 21 | ) 22 | 23 | 24 | def test_maxpool_2(): 25 | with tempfile.TemporaryDirectory() as tmpdir: 26 | op_maxpool( 27 | os.path.join(tmpdir, "model.onnx"), 28 | [1, 256, 20, 20], 29 | [1, 256, 20, 20], 30 | kernel_shape=[5, 5], 31 | pads=[2, 2, 2, 2], 32 | strides=[1, 1], 33 | ceil_mode=0, 34 | ) 35 | 36 | 37 | def op_maxpool(filepath, shape_x, shape_y, **kwargs): 38 | inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x)] 39 | outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y)] 40 | nodes = [ 41 | helper.make_node( 42 | "MaxPool", 43 | ["x"], 44 | ["y"], 45 | **kwargs, 46 | ) 47 | ] 48 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 49 | model = helper.make_model(graph) 50 | 51 | onnx.save(model, filepath) 52 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 53 | 54 | for backend in ["interpreter", "cpu"]: 55 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 56 | 57 | x = np.random.random_sample(shape_x).astype(np.float32) 58 | inputs = {"x": x} 59 | expected = ort_sess.run(None, inputs) 60 | actual = altius_sess.run(None, inputs) 61 | 62 | for expected, actual in zip(expected, actual): 63 | assert np.allclose(expected, actual) 64 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_reduce.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import os 6 | import numpy as np 7 | from onnx import helper, TensorProto 8 | 9 | 10 | def test_reduce_mean_1(): 11 | with tempfile.TemporaryDirectory() as tmpdir: 12 | op_reduce( 13 | os.path.join(tmpdir, "model.onnx"), 14 | "ReduceMean", 15 | [1, 50, 70], 16 | [1, 50, 1], 17 | axes=[-1], 18 | ) 19 | 20 | 21 | def test_reduce_mean_2(): 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | op_reduce( 24 | os.path.join(tmpdir, "model.onnx"), 25 | "ReduceMean", 26 | [8, 4, 5, 5], 27 | [8, 4, 1, 1], 28 | axes=[2, 3], 29 | backends=["cpu"], 30 | ) 31 | 32 | 33 | def test_reduce_max_1(): 34 | with tempfile.TemporaryDirectory() as tmpdir: 35 | op_reduce( 36 | os.path.join(tmpdir, "model.onnx"), 37 | "ReduceMax", 38 | [1, 50, 70], 39 | [], 40 | keepdims=0, 41 | ) 42 | 43 | 44 | def op_reduce( 45 | filepath, op_type, shape_x, shape_y, backends=["interpreter", "cpu"], **kwargs 46 | ): 47 | inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x)] 48 | outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y)] 49 | nodes = [helper.make_node(op_type, ["x"], ["y"], **kwargs)] 50 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 51 | model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) 52 | 53 | onnx.checker.check_model(model) 54 | onnx.save(model, filepath) 55 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 56 | 57 | for backend in backends: 58 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 59 | 60 | x = np.random.random_sample(shape_x).astype(np.float32) 61 | inputs = {"x": x} 62 | expected = ort_sess.run(None, inputs) 63 | actual = altius_sess.run(None, inputs) 64 | 65 | for expected, actual in zip(expected, actual): 66 | assert np.allclose(expected, actual) 67 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_resize.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import pytest 3 | import os 4 | 5 | import numpy as np 6 | 7 | import onnxruntime as ort 8 | import onnx 9 | from onnx import helper, ValueInfoProto, TensorProto, numpy_helper 10 | import altius_py 11 | 12 | 13 | def test_resize_1(): 14 | with tempfile.TemporaryDirectory() as tmpdir: 15 | op_resize( 16 | os.path.join(tmpdir, "model.onnx"), 17 | [1, 256, 20, 20], 18 | [1, 256, 40, 40], 19 | np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32), 20 | coordinate_transformation_mode="asymmetric", 21 | cubic_coeff_a=-0.75, 22 | mode="nearest", 23 | nearest_mode="floor", 24 | ) 25 | 26 | 27 | def op_resize(filepath, shape_x, shape_y, scales, **kwargs): 28 | inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x)] 29 | outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y)] 30 | nodes = [ 31 | helper.make_node( 32 | "Resize", 33 | ["x", "roi", "scales"], 34 | ["y"], 35 | **kwargs, 36 | ) 37 | ] 38 | 39 | roi = numpy_helper.from_array(np.array([], dtype=np.float32), name="roi") 40 | scales = numpy_helper.from_array(scales, name="scales") 41 | graph = helper.make_graph( 42 | nodes, "graph", inputs, outputs, initializer=[roi, scales] 43 | ) 44 | model = helper.make_model(graph) 45 | 46 | onnx.save(model, filepath) 47 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 48 | altius_sess = altius_py.InferenceSession(filepath) 49 | 50 | x = np.random.random_sample(shape_x).astype(np.float32) 51 | inputs = {"x": x} 52 | expected = ort_sess.run(None, inputs) 53 | actual = altius_sess.run(None, inputs) 54 | 55 | for expected, actual in zip(expected, actual): 56 | assert np.allclose(expected, actual) 57 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_transpose.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_transpose_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_transpose( 14 | os.path.join(tmpdir, "model.onnx"), 15 | [50, 12, 64], 16 | [12, 64, 50], 17 | perm=[1, 2, 0], 18 | ) 19 | 20 | 21 | def test_transpose_2(): 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | op_transpose( 24 | os.path.join(tmpdir, "model.onnx"), 25 | [12, 64], 26 | [64, 12], 27 | perm=[1, 0], 28 | ) 29 | 30 | 31 | def test_transpose_3(): 32 | with tempfile.TemporaryDirectory() as tmpdir: 33 | op_transpose( 34 | os.path.join(tmpdir, "model.onnx"), 35 | [12, 64, 3, 5], 36 | [3, 64, 12, 5], 37 | perm=[2, 1, 0, 3], 38 | ) 39 | 40 | 41 | def test_transpose_4(): 42 | with tempfile.TemporaryDirectory() as tmpdir: 43 | op_transpose( 44 | os.path.join(tmpdir, "model.onnx"), 45 | [50, 12], 46 | [50, 12], 47 | perm=[0, 1], 48 | ) 49 | 50 | 51 | def op_transpose(filepath, shape_x, shape_y, **kwargs): 52 | inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x)] 53 | outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y)] 54 | nodes = [helper.make_node("Transpose", ["x"], ["y"], **kwargs)] 55 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 56 | model = helper.make_model(graph) 57 | 58 | onnx.checker.check_model(model) 59 | onnx.save(model, filepath) 60 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 61 | 62 | for backend in ["interpreter", "cpu"]: 63 | altius_sess = altius_py.InferenceSession(filepath, backend=backend) 64 | 65 | x = np.random.random_sample(shape_x).astype(np.float32) 66 | y = np.random.random_sample(shape_y).astype(np.float32) 67 | inputs = {"x": x} 68 | expected = ort_sess.run(None, inputs) 69 | actual = altius_sess.run(None, inputs) 70 | 71 | for expected, actual in zip(expected, actual): 72 | assert np.allclose(expected, actual) 73 | -------------------------------------------------------------------------------- /crates/altius_py/tests/test_ops_where.py: -------------------------------------------------------------------------------- 1 | import altius_py 2 | import onnxruntime as ort 3 | import onnx 4 | import tempfile 5 | import pytest 6 | import os 7 | import numpy as np 8 | from onnx import helper, ValueInfoProto, TensorProto 9 | 10 | 11 | def test_where_1(): 12 | with tempfile.TemporaryDirectory() as tmpdir: 13 | op_where( 14 | os.path.join(tmpdir, "model.onnx"), 15 | [1, 1, 10, 10], 16 | [1, 128, 10, 10], 17 | [1], 18 | ) 19 | 20 | 21 | def test_where_2(): 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | op_where( 24 | os.path.join(tmpdir, "model.onnx"), 25 | [1, 1, 1, 1], 26 | [1, 128, 1, 1], 27 | [1], 28 | ) 29 | 30 | 31 | def op_where(filepath, shape_c, shape_x, shape_y, **kwargs): 32 | inputs = [ 33 | helper.make_tensor_value_info("c", TensorProto.BOOL, shape_c), 34 | helper.make_tensor_value_info("x", TensorProto.FLOAT, shape_x), 35 | helper.make_tensor_value_info("y", TensorProto.FLOAT, shape_y), 36 | ] 37 | outputs = [helper.make_tensor_value_info("z", TensorProto.FLOAT, shape_x)] 38 | nodes = [helper.make_node("Where", ["c", "x", "y"], ["z"], **kwargs)] 39 | graph = helper.make_graph(nodes, "graph", inputs, outputs) 40 | model = helper.make_model(graph) 41 | 42 | onnx.checker.check_model(model) 43 | onnx.save(model, filepath) 44 | ort_sess = ort.InferenceSession(filepath, providers=["CPUExecutionProvider"]) 45 | altius_sess = altius_py.InferenceSession(filepath) 46 | 47 | c = np.random.choice(a=[False, True], size=shape_c) 48 | x = np.random.random_sample(shape_x).astype(np.float32) 49 | y = np.random.random_sample(shape_y).astype(np.float32) 50 | inputs = {"c": c, "x": x, "y": y} 51 | expected = ort_sess.run(None, inputs) 52 | actual = altius_sess.run(None, inputs) 53 | 54 | for expected, actual in zip(expected, actual): 55 | assert np.allclose(expected, actual) 56 | -------------------------------------------------------------------------------- /crates/altius_py/translation.py: -------------------------------------------------------------------------------- 1 | # python -m optimum.exporters.onnx --model "staka/fugumt-en-ja" --for-ort fugu 2 | 3 | import time 4 | import logging 5 | import os 6 | import sys 7 | 8 | from transformers import pipeline 9 | from transformers import MarianTokenizer 10 | import onnxruntime as ort 11 | import numpy as np 12 | 13 | import torch 14 | from torch.nn import functional as F 15 | 16 | 17 | def translate_baseline(text_en): 18 | fugu_translator = pipeline("translation", model="staka/fugumt-en-ja", device="cpu") 19 | result = fugu_translator(text_en)[0]["translation_text"] 20 | return result 21 | 22 | 23 | def translate_onnx(text_en): 24 | tokenizer = MarianTokenizer.from_pretrained("staka/fugumt-en-ja") 25 | 26 | use_altius = False 27 | if use_altius: 28 | import altius_py 29 | 30 | os.environ["GOMP_CPU_AFFINITY"] = "0-7" 31 | encoder = altius_py.InferenceSession( 32 | "./fugumt-en-ja/encoder_model.onnx", 33 | intra_op_num_threads=8, 34 | enable_profile=True, 35 | # backend="cpu" 36 | ) 37 | decoder = altius_py.InferenceSession( 38 | "./fugumt-en-ja/decoder_model.onnx", 39 | intra_op_num_threads=8, 40 | enable_profile=True, 41 | # backend="cpu" 42 | ) 43 | else: 44 | encoder = ort.InferenceSession( 45 | "./fugumt-en-ja/encoder_model.onnx", providers=["CPUExecutionProvider"] 46 | ) 47 | decoder = ort.InferenceSession( 48 | "./fugumt-en-ja/decoder_model.onnx", providers=["CPUExecutionProvider"] 49 | ) 50 | 51 | max_tokens = 100 52 | text = text_en 53 | text += "" 54 | 55 | inputs = tokenizer( 56 | text, 57 | return_tensors="np", 58 | padding=False, 59 | add_special_tokens=False, 60 | ) 61 | len_ = inputs["input_ids"].shape[1] 62 | 63 | assert len_ < max_tokens 64 | 65 | if len_ >= max_tokens: 66 | raise Exception("Too long") 67 | 68 | for name in ["input_ids", "attention_mask"]: 69 | input = np.zeros((1, max_tokens), dtype=np.int64) 70 | input[0, : inputs[name].shape[1]] = inputs[name] 71 | inputs[name] = input 72 | 73 | last_hidden_state = encoder.run(None, dict(inputs))[0] 74 | 75 | translated_text = "" 76 | for i in range(100): 77 | decoder_text = tokenizer( 78 | translated_text, 79 | return_tensors="np", 80 | padding=False, 81 | text_target="ja", 82 | add_special_tokens=False, 83 | ) 84 | len_ = decoder_text["input_ids"].shape[1] 85 | 86 | for name in ["input_ids", "attention_mask"]: 87 | input = np.zeros((1, max_tokens), dtype=np.int64) 88 | input[0, : decoder_text[name].shape[1]] = decoder_text[name] 89 | decoder_text[name] = input 90 | 91 | outputs = decoder.run( 92 | None, 93 | { 94 | "encoder_attention_mask": inputs["attention_mask"], 95 | "input_ids": decoder_text["input_ids"].reshape(1, -1), 96 | "encoder_hidden_states": last_hidden_state, 97 | }, 98 | ) 99 | 100 | if i >= len_: 101 | break 102 | 103 | next_token_logits = outputs[0][:, i, :32000] 104 | 105 | probs = F.softmax(torch.tensor(next_token_logits), dim=-1) 106 | ids = torch.argsort(-probs[0]) 107 | for i in ids: 108 | if i == 2: 109 | continue 110 | if i == tokenizer.pad_token_id: 111 | print("PAD!") 112 | continue 113 | id = i 114 | break 115 | resulting_string = tokenizer.decode( 116 | [id], 117 | skip_special_tokens=True, # clean_up_tokenization_spaces=False 118 | ) 119 | print(resulting_string) 120 | translated_text += resulting_string 121 | 122 | _, translated_text = translated_text.split("") 123 | 124 | return translated_text 125 | 126 | 127 | def main(): 128 | text = "Attention is all you need." 129 | 130 | baseline_result = translate_baseline(text) 131 | onnx_result = translate_onnx(text) 132 | print(f"baseline: {baseline_result}") 133 | print(f"onnx: {onnx_result}") 134 | 135 | 136 | if __name__ == "__main__": 137 | logging.basicConfig(level=logging.INFO) 138 | 139 | main() 140 | -------------------------------------------------------------------------------- /crates/altius_py/vit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | import logging 5 | import time 6 | 7 | import numpy as np 8 | from torchvision import transforms 9 | from PIL import Image 10 | 11 | import altius_py 12 | import onnxruntime as ort 13 | 14 | 15 | def main(): 16 | logging.basicConfig(level=logging.INFO) 17 | os.environ["OMP_PROC_BIND"] = "TRUE" 18 | os.environ["BLIS_NUM_THREADS"] = "1" # Increase this number 19 | 20 | labels = open("../../models/imagenet_classes.txt").readlines() 21 | image = Image.open("../../models/cat.png") 22 | 23 | preprocess = transforms.Compose( 24 | [ 25 | transforms.Resize(224), 26 | transforms.CenterCrop(224), 27 | transforms.ToTensor(), 28 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 29 | ] 30 | ) 31 | input = preprocess(image) 32 | input = input.unsqueeze(0).numpy() 33 | 34 | # opt = ort.SessionOptions() 35 | # # opt.intra_op_num_threads = 1 36 | # # opt.inter_op_num_threads = 1 37 | # sess = ort.InferenceSession("../../models/vit_b_16.onnx", sess_options=opt) 38 | sess = altius_py.InferenceSession("../../models/vit_b_16.onnx", True) 39 | 40 | inputs = {"x": input} 41 | start = time.time() 42 | output = sess.run(None, inputs)[0].reshape(1000) 43 | print(f"elapsed: {time.time() - start}") 44 | output = np.argsort(output)[::-1][:5] 45 | output = [labels[i].strip() for i in output] 46 | print(f"top5: {output}") 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /crates/core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "altius-core" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | id-arena = "^2.2.1" 8 | rustc-hash = { workspace = true } 9 | prost = "^0.10" 10 | thiserror = { workspace = true } 11 | log = { workspace = true } 12 | rand = "^0.8.5" 13 | ndarray = { workspace = true } 14 | 15 | [build-dependencies] 16 | prost-build = "^0.10" 17 | 18 | [dev-dependencies] 19 | insta = "^1.14.1" 20 | -------------------------------------------------------------------------------- /crates/core/build.rs: -------------------------------------------------------------------------------- 1 | extern crate prost_build; 2 | 3 | fn main() { 4 | prost_build::compile_protos(&["src/onnx/onnx.proto"], &["src/"]).unwrap(); 5 | } 6 | -------------------------------------------------------------------------------- /crates/core/src/analysis/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod shape; 2 | -------------------------------------------------------------------------------- /crates/core/src/dim.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fmt, 3 | ops::{Deref, Index, IndexMut}, 4 | slice::SliceIndex, 5 | }; 6 | 7 | use crate::fixed_dim::{FixedDimension, FixedDimensions}; 8 | 9 | #[derive(Clone, PartialEq, Eq, Hash)] 10 | pub enum Dimension { 11 | Fixed(FixedDimension), 12 | Dynamic(String), 13 | } 14 | 15 | /// An alternative to `FixedDimensions` that allows dynamic shape. 16 | #[derive(Clone, PartialEq, Eq, Hash)] 17 | pub struct Dimensions(pub Vec); 18 | 19 | impl Dimensions { 20 | pub const fn new(dims: Vec) -> Self { 21 | Self(dims) 22 | } 23 | 24 | pub fn is_fixed(&self) -> bool { 25 | self.0.iter().all(|d| matches!(d, Dimension::Fixed(_))) 26 | } 27 | 28 | pub fn is_dynamic(&self) -> bool { 29 | self.0.iter().any(|d| matches!(d, Dimension::Dynamic(_))) 30 | } 31 | 32 | pub fn as_fixed_dims(&self) -> Option { 33 | if self.is_dynamic() { 34 | return None; 35 | } 36 | 37 | Some(FixedDimensions( 38 | self.iter() 39 | .map(|d| match d { 40 | Dimension::Fixed(d) => *d, 41 | Dimension::Dynamic(_) => unreachable!(), 42 | }) 43 | .collect(), 44 | )) 45 | } 46 | } 47 | 48 | impl AsRef for Dimensions { 49 | fn as_ref(&self) -> &Dimensions { 50 | self 51 | } 52 | } 53 | 54 | impl Index for Dimensions 55 | where 56 | I: SliceIndex<[Dimension]>, 57 | { 58 | type Output = >::Output; 59 | 60 | fn index(&self, index: I) -> &Self::Output { 61 | &self.0[index] 62 | } 63 | } 64 | 65 | impl IndexMut for Dimensions 66 | where 67 | I: SliceIndex<[Dimension]>, 68 | { 69 | fn index_mut(&mut self, index: I) -> &mut Self::Output { 70 | &mut self.0[index] 71 | } 72 | } 73 | 74 | impl From> for Dimensions { 75 | fn from(v: Vec) -> Dimensions { 76 | Dimensions(v) 77 | } 78 | } 79 | 80 | impl Deref for Dimensions { 81 | type Target = Vec; 82 | fn deref(&self) -> &Self::Target { 83 | &self.0 84 | } 85 | } 86 | 87 | impl fmt::Debug for Dimension { 88 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 89 | match self { 90 | Dimension::Fixed(d) => write!(f, "{d}"), 91 | Dimension::Dynamic(s) => write!(f, "{s}"), 92 | } 93 | } 94 | } 95 | 96 | impl fmt::Debug for Dimensions { 97 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 98 | write!(f, "{:?}", self.0) 99 | } 100 | } 101 | 102 | #[test] 103 | fn use_symdims() { 104 | let _ = Dimensions(vec![ 105 | Dimension::Dynamic("batch".into()), 106 | Dimension::Fixed(8), 107 | ]); 108 | } 109 | -------------------------------------------------------------------------------- /crates/core/src/flops.rs: -------------------------------------------------------------------------------- 1 | use rustc_hash::FxHashMap; 2 | 3 | use crate::{ 4 | analysis::shape::{ShapeError, infer_shapes}, 5 | model::Model, 6 | op::Op, 7 | }; 8 | 9 | pub fn compute_flops(model: &Model) -> Result { 10 | let nodes = model.topo_sort_nodes(); // TODO: Dead node elimination 11 | let mut inferred_shapes = FxHashMap::default(); 12 | let mut value_shapes = FxHashMap::default(); 13 | infer_shapes(model, &mut inferred_shapes, &mut value_shapes)?; 14 | let mut flops = 0; 15 | for node_id in nodes { 16 | let node = &model.graph.nodes[node_id]; 17 | flops += match &node.op { 18 | Op::MatMul => { 19 | let a_shape = &value_shapes[&node.inputs[0]]; 20 | let b_shape = &value_shapes[&node.inputs[1]]; 21 | let m = a_shape.dims[a_shape.dims.len() - 2]; 22 | let k = a_shape.dims[a_shape.dims.len() - 1]; 23 | let n = b_shape.dims[b_shape.dims.len() - 1]; 24 | let rem = a_shape.dims[..a_shape.dims.len() - 2] 25 | .iter() 26 | .product::(); 27 | 2 * rem * m * n * k 28 | } 29 | Op::Conv2d(c) => { 30 | let input_shape = &value_shapes[&node.inputs[0]]; 31 | let kernel_shape = &value_shapes[&node.inputs[1]]; 32 | let output_shape = &value_shapes[&node.outputs[0]]; 33 | output_shape.dims.total_elems() 34 | * (input_shape.dims[1] / c.group as usize 35 | * kernel_shape.dims[2..].iter().product::()) 36 | * (1 + (node.inputs.len() == 3) as usize) 37 | } 38 | Op::Gemm(_) => { 39 | let a_shape = &value_shapes[&node.inputs[0]]; 40 | let b_shape = &value_shapes[&node.inputs[1]]; 41 | assert_eq!(a_shape.dims.len(), 2); 42 | assert_eq!(b_shape.dims.len(), 2); 43 | let m = a_shape.dims[0]; 44 | let k = a_shape.dims[1]; 45 | let n = b_shape.dims[1]; 46 | 2 * m * n * k + 3 * m * n 47 | } 48 | _ => 0, 49 | }; 50 | } 51 | Ok(flops) 52 | } 53 | 54 | #[test] 55 | fn test_compute_flops() { 56 | let model = Model::default(); 57 | let flops = compute_flops(&model).unwrap(); 58 | assert_eq!(flops, 0); 59 | } 60 | -------------------------------------------------------------------------------- /crates/core/src/graph.rs: -------------------------------------------------------------------------------- 1 | use rustc_hash::FxHashMap as HashMap; 2 | 3 | use crate::{ 4 | node::{Node, NodeArena, NodeId}, 5 | tensor::Tensor, 6 | value::{ValueArena, ValueId}, 7 | }; 8 | 9 | #[derive(Default, Clone)] 10 | pub struct Graph { 11 | pub nodes: NodeArena, 12 | pub values: ValueArena, 13 | pub inits: HashMap, 14 | pub inputs: Vec, 15 | pub outputs: Vec, 16 | } 17 | 18 | impl Graph { 19 | pub fn add_node(&mut self, node: Node) -> NodeId { 20 | self.nodes.alloc(node) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /crates/core/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::excessive_precision)] 2 | 3 | pub mod analysis; 4 | pub mod dim; 5 | pub mod fixed_dim; 6 | pub mod flops; 7 | pub mod graph; 8 | pub mod model; 9 | pub mod node; 10 | pub mod onnx; 11 | pub mod op; 12 | pub mod optimize; 13 | pub mod tensor; 14 | pub mod value; 15 | -------------------------------------------------------------------------------- /crates/core/src/node.rs: -------------------------------------------------------------------------------- 1 | use crate::{op::Op, value::ValueId}; 2 | use id_arena::{Arena, Id}; 3 | 4 | pub type NodeId = Id; 5 | pub type NodeArena = Arena; 6 | 7 | #[derive(Debug, Clone)] 8 | pub struct Node { 9 | pub op: Op, 10 | pub name: Option, 11 | pub inputs: Vec, 12 | pub outputs: Vec, 13 | pub deleted: bool, 14 | } 15 | 16 | impl Node { 17 | pub fn new(op: Op) -> Self { 18 | Self { 19 | op, 20 | name: None, 21 | inputs: Vec::new(), 22 | outputs: Vec::new(), 23 | deleted: false, 24 | } 25 | } 26 | 27 | pub fn with_name(mut self, name: impl Into>) -> Self { 28 | self.name = name.into(); 29 | self 30 | } 31 | 32 | pub fn with_in(mut self, id: ValueId) -> Self { 33 | self.inputs.push(id); 34 | self 35 | } 36 | 37 | pub fn with_ins(mut self, mut ids: Vec) -> Self { 38 | self.inputs.append(&mut ids); 39 | self 40 | } 41 | 42 | pub fn with_out(mut self, id: ValueId) -> Self { 43 | self.outputs.push(id); 44 | self 45 | } 46 | 47 | pub fn with_outs(mut self, mut ids: Vec) -> Self { 48 | self.outputs.append(&mut ids); 49 | self 50 | } 51 | 52 | pub fn alloc(self, arena: &mut NodeArena) -> NodeId { 53 | arena.alloc(self) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /crates/core/src/onnx/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod load; 2 | pub mod save; 3 | pub use load::{load_onnx, load_onnx_from_buffer}; 4 | -------------------------------------------------------------------------------- /crates/core/src/optimize/conv_act_fusion.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use crate::{ 4 | model::Model, 5 | op::{FusedActivation, Op}, 6 | }; 7 | 8 | pub fn fuse_conv_act(model: &mut Model) { 9 | let start = Instant::now(); 10 | let nodes = model.topo_sort_nodes(); 11 | let value_users = model.get_value_users(); 12 | 13 | let mut list = vec![]; 14 | let mut delete_list = vec![]; 15 | 16 | for node_id in nodes { 17 | let conv_id = node_id; 18 | let conv = &model.graph.nodes[conv_id]; 19 | if !matches!(conv.op, Op::Conv2d(_)) { 20 | continue; 21 | } 22 | if value_users[&conv.outputs[0]].len() != 1 { 23 | continue; 24 | } 25 | 26 | let act_id = value_users[&conv.outputs[0]] 27 | .iter() 28 | .next() 29 | .copied() 30 | .unwrap(); 31 | let act = &model.graph.nodes[act_id]; 32 | let fused_act = match act.op { 33 | Op::ReLU => FusedActivation::Relu, 34 | Op::HardSigmoid(ref h) => FusedActivation::HardSigmoid(*h), 35 | _ => continue, 36 | }; 37 | 38 | // Conv+Activation Detected! 39 | 40 | list.push((fused_act, conv_id, conv.outputs[0], act.outputs[0])); 41 | delete_list.push(act_id); 42 | } 43 | 44 | let count = list.len(); 45 | 46 | for (fused_act, conv_id, conv_out, act) in list { 47 | if let Op::Conv2d(c) = &mut model.graph.nodes[conv_id].op { 48 | c.activation = Some(fused_act); 49 | } 50 | 51 | for user_id in &value_users[&act] { 52 | let user = &mut model.graph.nodes[*user_id]; 53 | let idx = user.inputs.iter().position(|&i| i == act).unwrap(); 54 | user.inputs[idx] = conv_out 55 | } 56 | } 57 | 58 | for node in delete_list { 59 | model.graph.nodes[node].deleted = true 60 | } 61 | 62 | model.remove_unnecessary_nodes(); 63 | 64 | log::info!("fuse_conv_act({count}): {:?}", start.elapsed()); 65 | } 66 | -------------------------------------------------------------------------------- /crates/core/src/optimize/elemwise_fusion.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use rustc_hash::{FxHashMap, FxHashSet}; 4 | 5 | use crate::{ 6 | analysis::shape::{ShapeError, infer_shapes}, 7 | model::Model, 8 | node::{Node, NodeId}, 9 | op::{FusedElemwise, Op}, 10 | }; 11 | 12 | pub fn fuse_elemwise_ops(model: &mut Model) -> Result<(), ShapeError> { 13 | let start = Instant::now(); 14 | let nodes = model.topo_sort_nodes(); 15 | let value_users = model.get_value_users(); 16 | let value_parents = model.get_value_parents(); 17 | 18 | let mut value_shapes = FxHashMap::default(); 19 | infer_shapes(model, &mut FxHashMap::default(), &mut value_shapes)?; 20 | 21 | let mut list = vec![]; 22 | let mut visited: FxHashSet = FxHashSet::default(); 23 | 24 | for node_id in nodes { 25 | if visited.contains(&node_id) { 26 | continue; 27 | } 28 | 29 | let mut fusible_nodes = vec![]; 30 | let mut last_node_id = None; 31 | let mut cur_node_id = node_id; 32 | loop { 33 | let node = &model.graph.nodes[cur_node_id]; 34 | let fusible = node.op.is_elemwise() 35 | && node.outputs.len() == 1 36 | && node.inputs.iter().all(|id| { 37 | // The input is either: 38 | // - an initializer 39 | // - a value from a previous node 40 | // - a value of the first node in the chain 41 | !value_parents.contains_key(id) 42 | || (last_node_id.is_some() && Some(value_parents[id]) == last_node_id) 43 | || last_node_id.is_none() 44 | }) 45 | && (last_node_id.is_none_or(|last_node_id| { 46 | let last_node = &model.graph.nodes[last_node_id]; 47 | last_node.inputs.len() == 2 48 | && last_node.outputs[0] == node.inputs[0] 49 | && (node.inputs.len() == 1 50 | || (node.inputs.len() == 2 51 | && value_shapes[&node.inputs[1]] 52 | == value_shapes[&last_node.inputs[1]])) 53 | })); 54 | let end_of_chain = fusible 55 | && value_users 56 | .get(&node.outputs[0]) 57 | .is_none_or(|users| users.len() != 1); 58 | if fusible { 59 | fusible_nodes.push(cur_node_id); 60 | last_node_id = Some(cur_node_id); 61 | if !end_of_chain { 62 | cur_node_id = *value_users[&node.outputs[0]].iter().next().unwrap(); 63 | continue; 64 | } 65 | } 66 | break; 67 | } 68 | 69 | if fusible_nodes.len() > 1 { 70 | visited.extend(fusible_nodes.iter()); 71 | list.push(fusible_nodes); 72 | } 73 | } 74 | 75 | #[cfg(debug_assertions)] 76 | for nodes in list.iter() { 77 | log::debug!( 78 | "Fusible chain: {}", 79 | nodes 80 | .iter() 81 | .map(|&id| model.graph.nodes[id] 82 | .name 83 | .as_deref() 84 | .unwrap_or(model.graph.nodes[id].op.name())) 85 | .collect::>() 86 | .join(" -> ") 87 | ); 88 | } 89 | 90 | let count = list.len(); 91 | 92 | for chain in list { 93 | let mut input_map = Vec::new(); 94 | { 95 | let mut prev_node_id = None; 96 | for &node_id in &chain { 97 | let node = &model.graph.nodes[node_id]; 98 | if let Some(prev) = prev_node_id { 99 | input_map.extend( 100 | node.inputs 101 | .iter() 102 | .filter(|i| !model.graph.nodes[prev].outputs.contains(i)), 103 | ); 104 | } else { 105 | input_map.extend(node.inputs.iter()); 106 | } 107 | prev_node_id = Some(node_id); 108 | } 109 | 110 | // Deduplicate values 111 | let mut present = FxHashSet::default(); 112 | input_map.retain(|&id| present.insert(id)); 113 | } 114 | let last_node = &model.graph.nodes[*chain.last().unwrap()]; 115 | 116 | let fused_elemwise = Node::new(Op::FusedElemwise(FusedElemwise { 117 | input_map: input_map.clone(), 118 | chain: chain 119 | .iter() 120 | .map(|&id| { 121 | let node = &model.graph.nodes[id]; 122 | (node.op.clone(), node.inputs.clone(), node.outputs.clone()) 123 | }) 124 | .collect(), 125 | })) 126 | .with_ins(input_map) 127 | .with_out(last_node.outputs[0]); 128 | model.graph.add_node(fused_elemwise); 129 | 130 | for &node_id in &chain { 131 | model.graph.nodes[node_id].deleted = true; 132 | } 133 | } 134 | 135 | log::info!("fuse_elemwise_ops({count}): {:?}", start.elapsed()); 136 | 137 | Ok(()) 138 | } 139 | -------------------------------------------------------------------------------- /crates/core/src/optimize/gelu_fusion.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use rustc_hash::{FxHashMap, FxHashSet}; 4 | 5 | use crate::{ 6 | model::Model, 7 | node::{Node, NodeId}, 8 | op::Op, 9 | value::ValueId, 10 | }; 11 | 12 | fn extract_gelu( 13 | model: &Model, 14 | value_users: &FxHashMap>, 15 | root_id: NodeId, 16 | ) -> Option<(ValueId, ValueId, Vec)> { 17 | let div = &model.graph.nodes[root_id]; 18 | let approx_sqrt_two = 1.4142099618911743f32; 19 | if div.op != Op::Div 20 | || model 21 | .graph 22 | .inits 23 | .get(&div.inputs[1]) 24 | .is_none_or(|rhs| !rhs.allclose(&[approx_sqrt_two])) 25 | { 26 | return None; 27 | } 28 | 29 | let erf_id = value_users[&div.outputs[0]].iter().next().copied()?; 30 | let erf = &model.graph.nodes[erf_id]; 31 | if erf.op != Op::Erf { 32 | return None; 33 | } 34 | 35 | let add_id = value_users[&erf.outputs[0]].iter().next().copied()?; 36 | let add = &model.graph.nodes[add_id]; 37 | if add.op != Op::Add { 38 | return None; 39 | } 40 | let is_erf_add_lhs = add.inputs[0] == erf.outputs[0]; 41 | if model 42 | .graph 43 | .inits 44 | .get(&add.inputs[is_erf_add_lhs as usize]) 45 | .is_none_or(|one| !one.elem_ty().is_f32() || one.data::()[0] != 1.) 46 | { 47 | return None; 48 | } 49 | 50 | let mul1_id = value_users[&add.outputs[0]].iter().next().copied()?; 51 | let mul1 = &model.graph.nodes[mul1_id]; 52 | if mul1.op != Op::Mul { 53 | return None; 54 | } 55 | let is_add_mul1_lhs = mul1.inputs[0] == add.outputs[0]; 56 | if mul1.inputs[is_add_mul1_lhs as usize] != div.inputs[0] { 57 | return None; 58 | } 59 | 60 | let mul2_id = value_users[&mul1.outputs[0]].iter().next().copied()?; 61 | let mul2 = &model.graph.nodes[mul2_id]; 62 | if mul2.op != Op::Mul { 63 | return None; 64 | } 65 | let is_mul1_mul2_lhs = mul2.inputs[0] == mul1.outputs[0]; 66 | if model 67 | .graph 68 | .inits 69 | .get(&mul2.inputs[is_mul1_mul2_lhs as usize]) 70 | .is_none_or(|half| !half.elem_ty().is_f32() || half.data::()[0] != 0.5) 71 | { 72 | return None; 73 | } 74 | 75 | // Gelu Detected! 76 | 77 | Some(( 78 | div.inputs[0], 79 | mul2.outputs[0], 80 | vec![root_id, erf_id, add_id, mul1_id, mul2_id], 81 | )) 82 | } 83 | 84 | pub fn fuse_gelu(model: &mut Model) { 85 | let start = Instant::now(); 86 | let nodes = model.topo_sort_nodes(); 87 | let value_users = model.get_value_users(); 88 | 89 | let mut subgraphs = vec![]; 90 | let mut unnecessary_nodes = vec![]; 91 | 92 | for node_id in nodes { 93 | if let Some((start, end, nodes)) = extract_gelu(model, &value_users, node_id) { 94 | subgraphs.push((start, end)); 95 | unnecessary_nodes.extend(nodes); 96 | } 97 | } 98 | 99 | let count = subgraphs.len(); 100 | 101 | for (start, end) in subgraphs { 102 | let gelu_out = model.graph.values.new_val(); 103 | let gelu = Node::new(Op::Gelu).with_in(start).with_out(gelu_out); 104 | model.graph.add_node(gelu); 105 | 106 | let Some(users) = value_users.get(&end) else { 107 | for output in &mut model.graph.outputs { 108 | if *output == end { 109 | *output = gelu_out; 110 | } 111 | } 112 | continue; 113 | }; 114 | for &user_id in users { 115 | let user = &mut model.graph.nodes[user_id]; 116 | for input in &mut user.inputs { 117 | if *input == end { 118 | *input = gelu_out; 119 | } 120 | } 121 | } 122 | } 123 | 124 | for node in unnecessary_nodes { 125 | model.graph.nodes[node].deleted = true 126 | } 127 | 128 | model.remove_unnecessary_nodes(); 129 | 130 | log::info!("fuse_gelu({count}): {:?}", start.elapsed()); 131 | } 132 | -------------------------------------------------------------------------------- /crates/core/src/optimize/identity_elim.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use crate::{model::Model, op::Op}; 4 | 5 | pub fn eliminate_identity(model: &mut Model) { 6 | let start = Instant::now(); 7 | 8 | let value_users = model.get_value_users(); 9 | let nodes = model.topo_sort_nodes(); 10 | for node_id in nodes { 11 | let node = &model.graph.nodes[node_id]; 12 | if node.op != Op::Identity { 13 | continue; 14 | } 15 | assert!(node.inputs.len() == 1); 16 | assert!(node.outputs.len() == 1); 17 | 18 | let id_in = node.inputs[0]; 19 | let id_out = node.outputs[0]; 20 | if let Some(users) = value_users.get(&id_out) { 21 | for &uid in users { 22 | let user = &mut model.graph.nodes[uid]; 23 | for input in &mut user.inputs { 24 | if *input == id_out { 25 | *input = id_in; 26 | } 27 | } 28 | } 29 | } else { 30 | continue; 31 | // NOTE: Prevent for graph input value and output value from being the same. 32 | // for output in &mut model.graph.outputs { 33 | // if *output == id_out { 34 | // *output = id_in; 35 | // } 36 | // } 37 | }; 38 | 39 | model.graph.nodes[node_id].deleted = true; 40 | } 41 | 42 | model.remove_unnecessary_nodes(); 43 | 44 | log::info!("eliminate_identity: {:?}", start.elapsed()); 45 | } 46 | -------------------------------------------------------------------------------- /crates/core/src/optimize/layer_norm_fusion.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use crate::{ 4 | model::Model, 5 | node::Node, 6 | op::{LayerNormalization, Op}, 7 | }; 8 | 9 | // From ONNX Runtime: 10 | // +---------------------+ 11 | // | | 12 | // | v 13 | // X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add 14 | // | ^ 15 | // | | 16 | // +-----------------------------------------------+ 17 | 18 | pub fn fuse_layer_norm(model: &mut Model) { 19 | let start = Instant::now(); 20 | let nodes = model.topo_sort_nodes(); 21 | let value_users = model.get_value_users(); 22 | 23 | let mut list = vec![]; 24 | let mut delete_list = vec![]; 25 | 26 | for node_id in nodes { 27 | let mean_id = node_id; 28 | let mean = &model.graph.nodes[mean_id]; 29 | if !matches!(mean.op, Op::ReduceMean(_)) { 30 | continue; 31 | } 32 | 33 | let Some(users) = value_users.get(&mean.outputs[0]) else { 34 | continue; 35 | }; 36 | let sub_id = users.iter().next().copied().unwrap(); 37 | let sub = &model.graph.nodes[sub_id]; 38 | if !matches!(sub.op, Op::Sub) { 39 | continue; 40 | } 41 | 42 | let Some(users) = value_users.get(&sub.outputs[0]) else { 43 | continue; 44 | }; 45 | let pow_id = users.iter().next().copied().unwrap(); 46 | let pow = &model.graph.nodes[pow_id]; 47 | if !matches!(pow.op, Op::Pow) { 48 | continue; 49 | } 50 | 51 | let Some(users) = value_users.get(&pow.outputs[0]) else { 52 | continue; 53 | }; 54 | let mean2_id = users.iter().next().copied().unwrap(); 55 | let mean2 = &model.graph.nodes[mean2_id]; 56 | if !matches!(mean2.op, Op::ReduceMean(_)) { 57 | continue; 58 | } 59 | 60 | let Some(users) = value_users.get(&mean2.outputs[0]) else { 61 | continue; 62 | }; 63 | let add_id = users.iter().next().copied().unwrap(); 64 | let add = &model.graph.nodes[add_id]; 65 | if !matches!(add.op, Op::Add) { 66 | continue; 67 | } 68 | 69 | let Some(users) = value_users.get(&add.outputs[0]) else { 70 | continue; 71 | }; 72 | let sqrt_id = users.iter().next().copied().unwrap(); 73 | let sqrt = &model.graph.nodes[sqrt_id]; 74 | if !matches!(sqrt.op, Op::Sqrt) { 75 | continue; 76 | } 77 | 78 | let Some(users) = value_users.get(&sqrt.outputs[0]) else { 79 | continue; 80 | }; 81 | let div_id = users.iter().next().copied().unwrap(); 82 | let div = &model.graph.nodes[div_id]; 83 | if !matches!(div.op, Op::Div) { 84 | continue; 85 | } 86 | 87 | let Some(users) = value_users.get(&div.outputs[0]) else { 88 | continue; 89 | }; 90 | let mul_id = users.iter().next().copied().unwrap(); 91 | let mul = &model.graph.nodes[mul_id]; 92 | if !matches!(mul.op, Op::Mul) { 93 | continue; 94 | } 95 | 96 | let Some(users) = value_users.get(&mul.outputs[0]) else { 97 | continue; 98 | }; 99 | let add2_id = users.iter().next().copied().unwrap(); 100 | let add2 = &model.graph.nodes[add2_id]; 101 | if !matches!(add2.op, Op::Add) { 102 | continue; 103 | } 104 | 105 | // LayerNormalization Detected! 106 | 107 | list.push(( 108 | mean.inputs[0], 109 | add2.outputs[0], 110 | mul.inputs[1], 111 | add2.inputs[1], 112 | add.inputs[1], 113 | )); 114 | delete_list.push(mean_id); 115 | delete_list.push(sub_id); 116 | delete_list.push(pow_id); 117 | delete_list.push(mean2_id); 118 | delete_list.push(add_id); 119 | delete_list.push(sqrt_id); 120 | delete_list.push(div_id); 121 | delete_list.push(mul_id); 122 | delete_list.push(add2_id); 123 | } 124 | 125 | let count = list.len(); 126 | 127 | for (data, end, scale, bias, epsilon) in list { 128 | let epsilon = model.graph.inits.get(&epsilon).unwrap().data::()[0]; 129 | let ln_out = model.graph.values.new_val(); 130 | let ln = Node::new(Op::LayerNormalization(LayerNormalization { 131 | axis: -1, 132 | epsilon, 133 | stash_type: 1, 134 | })) 135 | .with_in(data) 136 | .with_in(scale) 137 | .with_in(bias) 138 | .with_out(ln_out); 139 | let _ln_id = model.graph.add_node(ln); 140 | 141 | let Some(users) = value_users.get(&end) else { 142 | let idx = model.graph.outputs.iter().position(|&i| i == end).unwrap(); 143 | model.graph.outputs[idx] = ln_out; 144 | continue; 145 | }; 146 | for user_id in users { 147 | let user = &mut model.graph.nodes[*user_id]; 148 | let idx = user.inputs.iter().position(|&i| i == end).unwrap(); 149 | user.inputs[idx] = ln_out; 150 | } 151 | } 152 | 153 | for node in delete_list { 154 | model.graph.nodes[node].deleted = true 155 | } 156 | 157 | model.remove_unnecessary_nodes(); 158 | 159 | log::info!("fuse_layer_norm({count}): {:?}", start.elapsed()); 160 | } 161 | -------------------------------------------------------------------------------- /crates/core/src/optimize/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod conv_act_fusion; 2 | pub mod elemwise_fusion; 3 | pub mod fast_gelu_fusion; 4 | pub mod gelu_fusion; 5 | pub mod identity_elim; 6 | pub mod layer_norm_fusion; 7 | // pub mod transpose_fusion; // TODO: Implemented but I figured out it's not useful for now. 8 | -------------------------------------------------------------------------------- /crates/core/src/optimize/transpose_fusion.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use crate::{ 4 | model::Model, 5 | node::Node, 6 | op::{Gemm, Op}, 7 | value::ValueId, 8 | }; 9 | 10 | pub fn fuse_transpose_matmul(model: &mut Model) { 11 | let start = Instant::now(); 12 | let nodes = model.topo_sort_nodes(); 13 | let value_users = model.get_value_users(); 14 | let value_producer = model.get_value_parents(); 15 | 16 | let mut list = vec![]; 17 | let mut delete_list = vec![]; 18 | 19 | enum MMInput { 20 | Transpose(ValueId), 21 | Other(ValueId), 22 | } 23 | 24 | impl MMInput { 25 | fn val(&self) -> ValueId { 26 | match self { 27 | MMInput::Transpose(v) => *v, 28 | MMInput::Other(v) => *v, 29 | } 30 | } 31 | } 32 | 33 | // TODO: WIP 34 | for node_id in nodes { 35 | let mm_id = node_id; 36 | let mm = &model.nodes[mm_id]; 37 | if !matches!(mm.op, Op::MatMul) { 38 | continue; 39 | } 40 | let mut lhs_transpose = false; 41 | let mut rhs_transpose = false; 42 | let mut lhs_input = MMInput::Other(mm.inputs[0]); 43 | let mut rhs_input = MMInput::Other(mm.inputs[1]); 44 | 45 | if let Some(&transpose_id) = value_producer.get(&mm.inputs[0]) { 46 | let transpose = &model.nodes[transpose_id]; 47 | if matches!(transpose.op, Op::Transpose(ref t) 48 | if t.perm == [0, 1, 3, 2] || t.perm == [1, 0] || t.perm == [0, 2, 1]) 49 | { 50 | lhs_transpose = true; 51 | lhs_input = MMInput::Transpose(transpose.inputs[0]); 52 | delete_list.push(transpose_id); 53 | } 54 | } 55 | if let Some(&transpose_id) = value_producer.get(&mm.inputs[1]) { 56 | let transpose = &model.nodes[transpose_id]; 57 | if matches!(transpose.op, Op::Transpose(ref t) 58 | if t.perm == [0, 1, 3, 2] || t.perm == [1, 0] || t.perm == [0, 2, 1]) 59 | { 60 | rhs_transpose = true; 61 | rhs_input = MMInput::Transpose(transpose.inputs[0]); 62 | delete_list.push(transpose_id); 63 | } 64 | } 65 | 66 | if lhs_transpose || rhs_transpose { 67 | list.push((lhs_input, rhs_input, mm.outputs[0])); 68 | delete_list.push(mm_id); 69 | } 70 | } 71 | 72 | let count = list.len(); 73 | 74 | for (lhs_input, rhs_input, mm_output) in list { 75 | let gemm_out = model.values.new_val(); 76 | let gemm = Node::new(Op::Gemm(Gemm { 77 | alpha: 1.0, 78 | beta: 0.0, 79 | trans_a: matches!(lhs_input, MMInput::Transpose(_)), 80 | trans_b: matches!(rhs_input, MMInput::Transpose(_)), 81 | })) 82 | .with_in(lhs_input.val()) 83 | .with_in(rhs_input.val()) 84 | .with_out(gemm_out); 85 | model.add_node(gemm); 86 | 87 | for user_id in &value_users[&mm_output] { 88 | let user = &mut model.nodes[*user_id]; 89 | for i in &mut user.inputs { 90 | if *i == mm_output { 91 | *i = gemm_out; 92 | } 93 | } 94 | } 95 | } 96 | 97 | for node in delete_list { 98 | model.nodes[node].deleted = true 99 | } 100 | 101 | model.remove_unnecessary_nodes(); 102 | 103 | log::info!("fuse_transpose_matmul({count}): {:?}", start.elapsed()); 104 | } 105 | -------------------------------------------------------------------------------- /crates/core/src/snapshots/altius_core__model__mnist_model.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: core/src/model.rs 3 | assertion_line: 200 4 | expression: order 5 | --- 6 | [ 7 | Id { 8 | idx: 9, 9 | }, 10 | Id { 11 | idx: 0, 12 | }, 13 | Id { 14 | idx: 1, 15 | }, 16 | Id { 17 | idx: 2, 18 | }, 19 | Id { 20 | idx: 3, 21 | }, 22 | Id { 23 | idx: 4, 24 | }, 25 | Id { 26 | idx: 5, 27 | }, 28 | Id { 29 | idx: 6, 30 | }, 31 | Id { 32 | idx: 7, 33 | }, 34 | Id { 35 | idx: 8, 36 | }, 37 | Id { 38 | idx: 10, 39 | }, 40 | Id { 41 | idx: 11, 42 | }, 43 | ] 44 | -------------------------------------------------------------------------------- /crates/core/src/snapshots/altius_core__tensor__dump_bool_tensor.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: core/src/tensor.rs 3 | assertion_line: 535 4 | expression: t 5 | --- 6 | Tensor([2, 3, 4], Bool, [false, false, false, false, false, ..., false, false, false, false, false, false]) 7 | -------------------------------------------------------------------------------- /crates/core/src/snapshots/altius_core__tensor__dump_f32_tensor.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: core/src/tensor.rs 3 | assertion_line: 532 4 | expression: t 5 | --- 6 | Tensor([2, 3, 4], F32, [0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 7 | -------------------------------------------------------------------------------- /crates/core/src/snapshots/altius_core__tensor__dump_i32_tensor.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: core/src/tensor.rs 3 | assertion_line: 533 4 | expression: t 5 | --- 6 | Tensor([2, 3, 4], I32, [0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0]) 7 | -------------------------------------------------------------------------------- /crates/core/src/snapshots/altius_core__tensor__dump_i64_tensor.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: core/src/tensor.rs 3 | assertion_line: 534 4 | expression: t 5 | --- 6 | Tensor([2, 3, 4], I64, [0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0]) 7 | -------------------------------------------------------------------------------- /crates/core/src/value.rs: -------------------------------------------------------------------------------- 1 | use std::ops::{Index, IndexMut}; 2 | 3 | use id_arena::{Arena, Id}; 4 | 5 | use crate::tensor::TypedShape; 6 | 7 | pub type ValueId = Id; 8 | 9 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 10 | pub struct Value { 11 | pub name: Option, 12 | pub shape: Option, 13 | } 14 | 15 | #[derive(Debug, Default, Clone)] 16 | pub struct ValueArena(Arena); 17 | 18 | impl ValueArena { 19 | pub fn new_val(&mut self) -> ValueId { 20 | self.0.alloc(Value { 21 | name: None, 22 | shape: None, 23 | }) 24 | } 25 | 26 | pub fn new_val_named(&mut self, name: impl Into) -> ValueId { 27 | self.0.alloc(Value { 28 | name: Some(name.into()), 29 | shape: None, 30 | }) 31 | } 32 | 33 | pub fn new_val_named_and_shaped( 34 | &mut self, 35 | name: impl Into, 36 | shape: impl Into, 37 | ) -> ValueId { 38 | self.0.alloc(Value { 39 | name: Some(name.into()), 40 | shape: Some(shape.into()), 41 | }) 42 | } 43 | 44 | pub fn inner(&self) -> &Arena { 45 | &self.0 46 | } 47 | 48 | pub fn inner_mut(&mut self) -> &mut Arena { 49 | &mut self.0 50 | } 51 | } 52 | 53 | impl Index for ValueArena { 54 | type Output = Value; 55 | 56 | fn index(&self, index: ValueId) -> &Self::Output { 57 | &self.0[index] 58 | } 59 | } 60 | 61 | impl IndexMut for ValueArena { 62 | fn index_mut(&mut self, index: ValueId) -> &mut Self::Output { 63 | &mut self.0[index] 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /crates/session/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "altius_session" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | altius-core = { path = "../core" } 8 | log = { workspace = true } 9 | rustc-hash = { workspace = true } 10 | thiserror = { workspace = true } 11 | libloading = "^0.8.1" 12 | cranelift-module = { workspace = true } 13 | 14 | [dev-dependencies] 15 | color-backtrace = "0.5.1" 16 | env_logger = "0.9.0" 17 | image = "0.24.2" 18 | structopt = "0.3.26" 19 | criterion = "0.4.0" 20 | ort = "1.16.2" 21 | ndarray = "0.15.6" 22 | tempfile = "^3.8.1" 23 | -------------------------------------------------------------------------------- /crates/session/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::excessive_precision)] 2 | 3 | pub mod plan; 4 | 5 | use std::borrow::Cow; 6 | 7 | use altius_core::{analysis::shape::ShapeError, tensor::Tensor}; 8 | use cranelift_module::ModuleError; 9 | use thiserror::Error; 10 | 11 | #[derive(Debug, Error)] 12 | pub enum SessionError { 13 | /// Errors arised from shape inference. 14 | #[error("Shape: {0}")] 15 | Shape(#[from] ShapeError), 16 | 17 | #[error("Io: {0}")] 18 | Io(#[from] std::io::Error), 19 | 20 | #[error("Libloading: {0}")] 21 | Libloading(#[from] libloading::Error), 22 | 23 | #[error("Cranelift: {0}")] 24 | Cranelift(#[from] Box), 25 | 26 | /// General error messages (including TODOs). 27 | #[error("Something went wrong: {0}")] 28 | Message(Cow<'static, str>), 29 | } 30 | 31 | impl From for SessionError { 32 | fn from(err: ModuleError) -> Self { 33 | SessionError::Cranelift(Box::new(err)) 34 | } 35 | } 36 | 37 | pub trait Session { 38 | fn run(&self, inputs: Vec) -> Result, SessionError>; 39 | } 40 | -------------------------------------------------------------------------------- /crates/session/src/plan.rs: -------------------------------------------------------------------------------- 1 | use altius_core::{model::Model, node::NodeId, value::ValueId}; 2 | use rustc_hash::FxHashMap; 3 | 4 | /// Represents a node to execute and values to be freed after the execution of the node. 5 | #[derive(Debug)] 6 | pub struct NodeExecutionPlan { 7 | /// The node to execute. 8 | pub node_id: NodeId, 9 | 10 | /// Values to be freed after the execution of the node. 11 | pub free_vals: Vec, 12 | } 13 | 14 | pub fn create_execution_plan(model: &Model) -> Vec { 15 | let sorted_nodes = model.topo_sort_nodes(); 16 | let node_order: FxHashMap = sorted_nodes 17 | .iter() 18 | .enumerate() 19 | .map(|(i, id)| (*id, i)) 20 | .collect(); 21 | let mut new_sorted_nodes = vec![]; 22 | let mut node_to_free_vals = FxHashMap::default(); 23 | let value_users = model.get_value_users(); 24 | 25 | for node_id in sorted_nodes { 26 | let node = &model.graph.nodes[node_id]; 27 | let mut plan = NodeExecutionPlan { 28 | node_id, 29 | free_vals: Vec::new(), 30 | }; 31 | 32 | for &output_id in &node.outputs { 33 | if !value_users.contains_key(&output_id) { 34 | continue; 35 | } 36 | 37 | let users = &value_users[&output_id]; 38 | let last_user = users 39 | .iter() 40 | .map(|id| (node_order[id], id)) 41 | .max_by(|x, y| x.0.cmp(&y.0)) 42 | .unwrap() 43 | .1; 44 | node_to_free_vals 45 | .entry(last_user) 46 | .or_insert_with(Vec::new) 47 | .push(output_id) 48 | } 49 | 50 | if let Some(mut vals) = node_to_free_vals.remove(&node_id) { 51 | plan.free_vals.append(&mut vals); 52 | } 53 | 54 | new_sorted_nodes.push(plan); 55 | } 56 | 57 | new_sorted_nodes 58 | } 59 | -------------------------------------------------------------------------------- /crates/session/tests/ort.rs: -------------------------------------------------------------------------------- 1 | use altius_core::{ 2 | graph::Graph, 3 | model::Model, 4 | node::Node, 5 | onnx::save::save_onnx, 6 | op::Op, 7 | tensor::{TensorElemType, TypedFixedShape}, 8 | }; 9 | use ndarray::CowArray; 10 | use ort::{Environment, ExecutionProvider, SessionBuilder, Value}; 11 | 12 | #[test] 13 | fn ort_add() { 14 | let path = tempfile::NamedTempFile::new().unwrap(); 15 | export_onnx(path.path().to_str().unwrap()); 16 | 17 | let env = Environment::builder() 18 | .with_execution_providers(&[ExecutionProvider::CPU(Default::default())]) 19 | .build() 20 | .unwrap() 21 | .into_arc(); 22 | let sess = SessionBuilder::new(&env) 23 | .unwrap() 24 | .with_model_from_file(path) 25 | .unwrap(); 26 | let x = CowArray::from(&[1.0f32, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0]) 27 | .into_shape((4, 2)) 28 | .unwrap() 29 | .into_dimensionality() 30 | .unwrap(); 31 | let y = CowArray::from(&[2.0f32, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0]) 32 | .into_shape((4, 2)) 33 | .unwrap() 34 | .into_dimensionality() 35 | .unwrap(); 36 | let x = Value::from_array(sess.allocator(), &x).unwrap(); 37 | let y = Value::from_array(sess.allocator(), &y).unwrap(); 38 | let z = &sess.run(vec![x, y]).unwrap()[0]; 39 | let z = z.try_extract::().unwrap(); 40 | let z = z.view(); 41 | 42 | assert!(z.shape() == &[4, 2]); 43 | assert!(z.as_slice() == Some(&[3.0f32, 7.0, 11.0, 15.0, 19.0, 23.0, 27.0, 31.0])); 44 | } 45 | 46 | #[cfg(test)] 47 | fn export_onnx(path: &str) { 48 | // TODO: We need a better interface for building models. 49 | let mut model = Model { 50 | graph: Graph::default(), 51 | opset_version: 12, 52 | }; 53 | let x = model.graph.values.new_val_named_and_shaped( 54 | "x", 55 | TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), 56 | ); 57 | let y = model.graph.values.new_val_named_and_shaped( 58 | "y", 59 | TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), 60 | ); 61 | let z = model.graph.values.new_val_named_and_shaped( 62 | "z", 63 | TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), 64 | ); 65 | model 66 | .graph 67 | .nodes 68 | .alloc(Node::new(Op::Add).with_ins(vec![x, y]).with_outs(vec![z])); 69 | model.graph.inputs.push(x); 70 | model.graph.inputs.push(y); 71 | model.graph.outputs.push(z); 72 | save_onnx(&model, path).unwrap(); 73 | } 74 | -------------------------------------------------------------------------------- /crates/session_clang/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "altius_session_clang" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | log = { workspace = true } 8 | rustc-hash = { workspace = true } 9 | altius-core = { path = "../core" } 10 | altius_session = { path = "../session" } 11 | cranelift = { workspace = true } 12 | cranelift-module = { workspace = true } 13 | cranelift-object = { workspace = true } 14 | cranelift-codegen = { workspace = true } 15 | libloading = "^0.8.1" 16 | indent = "0.1.1" 17 | sha1 = "0.10.5" 18 | glob = "0.3.1" 19 | num_cpus = "1.15.0" 20 | target-lexicon = "^0.12.7" 21 | tempfile = "^3.8.1" 22 | 23 | [target.'cfg(target_os = "linux")'.dependencies] 24 | blis-src = { version = "*", features = [ "openmp" ], default-features = false } 25 | 26 | [dev-dependencies] 27 | ndarray = "0.15.6" 28 | color-backtrace = "0.5.1" 29 | env_logger = "0.9.0" 30 | image = "0.24.2" 31 | structopt = "0.3.26" 32 | criterion = "0.4.0" 33 | ort = { version = "1.16.2", features = [ "profiling" ] } 34 | rayon = "1.8.0" 35 | -------------------------------------------------------------------------------- /crates/session_clang/examples/deit_cpu.rs: -------------------------------------------------------------------------------- 1 | use std::{cmp::Ordering, fs::read_to_string, path::Path, time::Instant}; 2 | 3 | use ndarray::CowArray; 4 | use ort::{Environment, GraphOptimizationLevel, SessionBuilder, Value}; 5 | use structopt::StructOpt; 6 | 7 | #[derive(Debug, StructOpt)] 8 | #[structopt(name = "compile")] 9 | pub struct Opt { 10 | #[structopt(long = "profile", help = "Enable profiling")] 11 | pub profile: bool, 12 | 13 | #[structopt(long = "iters", help = "The number of iterations", default_value = "1")] 14 | pub iters: usize, 15 | 16 | #[structopt( 17 | long = "threads", 18 | help = "The number of computation threads", 19 | default_value = "1" 20 | )] 21 | pub threads: usize, 22 | 23 | #[structopt(long = "ort", help = "Use ONNX Runtime")] 24 | pub ort: bool, 25 | } 26 | 27 | fn main() { 28 | use altius_core::optimize::elemwise_fusion::fuse_elemwise_ops; 29 | use altius_core::optimize::gelu_fusion::fuse_gelu; 30 | use altius_core::optimize::layer_norm_fusion::fuse_layer_norm; 31 | use altius_core::{onnx::load_onnx, tensor::Tensor}; 32 | use altius_session_clang::ClangSessionBuilder; 33 | use std::fs; 34 | 35 | env_logger::init(); 36 | color_backtrace::install(); 37 | 38 | let opt = Opt::from_args(); 39 | 40 | if opt.ort { 41 | return run_on_ort(&opt); 42 | } 43 | 44 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 45 | let mut model = load_onnx(root.join("deit.onnx")) 46 | .expect("Failed to load model. Have you run altius_py/deit.py?"); 47 | fuse_layer_norm(&mut model); 48 | fuse_gelu(&mut model); 49 | fuse_elemwise_ops(&mut model).unwrap(); 50 | 51 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 52 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 53 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 54 | let mean = [0.5, 0.5, 0.5][c]; 55 | let std = [0.5, 0.5, 0.5][c]; 56 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 57 | }); 58 | let input = Tensor::new(vec![1, 3, 224, 224].into(), image.into_raw_vec()); 59 | 60 | let i = ClangSessionBuilder::new(model) 61 | .with_profiling_enabled(opt.profile) 62 | .with_intra_op_num_threads(opt.threads) 63 | .build() 64 | .unwrap(); 65 | let classes = fs::read_to_string(Path::new(&root).join("imagenet_classes.txt")).unwrap(); 66 | let classes = classes.split('\n').collect::>(); 67 | for _ in 0..opt.iters { 68 | let out = i.run(vec![input.clone()]).expect("Inference failed"); 69 | let mut out = out[0].data::().iter().enumerate().collect::>(); 70 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 71 | 72 | println!("inference result: {}", classes[out[0].0]); 73 | println!("top5: {:?}", &out[..5]); 74 | } 75 | } 76 | 77 | fn run_on_ort(opt: &Opt) { 78 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 79 | let env = Environment::builder() 80 | .with_name("altius") 81 | .build() 82 | .unwrap() 83 | .into_arc(); 84 | let session = SessionBuilder::new(&env) 85 | .unwrap() 86 | .with_optimization_level(GraphOptimizationLevel::Level3) 87 | .unwrap() 88 | .with_intra_threads(opt.threads as i16) 89 | .unwrap() 90 | .with_model_from_file(root.join("deit.onnx")) 91 | .unwrap(); 92 | 93 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 94 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 95 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 96 | let mean = [0.5, 0.5, 0.5][c]; 97 | let std = [0.5, 0.5, 0.5][c]; 98 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 99 | }); 100 | let input = CowArray::from(image) 101 | .into_shape((1, 3, 224, 224)) 102 | .unwrap() 103 | .into_dimensionality() 104 | .unwrap(); 105 | 106 | let classes = read_to_string(Path::new(&root).join("imagenet_classes.txt")).unwrap(); 107 | for _ in 0..opt.iters { 108 | let input = Value::from_array(session.allocator(), &input).unwrap(); 109 | let now = Instant::now(); 110 | let output = &session.run(vec![input]).unwrap()[0]; 111 | let elapsed = now.elapsed(); 112 | let output = output.try_extract::().unwrap(); 113 | let output = output.view(); 114 | let mut output = output.iter().enumerate().collect::>(); 115 | output.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 116 | 117 | println!( 118 | "inferred: {} (in {elapsed:?})", 119 | classes.split('\n').collect::>()[output[0].0], 120 | ); 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /crates/session_clang/examples/mnist_cpu.rs: -------------------------------------------------------------------------------- 1 | use structopt::StructOpt; 2 | 3 | #[derive(Debug, StructOpt)] 4 | #[structopt(name = "compile")] 5 | pub struct Opt { 6 | #[structopt(long = "iters", help = "The number of iterations", default_value = "1")] 7 | pub iters: usize, 8 | } 9 | 10 | fn main() { 11 | use altius_core::onnx::load_onnx; 12 | use altius_core::tensor::*; 13 | use altius_session_clang::ClangSessionBuilder; 14 | use std::cmp::Ordering; 15 | use std::fs; 16 | use std::path::Path; 17 | use std::time::Instant; 18 | 19 | env_logger::init(); 20 | color_backtrace::install(); 21 | 22 | let opt = Opt::from_args(); 23 | let model_root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 24 | let model = load_onnx(model_root.join("mnist-8.onnx")).unwrap(); 25 | 26 | let mut inputs = vec![]; 27 | for line in fs::read_to_string(Path::new(&model_root).join("MNIST_test.txt")) 28 | .unwrap() 29 | .split('\n') 30 | { 31 | if line.is_empty() { 32 | continue; 33 | } 34 | let nums: Vec<&str> = line.split(',').collect(); 35 | let expected: i32 = nums[0].parse().unwrap(); 36 | let pixels: Tensor = Tensor::new( 37 | vec![1, 1, 28, 28].into(), 38 | nums[1..] 39 | .iter() 40 | .map(|s| s.parse::().unwrap() / 255.0) 41 | .collect::>(), 42 | ); 43 | inputs.push((expected, pixels)); 44 | } 45 | 46 | let validation_count = 10000; 47 | let sess = ClangSessionBuilder::new(model) 48 | .with_profiling_enabled(false) 49 | .build() 50 | .unwrap(); 51 | 52 | for _ in 0..opt.iters { 53 | let start = Instant::now(); 54 | 55 | let correct: i32 = inputs 56 | .iter() 57 | .take(validation_count) 58 | .map(|(expected, input)| { 59 | let v = sess.run(vec![input.clone()]).expect("Inference failed"); 60 | let inferred = v[0] 61 | .data::() 62 | .iter() 63 | .enumerate() 64 | .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) 65 | .map(|(index, _)| index) 66 | .unwrap(); 67 | (*expected == inferred as i32) as i32 68 | }) 69 | .sum(); 70 | 71 | let end = start.elapsed(); 72 | println!("elapsed: {end:?}"); 73 | println!("fps: {:?}", (validation_count as f64) / end.as_secs_f64()); 74 | println!("accuracy: {}", correct as f32 / validation_count as f32); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /crates/session_clang/examples/mobilenet_cpu.rs: -------------------------------------------------------------------------------- 1 | use ndarray::CowArray; 2 | use ort::{Environment, ExecutionProvider, SessionBuilder, Value}; 3 | use structopt::StructOpt; 4 | 5 | #[derive(Debug, StructOpt)] 6 | #[structopt(name = "compile")] 7 | struct Opt { 8 | #[structopt(long = "profile", help = "Enable profiling")] 9 | profile: bool, 10 | 11 | #[structopt(long = "iters", help = "The number of iterations", default_value = "1")] 12 | iters: usize, 13 | 14 | #[structopt( 15 | long = "threads", 16 | help = "The number of computation threads", 17 | default_value = "1" 18 | )] 19 | threads: usize, 20 | 21 | #[structopt(long = "ort", help = "Use ONNX Runtime")] 22 | ort: bool, 23 | } 24 | 25 | fn main() { 26 | use altius_core::{onnx::load_onnx, tensor::Tensor}; 27 | use altius_session_clang::ClangSessionBuilder; 28 | use std::cmp::Ordering; 29 | use std::fs; 30 | use std::path::Path; 31 | 32 | env_logger::init(); 33 | color_backtrace::install(); 34 | 35 | let opt = Opt::from_args(); 36 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 37 | 38 | let classes = fs::read_to_string(Path::new(&root).join("imagenet_classes.txt")).unwrap(); 39 | let classes = classes.split('\n').collect::>(); 40 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 41 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 42 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 43 | let mean = [0.485, 0.456, 0.406][c]; 44 | let std = [0.229, 0.224, 0.225][c]; 45 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 46 | }); 47 | let input = CowArray::from(image.as_slice().unwrap()) 48 | .into_shape((1, 3, 224, 224)) 49 | .unwrap() 50 | .into_dimensionality() 51 | .unwrap(); 52 | 53 | if opt.ort { 54 | let env = Environment::builder() 55 | .with_execution_providers(&[ExecutionProvider::CPU(Default::default())]) 56 | .build() 57 | .unwrap() 58 | .into_arc(); 59 | let sess = SessionBuilder::new(&env) 60 | .unwrap() 61 | .with_optimization_level(ort::GraphOptimizationLevel::Level3) 62 | .unwrap() 63 | .with_intra_threads(8) 64 | .unwrap() 65 | .with_model_from_file(root.join("mobilenetv3.onnx")) 66 | .unwrap(); 67 | for _ in 0..opt.iters { 68 | let x = Value::from_array(sess.allocator(), &input).unwrap(); 69 | use std::time::Instant; 70 | let start = Instant::now(); 71 | let out = &sess.run(vec![x]).unwrap()[0]; 72 | log::info!("ort: {:?}", start.elapsed()); 73 | let out = out.try_extract::().unwrap(); 74 | let out = out.view(); 75 | let out = out.as_slice().unwrap(); 76 | let mut out = out.iter().enumerate().collect::>(); 77 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 78 | 79 | println!("prediction: {}", classes[out[0].0]); 80 | println!("top5: {:?}", &out[..5]); 81 | } 82 | } else { 83 | let model = load_onnx(root.join("mobilenetv3.onnx")).unwrap(); 84 | let session = ClangSessionBuilder::new(model) 85 | .with_profiling_enabled(opt.profile) 86 | .with_intra_op_num_threads(opt.threads) 87 | .build() 88 | .unwrap(); 89 | for _ in 0..opt.iters { 90 | let out = session 91 | .run(vec![Tensor::from(&input)]) 92 | .expect("Inference failed"); 93 | let mut out = out[0].data::().iter().enumerate().collect::>(); 94 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 95 | 96 | println!("prediction: {}", classes[out[0].0]); 97 | println!("top5: {:?}", &out[..5]); 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /crates/session_clang/examples/vit_cpu.rs: -------------------------------------------------------------------------------- 1 | use structopt::StructOpt; 2 | 3 | #[derive(Debug, StructOpt)] 4 | #[structopt(name = "compile")] 5 | pub struct Opt { 6 | #[structopt(long = "profile", help = "Enable profiling")] 7 | pub profile: bool, 8 | 9 | #[structopt(long = "iters", help = "The number of iterations", default_value = "1")] 10 | pub iters: usize, 11 | 12 | #[structopt( 13 | long = "threads", 14 | help = "The number of computation threads", 15 | default_value = "1" 16 | )] 17 | pub threads: usize, 18 | } 19 | 20 | fn main() { 21 | use altius_core::optimize::elemwise_fusion::fuse_elemwise_ops; 22 | use altius_core::optimize::gelu_fusion::fuse_gelu; 23 | use altius_core::optimize::layer_norm_fusion::fuse_layer_norm; 24 | use altius_core::{onnx::load_onnx, tensor::Tensor}; 25 | use altius_session_clang::ClangSessionBuilder; 26 | use std::cmp::Ordering; 27 | use std::fs; 28 | use std::path::Path; 29 | 30 | env_logger::init(); 31 | color_backtrace::install(); 32 | 33 | let opt = Opt::from_args(); 34 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 35 | let mut model = load_onnx(root.join("vit_b_16.onnx")).expect("Failed to load model"); 36 | fuse_layer_norm(&mut model); 37 | fuse_gelu(&mut model); 38 | fuse_elemwise_ops(&mut model).unwrap(); 39 | 40 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 41 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 42 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 43 | let mean = [0.5, 0.5, 0.5][c]; 44 | let std = [0.5, 0.5, 0.5][c]; 45 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 46 | }); 47 | let input = Tensor::new(vec![1, 3, 224, 224].into(), image.into_raw_vec()); 48 | 49 | let i = ClangSessionBuilder::new(model) 50 | .with_profiling_enabled(opt.profile) 51 | .with_intra_op_num_threads(opt.threads) 52 | .build() 53 | .unwrap(); 54 | let classes = fs::read_to_string(Path::new(&root).join("imagenet_classes.txt")).unwrap(); 55 | for _ in 0..opt.iters { 56 | let out = i.run(vec![input.clone()]).expect("Inference failed"); 57 | let mut out = out[0].data::().iter().enumerate().collect::>(); 58 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 59 | 60 | let classes = classes.split('\n').collect::>(); 61 | println!("inferred: {}", classes[out[0].0]); 62 | println!("top5: {:?}", &out[..5]); 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /crates/session_clang/src/builder.rs: -------------------------------------------------------------------------------- 1 | use altius_core::{analysis::shape::infer_shapes, model::Model}; 2 | use rustc_hash::FxHashMap; 3 | 4 | use altius_session::SessionError; 5 | 6 | use super::{session::ClangSession, translator::Translator}; 7 | 8 | pub struct ClangSessionBuilder { 9 | model: Model, 10 | intra_op_num_threads: usize, 11 | enable_profiling: bool, 12 | } 13 | 14 | impl ClangSessionBuilder { 15 | pub const fn new(model: Model) -> Self { 16 | Self { 17 | model, 18 | intra_op_num_threads: 1, 19 | enable_profiling: false, 20 | } 21 | } 22 | 23 | pub const fn with_intra_op_num_threads(mut self, intra_op_num_threads: usize) -> Self { 24 | self.intra_op_num_threads = intra_op_num_threads; 25 | self 26 | } 27 | 28 | pub const fn with_profiling_enabled(mut self, enable_profiling: bool) -> Self { 29 | self.enable_profiling = enable_profiling; 30 | self 31 | } 32 | 33 | pub fn build(self) -> Result { 34 | let mut inferred_shapes = FxHashMap::default(); 35 | let mut value_shapes = FxHashMap::default(); 36 | infer_shapes(&self.model, &mut inferred_shapes, &mut value_shapes)?; 37 | 38 | let mut profile_symbols = FxHashMap::default(); 39 | let product = Translator::new(&self.model, &inferred_shapes, &value_shapes)? 40 | .with_profiling_enabled(self.enable_profiling) 41 | .with_intra_op_num_threads(self.intra_op_num_threads) 42 | .compile()?; 43 | 44 | #[cfg(target_os = "linux")] 45 | let lib: libloading::Library = unsafe { 46 | // Load library with `RTLD_NOW | RTLD_NODELETE` to fix a SIGSEGV 47 | libloading::os::unix::Library::open( 48 | Some(product.target_dir.join("model.so")), 49 | 0x2 | 0x1000, 50 | )? 51 | .into() 52 | }; 53 | #[cfg(not(target_os = "linux"))] 54 | let lib = unsafe { libloading::Library::new(product.target_dir.join("model.so")) }?; 55 | { 56 | let initializer: libloading::Symbol = 57 | unsafe { lib.get(b"initialize")? }; 58 | unsafe { initializer() }; 59 | } 60 | let trampoline: libloading::Symbol = 61 | unsafe { lib.get(b"trampoline")? }; 62 | let trampoline = *trampoline; 63 | 64 | for (&val_id, tensor) in &self.model.graph.inits { 65 | let name = product.value_name(val_id); 66 | let entry: libloading::Symbol<*const *const u8> = unsafe { lib.get(name.as_bytes())? }; 67 | unsafe { *entry.cast_mut() = tensor.data_as_ptr() }; 68 | } 69 | 70 | if self.enable_profiling { 71 | for name in product.used_op_names { 72 | let symbol: libloading::Symbol<*const f64> = 73 | unsafe { lib.get(format!("elapsed_{name}").as_bytes())? }; 74 | profile_symbols.insert(name, unsafe { *symbol.into_raw() }); 75 | } 76 | } 77 | 78 | Ok(ClangSession { 79 | target_dir: product.target_dir, 80 | model: self.model, 81 | lib, 82 | value_shapes, 83 | trampoline, 84 | enable_profiling: self.enable_profiling, 85 | profile_symbols, 86 | }) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /crates/session_clang/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[cfg(target_os = "linux")] 2 | #[allow(unused)] 3 | #[allow(clippy::single_component_path_imports)] 4 | use blis_src; 5 | 6 | mod builder; 7 | mod session; 8 | mod translator; 9 | 10 | pub use builder::ClangSessionBuilder; 11 | pub use session::ClangSession; 12 | -------------------------------------------------------------------------------- /crates/session_clang/src/session.rs: -------------------------------------------------------------------------------- 1 | use altius_core::{ 2 | flops::compute_flops, 3 | model::Model, 4 | tensor::{Tensor, TypedFixedShape}, 5 | value::ValueId, 6 | }; 7 | use altius_session::SessionError; 8 | use rustc_hash::FxHashMap; 9 | 10 | use std::{path::PathBuf, time::Instant}; 11 | 12 | pub struct ClangSession { 13 | pub(super) model: Model, 14 | #[allow(dead_code)] 15 | pub(super) target_dir: PathBuf, 16 | pub(super) value_shapes: FxHashMap, 17 | #[allow(dead_code)] 18 | pub(super) lib: libloading::Library, 19 | pub(super) trampoline: extern "C" fn(*const *const u8, *const *mut u8), 20 | pub(super) enable_profiling: bool, 21 | pub(super) profile_symbols: FxHashMap, 22 | } 23 | 24 | // TODO: Is this really safe? 25 | unsafe impl Send for ClangSession {} 26 | 27 | impl ClangSession { 28 | pub fn model(&self) -> &Model { 29 | &self.model 30 | } 31 | 32 | pub fn run(&self, inputs: Vec) -> Result, SessionError> { 33 | let mut outputs = self 34 | .model 35 | .graph 36 | .outputs 37 | .iter() 38 | .map(|id| { 39 | let shape = &self.value_shapes[id]; 40 | Tensor::uninit_of_type(shape.elem_ty, shape.dims.clone()) 41 | }) 42 | .collect::>(); 43 | 44 | let start = Instant::now(); 45 | 46 | { 47 | let mut inputs_ = Vec::with_capacity(inputs.len()); 48 | let mut outputs_ = Vec::with_capacity(outputs.len()); 49 | for tensor in inputs.iter() { 50 | inputs_.push(tensor.data_as_ptr()); 51 | } 52 | for tensor in outputs.iter_mut() { 53 | outputs_.push(tensor.data_as_mut_ptr()); 54 | } 55 | (self.trampoline)(inputs_.as_ptr(), outputs_.as_ptr()); 56 | } 57 | 58 | if self.enable_profiling { 59 | let entire_duration = start.elapsed().as_secs_f32() * 1000.0; 60 | let mut durations = self 61 | .profile_symbols 62 | .iter() 63 | .map(|(name, &duration)| { 64 | let duration = unsafe { *duration }; 65 | (name.as_str(), duration as f32 * 1000.0) 66 | }) 67 | .collect::>(); 68 | let sum_durations = durations.iter().map(|(_, d)| d).sum::(); 69 | durations.push(("All (Kernel)", sum_durations)); 70 | durations.sort_by(|(_, b), (_, a)| a.partial_cmp(b).unwrap()); 71 | let width = durations.iter().map(|(op, _)| op.len()).max().unwrap(); 72 | for (op, duration) in durations { 73 | log::info!("{op:width$}: {duration:.5} ms"); 74 | } 75 | if let Ok(flops) = compute_flops(&self.model) { 76 | log::info!( 77 | "[ {:.5} ms, {:.5} GFLOPS ]", 78 | entire_duration, 79 | flops as f32 / (entire_duration / 1000.0) / 1_000_000_000.0 80 | ); 81 | } 82 | } 83 | 84 | Ok(outputs) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /crates/session_clang/tests/ops_bin.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use altius_core::{ 4 | graph::Graph, 5 | model::Model, 6 | node::Node, 7 | onnx::{load_onnx, save::save_onnx}, 8 | op::Op, 9 | tensor::{Tensor, TensorElemType, TypedFixedShape}, 10 | }; 11 | use altius_session_clang::ClangSessionBuilder; 12 | use ndarray::CowArray; 13 | use ort::{Environment, ExecutionProvider, SessionBuilder, Value}; 14 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; 15 | 16 | #[test] 17 | fn cpu_ops_bin() { 18 | env_logger::init(); 19 | 20 | [Op::Add, Op::Sub, Op::Mul, Op::Div] 21 | .into_par_iter() 22 | .for_each(|op| { 23 | let path = tempfile::NamedTempFile::new().unwrap(); 24 | let path = path.path(); 25 | Exporter::new(path, op).export(); 26 | 27 | let x_ = Tensor::rand_of_type(TensorElemType::F32, vec![4, 2].into()); 28 | let y_ = Tensor::rand_of_type(TensorElemType::F32, vec![4, 2].into()); 29 | 30 | let env = Environment::builder() 31 | .with_execution_providers(&[ExecutionProvider::CPU(Default::default())]) 32 | .build() 33 | .unwrap() 34 | .into_arc(); 35 | let sess = SessionBuilder::new(&env) 36 | .unwrap() 37 | .with_model_from_file(path) 38 | .unwrap(); 39 | let x = CowArray::from(x_.data::()) 40 | .into_shape((4, 2)) 41 | .unwrap() 42 | .into_dimensionality() 43 | .unwrap(); 44 | let y = CowArray::from(y_.data::()) 45 | .into_shape((4, 2)) 46 | .unwrap() 47 | .into_dimensionality() 48 | .unwrap(); 49 | let x = Value::from_array(sess.allocator(), &x).unwrap(); 50 | let y = Value::from_array(sess.allocator(), &y).unwrap(); 51 | let z = &sess.run(vec![x, y]).unwrap()[0]; 52 | let z = z.try_extract::().unwrap(); 53 | let ort_z = z.view(); 54 | assert!(ort_z.shape() == &[4, 2]); 55 | 56 | let sess = ClangSessionBuilder::new(load_onnx(path).unwrap()) 57 | .build() 58 | .unwrap(); 59 | let altius_z = &sess.run(vec![x_, y_]).unwrap()[0]; 60 | assert!(altius_z.dims().as_slice() == &[4, 2]); 61 | 62 | ort_z 63 | .as_slice() 64 | .unwrap() 65 | .iter() 66 | .zip(altius_z.data::()) 67 | .for_each(|(ort, altius)| { 68 | assert!((ort - altius).abs() < 1e-6); 69 | }); 70 | }) 71 | } 72 | 73 | #[cfg(test)] 74 | struct Exporter<'a> { 75 | path: &'a Path, 76 | op: Op, 77 | } 78 | 79 | #[cfg(test)] 80 | impl<'a> Exporter<'a> { 81 | fn new(path: &'a Path, op: Op) -> Self { 82 | Self { path, op } 83 | } 84 | 85 | fn export(self) { 86 | // TODO: We need a better interface for building models. 87 | let mut model = Model { 88 | graph: Graph::default(), 89 | opset_version: 12, 90 | }; 91 | let x = model.graph.values.new_val_named_and_shaped( 92 | "x", 93 | TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), 94 | ); 95 | let y = model.graph.values.new_val_named_and_shaped( 96 | "y", 97 | TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), 98 | ); 99 | let z = model.graph.values.new_val_named_and_shaped( 100 | "z", 101 | TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), 102 | ); 103 | model 104 | .graph 105 | .nodes 106 | .alloc(Node::new(self.op).with_ins(vec![x, y]).with_outs(vec![z])); 107 | model.graph.inputs.push(x); 108 | model.graph.inputs.push(y); 109 | model.graph.outputs.push(z); 110 | save_onnx(&model, self.path).unwrap(); 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /crates/session_clang/tests/ops_conv.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use altius_core::{ 4 | graph::Graph, 5 | model::Model, 6 | node::Node, 7 | onnx::{load_onnx, save::save_onnx}, 8 | op::{Conv2d, Op}, 9 | tensor::{Tensor, TensorElemType, TypedFixedShape}, 10 | }; 11 | use altius_session_clang::ClangSessionBuilder; 12 | use ndarray::CowArray; 13 | use ort::{Environment, ExecutionProvider, SessionBuilder, Value}; 14 | 15 | #[test] 16 | fn cpu_ops_conv() { 17 | env_logger::init(); 18 | 19 | let path = tempfile::NamedTempFile::new().unwrap(); 20 | let path = path.path(); 21 | Exporter::new(path).export(); 22 | 23 | let x_ = Tensor::rand_of_type(TensorElemType::F32, vec![1, 1, 28, 28].into()); 24 | 25 | let env = Environment::builder() 26 | .with_execution_providers(&[ExecutionProvider::CPU(Default::default())]) 27 | .build() 28 | .unwrap() 29 | .into_arc(); 30 | let sess = SessionBuilder::new(&env) 31 | .unwrap() 32 | .with_model_from_file(path) 33 | .unwrap(); 34 | let x = CowArray::from(x_.data::()) 35 | .into_shape((1, 1, 28, 28)) 36 | .unwrap() 37 | .into_dimensionality() 38 | .unwrap(); 39 | let x = Value::from_array(sess.allocator(), &x).unwrap(); 40 | let z = &sess.run(vec![x]).unwrap()[0]; 41 | let z = z.try_extract::().unwrap(); 42 | let ort_z = z.view(); 43 | assert!(ort_z.shape() == &[1, 8, 28, 28]); 44 | 45 | let sess = ClangSessionBuilder::new(load_onnx(path).unwrap()) 46 | .build() 47 | .unwrap(); 48 | let altius_z = &sess.run(vec![x_]).unwrap()[0]; 49 | assert!(altius_z.dims().as_slice() == &[1, 8, 28, 28]); 50 | 51 | ort_z 52 | .as_slice() 53 | .unwrap() 54 | .iter() 55 | .zip(altius_z.data::()) 56 | .for_each(|(ort, altius)| { 57 | assert!((ort - altius).abs() < 1e-5, "{} != {}", ort, altius); 58 | }); 59 | } 60 | 61 | #[cfg(test)] 62 | struct Exporter<'a> { 63 | path: &'a Path, 64 | } 65 | 66 | #[cfg(test)] 67 | impl<'a> Exporter<'a> { 68 | fn new(path: &'a Path) -> Self { 69 | Self { path } 70 | } 71 | 72 | fn export(self) { 73 | // TODO: We need a better interface for building models. 74 | let mut model = Model { 75 | graph: Graph::default(), 76 | opset_version: 12, 77 | }; 78 | let x = model.graph.values.new_val_named_and_shaped( 79 | "x", 80 | TypedFixedShape::new(vec![1, 1, 28, 28].into(), TensorElemType::F32), 81 | ); 82 | let y = model.graph.values.new_val_named_and_shaped( 83 | "y", 84 | TypedFixedShape::new(vec![8, 1, 5, 5].into(), TensorElemType::F32), 85 | ); 86 | let z = model.graph.values.new_val_named_and_shaped( 87 | "z", 88 | TypedFixedShape::new(vec![1, 8, 28, 28].into(), TensorElemType::F32), 89 | ); 90 | model.graph.nodes.alloc( 91 | Node::new(Op::Conv2d(Conv2d { 92 | auto_pad: "SAME_UPPER".into(), 93 | kernel_shape: vec![5, 5].into(), 94 | strides: vec![1, 1].into(), 95 | group: 1, 96 | dilations: vec![1, 1].into(), 97 | ..Default::default() 98 | })) 99 | .with_ins(vec![x, y]) 100 | .with_outs(vec![z]), 101 | ); 102 | model.graph.inputs.push(x); 103 | model.graph.outputs.push(z); 104 | model.graph.inits.insert( 105 | y, 106 | Tensor::rand_of_type(TensorElemType::F32, vec![8, 1, 5, 5].into()), 107 | ); 108 | save_onnx(&model, self.path).unwrap(); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /crates/session_interpreter/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "altius_session_interpreter" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | log = { workspace = true } 8 | rustc-hash = { workspace = true } 9 | altius-core = { path = "../core" } 10 | altius_session = { path = "../session" } 11 | ndarray = "0.15.6" 12 | core_affinity = "^0.7.6" 13 | matrixmultiply = "0.3.2" 14 | fastapprox = "^0.3.0" 15 | thread_local = "^1.1" 16 | paste = "1.0.11" 17 | threadpool = "^1.8.1" 18 | cblas-sys = "0.1.4" 19 | cudnn = { git = "https://github.com/Rust-GPU/Rust-CUDA", optional = true } 20 | cust = { git = "https://github.com/Rust-GPU/Rust-CUDA", optional = true } 21 | 22 | [target.'cfg(not(target_arch = "wasm32"))'.dependencies] 23 | mimalloc = { version = "0.1.46", default-features = false, features = ["local_dynamic_tls"] } 24 | 25 | [target.'cfg(target_os = "linux")'.dependencies] 26 | blis-src = { version = "*", features = [ "openmp" ], default-features = false } 27 | procfs = "0.14.2" 28 | 29 | [target.'cfg(target_os = "macos")'.dependencies] 30 | blas-src = { version = "0.8", features = ["accelerate"] } 31 | 32 | [features] 33 | default = ["cblas"] 34 | matrixmultiply-threading = ["matrixmultiply/threading"] 35 | cblas = [] 36 | heavy-log = [] 37 | cuda = ["cudnn", "cust"] 38 | 39 | [dev-dependencies] 40 | color-backtrace = "0.5.1" 41 | env_logger = "0.9.0" 42 | image = "0.24.2" 43 | structopt = "0.3.26" 44 | criterion = "0.4.0" 45 | 46 | [[bench]] 47 | name = "interpreter" 48 | harness = false 49 | -------------------------------------------------------------------------------- /crates/session_interpreter/benches/interpreter.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use altius_core::{ 4 | onnx::load_onnx, 5 | optimize::{gelu_fusion::fuse_gelu, layer_norm_fusion::fuse_layer_norm}, 6 | tensor::Tensor, 7 | }; 8 | use altius_session_interpreter::InterpreterSessionBuilder; 9 | use criterion::{criterion_group, criterion_main, Criterion}; 10 | 11 | const THREADS: usize = 1; 12 | 13 | fn without_gelu(c: &mut Criterion) { 14 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 15 | let model = load_onnx(root.join("deit.onnx")) 16 | .expect("Failed to load model. Have you run altius_py/deit.py?"); 17 | 18 | let input = Tensor::rand::(vec![1, 3, 224, 224].into()); 19 | 20 | let sess = InterpreterSessionBuilder::new(model) 21 | .with_intra_op_num_threads(THREADS) 22 | .build() 23 | .unwrap(); 24 | c.bench_function("No fusion", |b| b.iter(|| sess.run(vec![input.clone()]))); 25 | } 26 | 27 | fn with_gelu(c: &mut Criterion) { 28 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 29 | let mut model = load_onnx(root.join("deit.onnx")) 30 | .expect("Failed to load model. Have you run altius_py/deit.py?"); 31 | fuse_gelu(&mut model); 32 | 33 | let input = Tensor::rand::(vec![1, 3, 224, 224].into()); 34 | 35 | let sess = InterpreterSessionBuilder::new(model) 36 | .with_intra_op_num_threads(THREADS) 37 | .build() 38 | .unwrap(); 39 | c.bench_function("Gelu fusion", |b| b.iter(|| sess.run(vec![input.clone()]))); 40 | } 41 | 42 | fn with_gelu_ln(c: &mut Criterion) { 43 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 44 | let mut model = load_onnx(root.join("deit.onnx")) 45 | .expect("Failed to load model. Have you run altius_py/deit.py?"); 46 | fuse_layer_norm(&mut model); 47 | fuse_gelu(&mut model); 48 | 49 | let input = Tensor::rand::(vec![1, 3, 224, 224].into()); 50 | 51 | let sess = InterpreterSessionBuilder::new(model) 52 | .with_intra_op_num_threads(THREADS) 53 | .build() 54 | .unwrap(); 55 | c.bench_function("LN fusion, Gelu fusion", |b| { 56 | b.iter(|| sess.run(vec![input.clone()])) 57 | }); 58 | } 59 | 60 | fn with_gelu_ln2(c: &mut Criterion) { 61 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 62 | let mut model = load_onnx(root.join("deit.onnx")) 63 | .expect("Failed to load model. Have you run altius_py/deit.py?"); 64 | fuse_gelu(&mut model); 65 | fuse_layer_norm(&mut model); 66 | 67 | let input = Tensor::rand::(vec![1, 3, 224, 224].into()); 68 | 69 | let sess = InterpreterSessionBuilder::new(model) 70 | .with_intra_op_num_threads(THREADS) 71 | .build() 72 | .unwrap(); 73 | c.bench_function("Gelu fusion, LN fusion", |b| { 74 | b.iter(|| sess.run(vec![input.clone()])) 75 | }); 76 | } 77 | 78 | criterion_group! { 79 | fusion, 80 | with_gelu, 81 | with_gelu_ln, 82 | with_gelu_ln2, 83 | without_gelu, 84 | } 85 | 86 | criterion_main!(fusion); 87 | -------------------------------------------------------------------------------- /crates/session_interpreter/examples/deit.rs: -------------------------------------------------------------------------------- 1 | use altius_core::optimize::gelu_fusion::fuse_gelu; 2 | use altius_core::optimize::layer_norm_fusion::fuse_layer_norm; 3 | use altius_core::{onnx::load_onnx, tensor::Tensor}; 4 | use altius_session_interpreter::InterpreterSessionBuilder; 5 | use std::cmp::Ordering; 6 | use std::fs; 7 | use std::path::Path; 8 | use structopt::StructOpt; 9 | 10 | #[derive(Debug, StructOpt)] 11 | #[structopt(name = "compile")] 12 | pub struct Opt { 13 | #[structopt(long = "profile", help = "Enable profiling")] 14 | pub profile: bool, 15 | 16 | #[structopt(long = "iters", help = "The number of iterations", default_value = "1")] 17 | pub iters: usize, 18 | } 19 | 20 | fn main() { 21 | env_logger::init(); 22 | color_backtrace::install(); 23 | 24 | let opt = Opt::from_args(); 25 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 26 | let mut model = load_onnx(root.join("deit.onnx")) 27 | .expect("Failed to load model. Have you run altius_py/deit.py?"); 28 | fuse_gelu(&mut model); 29 | fuse_layer_norm(&mut model); 30 | 31 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 32 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 33 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 34 | let mean = [0.5, 0.5, 0.5][c]; 35 | let std = [0.5, 0.5, 0.5][c]; 36 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 37 | }); 38 | let input = Tensor::new(vec![1, 3, 224, 224].into(), image.into_raw_vec()); 39 | 40 | let i = InterpreterSessionBuilder::new(model) 41 | .with_profiling_enabled(opt.profile) 42 | .build() 43 | .unwrap(); 44 | for _ in 0..opt.iters { 45 | let out = i.run(vec![input.clone()]).expect("Inference failed"); 46 | let mut out = out[0].data::().iter().enumerate().collect::>(); 47 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 48 | 49 | let classes = fs::read_to_string(Path::new(&root).join("imagenet_classes.txt")).unwrap(); 50 | let classes = classes.split('\n').collect::>(); 51 | println!("inferred: {}", classes[out[0].0]); 52 | println!("top5: {:?}", &out[..5]); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /crates/session_interpreter/examples/infer.rs: -------------------------------------------------------------------------------- 1 | use altius_core::onnx::load_onnx; 2 | use altius_core::optimize::gelu_fusion::fuse_gelu; 3 | use altius_core::tensor::{Tensor, TensorElemType}; 4 | use altius_session_interpreter::InterpreterSessionBuilder; 5 | use std::path::PathBuf; 6 | use std::process::exit; 7 | use std::time::Instant; 8 | use structopt::StructOpt; 9 | 10 | #[derive(Debug, StructOpt)] 11 | #[structopt(name = "compile")] 12 | pub struct Opt { 13 | #[structopt(parse(from_os_str))] 14 | pub onnx_path: PathBuf, 15 | 16 | #[structopt(long = "profile", help = "Enable profiling")] 17 | pub profile: bool, 18 | 19 | #[structopt( 20 | long = "threads", 21 | help = "The number of threads for computation", 22 | default_value = "1" 23 | )] 24 | pub threads: usize, 25 | } 26 | 27 | fn main() { 28 | env_logger::init(); 29 | 30 | let opt = Opt::from_args(); 31 | 32 | log::info!("load onnx: start ({:?})", opt.onnx_path); 33 | let start = Instant::now(); 34 | let mut model = load_onnx(opt.onnx_path).unwrap(); 35 | log::info!("load onnx: finished in {:?}", start.elapsed()); 36 | 37 | fuse_gelu(&mut model); 38 | 39 | log::info!( 40 | "create session: start (profile={:?}, threads={})", 41 | opt.profile, 42 | opt.threads 43 | ); 44 | let start = Instant::now(); 45 | let sess = InterpreterSessionBuilder::new(model) 46 | .with_profiling_enabled(opt.profile) 47 | .with_intra_op_num_threads(opt.threads) 48 | .build() 49 | .unwrap(); 50 | log::info!("create session: finished in {:?}", start.elapsed()); 51 | 52 | let mut inputs = vec![]; 53 | for (i, &input_id) in sess.model().graph.inputs.iter().enumerate() { 54 | if sess.model().graph.inits.contains_key(&input_id) { 55 | continue; 56 | } 57 | 58 | let input = &sess.model().graph.values.inner()[input_id]; 59 | let name = input.name.as_deref().unwrap_or(""); 60 | let Some(shape) = input.shape.as_ref() else { 61 | log::error!( 62 | "failed to feed input({i}, name={name}): unknown shape (or dynamic shape?)" 63 | ); 64 | exit(1); 65 | }; 66 | 67 | log::info!( 68 | "feed input({i}, name={}, ty={:?}, shape={:?}): random input", 69 | name, 70 | shape.elem_ty, 71 | shape.dims 72 | ); 73 | 74 | inputs.push(Tensor::rand_of_type( 75 | shape.elem_ty, 76 | shape.dims.as_fixed_dims().unwrap(), 77 | )); 78 | } 79 | 80 | let outputs = match sess.run(inputs) { 81 | Ok(outputs) => outputs, 82 | Err(e) => { 83 | log::error!("inference failed: {:?}", e); 84 | exit(1); 85 | } 86 | }; 87 | 88 | for (i, (output, output_id)) in outputs 89 | .iter() 90 | .zip(sess.model().graph.outputs.iter()) 91 | .enumerate() 92 | { 93 | let name = sess.model().graph.values.inner()[*output_id] 94 | .name 95 | .as_deref() 96 | .unwrap_or(""); 97 | // TODO: Dirty. 98 | let stat = match output.elem_ty() { 99 | TensorElemType::F32 => { 100 | format!("{:?}", output.statistics::()) 101 | } 102 | TensorElemType::I32 => { 103 | format!("{:?}", output.statistics::()) 104 | } 105 | TensorElemType::I64 => "no stats".to_string(), 106 | TensorElemType::Bool => "no stats".to_string(), 107 | }; 108 | log::info!( 109 | "output({i}, name={}, ty={:?}, shape={:?}): {}", 110 | name, 111 | output.elem_ty(), 112 | output.dims(), 113 | stat 114 | ); 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /crates/session_interpreter/examples/mnist.rs: -------------------------------------------------------------------------------- 1 | use altius_core::onnx::load_onnx; 2 | use altius_core::tensor::*; 3 | use altius_session_interpreter::InterpreterSessionBuilder; 4 | use std::cmp::Ordering; 5 | use std::fs; 6 | use std::path::Path; 7 | use std::time::Instant; 8 | 9 | fn main() { 10 | env_logger::init(); 11 | color_backtrace::install(); 12 | 13 | let mnist_root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 14 | let mnist = load_onnx(mnist_root.join("mnist-8.onnx")).unwrap(); 15 | 16 | let mut inputs = vec![]; 17 | for line in fs::read_to_string(Path::new(&mnist_root).join("MNIST_test.txt")) 18 | .unwrap() 19 | .split('\n') 20 | { 21 | if line.is_empty() { 22 | continue; 23 | } 24 | let nums: Vec<&str> = line.split(',').collect(); 25 | let expected: i32 = nums[0].parse().unwrap(); 26 | let pixels: Tensor = Tensor::new( 27 | vec![1, 1, 28, 28].into(), 28 | nums[1..] 29 | .iter() 30 | .map(|s| s.parse::().unwrap() / 255.0) 31 | .collect::>(), 32 | ); 33 | inputs.push((expected, pixels)); 34 | } 35 | 36 | let start = Instant::now(); 37 | 38 | let validation_count = 10000; 39 | let sess = InterpreterSessionBuilder::new(mnist).build().unwrap(); 40 | 41 | let correct: i32 = inputs 42 | .iter() 43 | .take(validation_count) 44 | .map(|(expected, input)| { 45 | let v = sess.run(vec![input.clone()]).expect("Inference failed"); 46 | let inferred = v[0] 47 | .data::() 48 | .iter() 49 | .enumerate() 50 | .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) 51 | .map(|(index, _)| index) 52 | .unwrap(); 53 | (*expected == inferred as i32) as i32 54 | }) 55 | .sum(); 56 | 57 | let end = start.elapsed(); 58 | println!("elapsed: {end:?}"); 59 | println!("fps: {:?}", (validation_count as f64) / end.as_secs_f64()); 60 | 61 | // for (_expected, input) in &inputs { 62 | // for x in 0..28 { 63 | // for y in 0..28 { 64 | // let pixel = input.at(&[0, 0, x, y]); 65 | // print!("{}", if pixel > 0.5 { '#' } else { ' ' }); 66 | // } 67 | // println!(); 68 | // } 69 | // // break; 70 | // } 71 | 72 | println!("accuracy: {}", correct as f32 / validation_count as f32); 73 | } 74 | -------------------------------------------------------------------------------- /crates/session_interpreter/examples/mobilenet.rs: -------------------------------------------------------------------------------- 1 | use altius_core::{onnx::load_onnx, tensor::Tensor}; 2 | use altius_session_interpreter::InterpreterSessionBuilder; 3 | use std::cmp::Ordering; 4 | use std::fs; 5 | use std::path::Path; 6 | use structopt::StructOpt; 7 | 8 | #[derive(Debug, StructOpt)] 9 | #[structopt(name = "compile")] 10 | pub struct Opt { 11 | #[structopt(long = "profile", help = "Enable profiling")] 12 | pub profile: bool, 13 | 14 | #[structopt(long = "iters", help = "The number of iterations", default_value = "1")] 15 | pub iters: usize, 16 | } 17 | 18 | fn main() { 19 | env_logger::init(); 20 | color_backtrace::install(); 21 | 22 | let opt = Opt::from_args(); 23 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 24 | let model = load_onnx(root.join("mobilenetv3.onnx")).unwrap(); 25 | 26 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 27 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 28 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 29 | let mean = [0.485, 0.456, 0.406][c]; 30 | let std = [0.229, 0.224, 0.225][c]; 31 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 32 | }); 33 | let input = Tensor::new(vec![1, 3, 224, 224].into(), image.into_raw_vec()); 34 | 35 | let session = InterpreterSessionBuilder::new(model) 36 | .with_profiling_enabled(opt.profile) 37 | .build() 38 | .unwrap(); 39 | let classes = fs::read_to_string(Path::new(&root).join("imagenet_classes.txt")).unwrap(); 40 | for _ in 0..opt.iters { 41 | let out = session.run(vec![input.clone()]).expect("Inference failed"); 42 | let mut out = out[0].data::().iter().enumerate().collect::>(); 43 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 44 | 45 | let classes = classes.split('\n').collect::>(); 46 | println!("prediction: {}", classes[out[0].0]); 47 | println!("top5: {:?}", &out[..5]); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /crates/session_interpreter/examples/vit.rs: -------------------------------------------------------------------------------- 1 | use altius_core::optimize::gelu_fusion::fuse_gelu; 2 | use altius_core::{onnx::load_onnx, tensor::Tensor}; 3 | use altius_session_interpreter::InterpreterSessionBuilder; 4 | use std::cmp::Ordering; 5 | use std::fs; 6 | use std::path::Path; 7 | use structopt::StructOpt; 8 | 9 | #[derive(Debug, StructOpt)] 10 | #[structopt(name = "compile")] 11 | pub struct Opt { 12 | #[structopt(long = "profile", help = "Enable profiling")] 13 | pub profile: bool, 14 | 15 | #[structopt(long = "iters", help = "The number of iterations", default_value = "1")] 16 | pub iters: usize, 17 | } 18 | 19 | fn main() { 20 | env_logger::init(); 21 | color_backtrace::install(); 22 | 23 | let opt = Opt::from_args(); 24 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 25 | let mut model = load_onnx(root.join("vit_b_16.onnx")).unwrap(); 26 | fuse_gelu(&mut model); 27 | 28 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 29 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 30 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 31 | let mean = [0.485, 0.456, 0.406][c]; 32 | let std = [0.229, 0.224, 0.225][c]; 33 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 34 | }); 35 | let input = Tensor::new(vec![1, 3, 224, 224].into(), image.into_raw_vec()); 36 | 37 | let i = InterpreterSessionBuilder::new(model) 38 | .with_profiling_enabled(opt.profile) 39 | .with_intra_op_num_threads(16) 40 | .build() 41 | .unwrap(); 42 | for _ in 0..opt.iters { 43 | let out = i.run(vec![input.clone()]).expect("Inference failed"); 44 | let mut out = out[0].data::().iter().enumerate().collect::>(); 45 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 46 | 47 | let classes = fs::read_to_string(Path::new(&root).join("imagenet_classes.txt")).unwrap(); 48 | let classes = classes.split('\n').collect::>(); 49 | println!("inferred: {}", classes[out[0].0]); 50 | println!("top5: {:?}", &out[..5]); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /crates/session_interpreter/src/builder.rs: -------------------------------------------------------------------------------- 1 | use altius_core::{analysis::shape::infer_shapes, model::Model, tensor::Tensor}; 2 | use altius_session::{plan::create_execution_plan, SessionError}; 3 | use rustc_hash::FxHashMap; 4 | use thread_local::ThreadLocal; 5 | 6 | #[cfg(feature = "cuda")] 7 | use super::session::SafeCudnnContext; 8 | use super::{session::InterpreterSession, thread::ThreadCtx}; 9 | #[cfg(feature = "cuda")] 10 | use cudnn::CudnnContext; 11 | 12 | pub struct InterpreterSessionBuilder { 13 | model: Model, 14 | intra_op_num_threads: usize, 15 | enable_profiling: bool, 16 | } 17 | 18 | impl InterpreterSessionBuilder { 19 | pub const fn new(model: Model) -> Self { 20 | Self { 21 | model, 22 | intra_op_num_threads: 1, 23 | enable_profiling: false, 24 | } 25 | } 26 | 27 | pub const fn with_intra_op_num_threads(mut self, intra_op_num_threads: usize) -> Self { 28 | self.intra_op_num_threads = intra_op_num_threads; 29 | self 30 | } 31 | 32 | pub const fn with_profiling_enabled(mut self, enable_profiling: bool) -> Self { 33 | self.enable_profiling = enable_profiling; 34 | self 35 | } 36 | 37 | pub fn build(self) -> Result { 38 | let model = self.model; 39 | let enable_profiling = self.enable_profiling; 40 | let intra_op_num_threads = self.intra_op_num_threads; 41 | 42 | let mut inferred_shapes = FxHashMap::default(); 43 | infer_shapes(&model, &mut inferred_shapes, &mut FxHashMap::default())?; 44 | 45 | #[cfg(target_os = "linux")] 46 | { 47 | // Suppose that blis is used for BLAS. 48 | extern "C" { 49 | fn bli_thread_set_num_threads(n_threads: usize); 50 | } 51 | unsafe { bli_thread_set_num_threads(intra_op_num_threads) }; 52 | } 53 | 54 | Ok(InterpreterSession { 55 | #[cfg(feature = "cuda")] 56 | cudnn_ctx: SafeCudnnContext(CudnnContext::new().expect("cudnn context init failed")), 57 | execution_plans: create_execution_plan(&model), 58 | model, 59 | inferred_shapes, 60 | enable_profiling, 61 | values: ThreadLocal::new(), 62 | dummy_value: Tensor::zeros::(vec![0].into()), 63 | tctx: ThreadCtx::new_with_num_threads(intra_op_num_threads), 64 | }) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /crates/session_interpreter/src/gemm.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::too_many_arguments)] 2 | 3 | pub fn sgemm( 4 | m: usize, 5 | k: usize, 6 | n: usize, 7 | alpha: f32, 8 | a: &[f32], 9 | lda: usize, 10 | b: &[f32], 11 | ldb: usize, 12 | beta: f32, 13 | c: &mut [f32], 14 | ldc: usize, 15 | ) { 16 | unsafe { 17 | #[cfg(not(feature = "cblas"))] 18 | matrixmultiply::sgemm( 19 | m, 20 | k, 21 | n, 22 | alpha, 23 | a.as_ptr(), 24 | lda as isize, 25 | 1, 26 | b.as_ptr(), 27 | ldb as isize, 28 | 1, 29 | beta, 30 | c.as_mut_ptr(), 31 | ldc as isize, 32 | 1, 33 | ); 34 | #[cfg(feature = "cblas")] 35 | { 36 | cblas_sys::cblas_sgemm( 37 | cblas_sys::CblasRowMajor, 38 | cblas_sys::CblasNoTrans, 39 | cblas_sys::CblasNoTrans, 40 | m as i32, 41 | n as i32, 42 | k as i32, 43 | alpha, 44 | a.as_ptr(), 45 | lda as i32, 46 | b.as_ptr(), 47 | ldb as i32, 48 | beta, 49 | c.as_mut_ptr(), 50 | ldc as i32, 51 | ); 52 | } 53 | } 54 | } 55 | 56 | #[allow(dead_code)] 57 | #[allow(unused_variables)] 58 | pub fn sgemm2( 59 | trans_a: bool, 60 | trans_b: bool, 61 | m: usize, 62 | k: usize, 63 | n: usize, 64 | alpha: f32, 65 | a: &[f32], 66 | lda: usize, 67 | b: &[f32], 68 | ldb: usize, 69 | beta: f32, 70 | c: &mut [f32], 71 | ldc: usize, 72 | ) { 73 | #[cfg(not(feature = "cblas"))] 74 | panic!(); 75 | // matrixmultiply::sgemm( 76 | // m, 77 | // k, 78 | // n, 79 | // alpha, 80 | // a.as_ptr(), 81 | // lda as isize, 82 | // 1, 83 | // b.as_ptr(), 84 | // ldb as isize, 85 | // 1, 86 | // beta, 87 | // c.as_mut_ptr(), 88 | // ldc as isize, 89 | // 1, 90 | // ); 91 | 92 | #[cfg(feature = "cblas")] 93 | { 94 | unsafe { 95 | cblas_sys::cblas_sgemm( 96 | cblas_sys::CblasRowMajor, 97 | if trans_a { 98 | cblas_sys::CblasTrans 99 | } else { 100 | cblas_sys::CblasNoTrans 101 | }, 102 | if trans_b { 103 | cblas_sys::CblasTrans 104 | } else { 105 | cblas_sys::CblasNoTrans 106 | }, 107 | m as i32, 108 | n as i32, 109 | k as i32, 110 | alpha, 111 | a.as_ptr(), 112 | lda as i32, 113 | b.as_ptr(), 114 | ldb as i32, 115 | beta, 116 | c.as_mut_ptr(), 117 | ldc as i32, 118 | ); 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /crates/session_interpreter/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(portable_simd)] 2 | #![allow(clippy::excessive_precision)] 3 | 4 | #[cfg(all(feature = "cblas", target_os = "macos"))] 5 | #[allow(unused)] 6 | #[allow(clippy::single_component_path_imports)] 7 | use blas_src; // For accelerate, this is necessary to link the library. 8 | #[cfg(all(feature = "cblas", target_os = "linux"))] 9 | #[allow(unused)] 10 | #[allow(clippy::single_component_path_imports)] 11 | use blis_src; 12 | 13 | mod builder; 14 | mod conv2d; 15 | mod fast_math; 16 | mod gemm; 17 | mod session; 18 | mod thread; 19 | 20 | pub use builder::InterpreterSessionBuilder; 21 | pub use session::InterpreterSession; 22 | 23 | #[cfg(not(target_arch = "wasm32"))] 24 | use mimalloc::MiMalloc; 25 | 26 | #[cfg(not(target_arch = "wasm32"))] 27 | #[global_allocator] 28 | static GLOBAL: MiMalloc = MiMalloc; 29 | -------------------------------------------------------------------------------- /crates/session_interpreter/src/thread.rs: -------------------------------------------------------------------------------- 1 | #[cfg(target_arch = "wasm32")] 2 | use std::marker::PhantomData; 3 | 4 | #[cfg(not(target_arch = "wasm32"))] 5 | use threadpool::ThreadPool; 6 | 7 | pub struct ThreadCtx { 8 | #[cfg(not(target_arch = "wasm32"))] 9 | pub tp: ThreadPool, 10 | 11 | #[cfg(not(target_arch = "wasm32"))] 12 | num_threads: usize, 13 | } 14 | 15 | pub struct Scope<'a> { 16 | #[cfg(not(target_arch = "wasm32"))] 17 | ctx: &'a ThreadCtx, 18 | 19 | #[cfg(target_arch = "wasm32")] 20 | _phantom: PhantomData<&'a fn()>, 21 | } 22 | 23 | impl<'a> Scope<'a> { 24 | pub fn spawn(&self, f: F) 25 | where 26 | F: FnOnce() + Send + 'a, 27 | { 28 | #[cfg(not(target_arch = "wasm32"))] 29 | { 30 | if self.ctx.num_threads() == 1 { 31 | f() 32 | } else { 33 | let f = unsafe { 34 | std::mem::transmute::< 35 | Box, 36 | Box, 37 | >(Box::new(f)) 38 | }; 39 | self.ctx.tp.execute(f) 40 | } 41 | } 42 | 43 | #[cfg(target_arch = "wasm32")] 44 | { 45 | f() 46 | } 47 | } 48 | } 49 | 50 | impl ThreadCtx { 51 | #[allow(dead_code)] 52 | #[cfg(not(target_arch = "wasm32"))] 53 | pub fn new() -> Self { 54 | let tp = ThreadPool::new(1); 55 | tp.execute(move || { 56 | core_affinity::set_for_current(core_affinity::CoreId { id: 0 }); 57 | }); 58 | tp.join(); 59 | Self { tp, num_threads: 1 } 60 | } 61 | 62 | #[allow(dead_code)] 63 | #[cfg(target_arch = "wasm32")] 64 | pub fn new() -> Self { 65 | Self {} 66 | } 67 | 68 | #[cfg(not(target_arch = "wasm32"))] 69 | pub fn new_with_num_threads(n: usize) -> Self { 70 | #[cfg(target_os = "linux")] 71 | let apicid_to_processor = if let Ok(cpuinfo) = procfs::CpuInfo::new() { 72 | let mut apicid_to_processor = vec![0; cpuinfo.cpus.len()]; 73 | let is_bijective = cpuinfo 74 | .cpus 75 | .iter() 76 | .enumerate() 77 | .all(|(i, c)| c["apicid"].parse::().unwrap_or(i) < cpuinfo.cpus.len()); 78 | if is_bijective { 79 | for (i, cpu) in cpuinfo.cpus.iter().enumerate() { 80 | *apicid_to_processor 81 | .get_mut(cpu["apicid"].parse().unwrap_or(i)) 82 | .unwrap() = cpu["processor"].parse().unwrap_or(i); 83 | } 84 | apicid_to_processor 85 | } else { 86 | (0..n).collect::>() 87 | } 88 | } else { 89 | (0..n).collect::>() 90 | }; 91 | #[cfg(not(target_os = "linux"))] 92 | let apicid_to_processor = (0..n).collect::>(); 93 | let tp = ThreadPool::new(n); 94 | if apicid_to_processor.len() >= n { 95 | for &id in &apicid_to_processor[0..n] { 96 | tp.execute(move || { 97 | core_affinity::set_for_current(core_affinity::CoreId { id }); 98 | }) 99 | } 100 | } 101 | tp.join(); 102 | Self { tp, num_threads: n } 103 | } 104 | 105 | #[cfg(target_arch = "wasm32")] 106 | pub fn new_with_num_threads(_n: usize) -> Self { 107 | Self {} 108 | } 109 | 110 | #[cfg(not(target_arch = "wasm32"))] 111 | pub const fn num_threads(&self) -> usize { 112 | self.num_threads 113 | } 114 | 115 | #[cfg(target_arch = "wasm32")] 116 | pub fn num_threads(&self) -> usize { 117 | 1 118 | } 119 | 120 | #[cfg(not(target_arch = "wasm32"))] 121 | pub fn scope(&self, mut f: F) 122 | where 123 | F: FnMut(&Scope), 124 | { 125 | let scope = Scope { ctx: self }; 126 | f(&scope); 127 | if self.num_threads != 1 { 128 | self.tp.join(); 129 | } 130 | } 131 | 132 | #[cfg(target_arch = "wasm32")] 133 | pub fn scope(&self, mut f: F) 134 | where 135 | F: FnMut(&Scope), 136 | { 137 | let scope = Scope { 138 | _phantom: PhantomData::default(), 139 | }; 140 | f(&scope); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /crates/session_interpreter/tests/mobilenet.rs: -------------------------------------------------------------------------------- 1 | use altius_core::dim::Dimension; 2 | use altius_core::{onnx::load_onnx, tensor::Tensor}; 3 | use altius_session_interpreter::InterpreterSessionBuilder; 4 | use std::cmp::Ordering; 5 | use std::path::Path; 6 | 7 | #[test] 8 | fn mobilenet() { 9 | env_logger::init(); 10 | let root = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../models"); 11 | let mut model = load_onnx(root.join("mobilenetv3.onnx")).unwrap(); 12 | let input_value = model.lookup_named_value("input").unwrap(); 13 | 14 | // Change input batch size from 1 to 4. 15 | model.graph.values.inner_mut()[input_value] 16 | .shape 17 | .as_mut() 18 | .unwrap() 19 | .dims 20 | .0[0] = Dimension::Fixed(4); 21 | 22 | let image = image::open(root.join("cat.png")).unwrap().to_rgb8(); 23 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 24 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 25 | let mean = [0.485, 0.456, 0.406][c]; 26 | let std = [0.229, 0.224, 0.225][c]; 27 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 28 | }); 29 | let input = Tensor::new( 30 | vec![4, 3, 224, 224].into(), 31 | image 32 | .clone() 33 | .into_raw_vec() 34 | .into_iter() 35 | .chain(image.clone().into_iter()) 36 | .chain(image.clone().into_iter()) 37 | .chain(image.into_iter()) 38 | .collect::>(), 39 | ); 40 | 41 | let i = InterpreterSessionBuilder::new(model) 42 | .with_profiling_enabled(true) 43 | .build() 44 | .unwrap(); 45 | let out = i.run(vec![input]).expect("Inference failed"); 46 | let mut out = out[0].data::().iter().enumerate().collect::>(); 47 | out[0..1000].sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 48 | out[1000..2000].sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 49 | out[2000..3000].sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 50 | out[3000..4000].sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 51 | 52 | assert!(out[0].0 == 285 && out[1000].0 == 1285 && out[2000].0 == 2285 && out[3000].0 == 3285); 53 | } 54 | -------------------------------------------------------------------------------- /crates/session_interpreter/tests/op_bin.rs: -------------------------------------------------------------------------------- 1 | use altius_core::{fixed_dim::FixedDimensions, model::Model, node::Node, op::Op, tensor::Tensor}; 2 | use altius_session_interpreter::InterpreterSessionBuilder; 3 | 4 | macro_rules! test_op { 5 | ($name:ident, $op:ident, $shape:expr) => { 6 | #[test] 7 | fn $name() { 8 | Tensor::seed_rng_from_u64(42); 9 | $op($shape.into()) 10 | } 11 | }; 12 | 13 | (! $name:ident, $op:ident, $shape:expr) => { 14 | #[test] 15 | #[should_panic] 16 | fn $name() { 17 | Tensor::seed_rng_from_u64(42); 18 | $op($shape.into()) 19 | } 20 | }; 21 | } 22 | 23 | macro_rules! op { 24 | ($name:ident, $op:ident, $bin:tt) => { 25 | #[cfg(test)] 26 | fn $name(shape: FixedDimensions) { 27 | let mut model = Model::default(); 28 | let x = model.graph.values.new_val_named("x"); 29 | let y = model.graph.values.new_val_named("y"); 30 | let z = model.graph.values.new_val_named("z"); 31 | 32 | model.graph.add_node(Node::new(Op::$op).with_ins(vec![x, y]).with_out(z)); 33 | model.graph.inputs.push(x); 34 | model.graph.inputs.push(y); 35 | model.graph.outputs.push(z); 36 | 37 | let sess = InterpreterSessionBuilder::new(model).build().unwrap(); 38 | let x_val = Tensor::rand::(shape.to_owned()); 39 | let y_val = Tensor::rand::(shape); 40 | 41 | let expected = x_val 42 | .data::() 43 | .iter() 44 | .zip(y_val.data::().iter()) 45 | .map(|(&x, &y)| x $bin y) 46 | .collect::>(); 47 | let actual = sess.run(vec![x_val, y_val]).unwrap(); 48 | assert_eq!(actual.len(), 1); 49 | assert!(allclose(actual[0].data::(), expected.as_slice()), 50 | "actual: {:?} vs expected: {:?}", 51 | &actual[0].data::()[..10], &expected.as_slice()[..10]); 52 | } 53 | }; 54 | } 55 | 56 | op!(op_add, Add, +); 57 | op!(op_sub, Sub, -); 58 | op!(op_mul, Mul, *); 59 | op!(op_div, Div, /); 60 | 61 | test_op!(test_op_add_1, op_add, vec![1, 2]); 62 | test_op!(test_op_add_2, op_add, vec![3, 1, 10]); 63 | test_op!(test_op_add_3, op_add, vec![128, 3, 224, 224]); 64 | 65 | test_op!(test_op_sub_1, op_sub, vec![1, 2]); 66 | test_op!(test_op_sub_2, op_sub, vec![3, 1, 10]); 67 | test_op!(test_op_sub_3, op_sub, vec![128, 3, 224, 224]); 68 | 69 | test_op!(test_op_mul_1, op_mul, vec![1, 2]); 70 | test_op!(test_op_mul_2, op_mul, vec![3, 1, 10]); 71 | test_op!(test_op_mul_3, op_mul, vec![128, 3, 224, 224]); 72 | 73 | test_op!(test_op_div_1, op_div, vec![1, 2]); 74 | test_op!(test_op_div_2, op_div, vec![3, 1, 10]); 75 | test_op!(test_op_div_3, op_div, vec![128, 3, 224, 224]); 76 | 77 | #[cfg(test)] 78 | fn allclose(x: &[f32], y: &[f32]) -> bool { 79 | let atol = 1e-5; 80 | let rtol = 1e-5; 81 | 82 | if x.len() != y.len() { 83 | return false; 84 | } 85 | 86 | x.iter().zip(y.iter()).all(|(x, y)| { 87 | ((x - y).abs() <= (atol + rtol * y.abs())) 88 | || (x.is_infinite() && y.is_infinite() && x.is_sign_positive() == y.is_sign_positive()) 89 | }) 90 | } 91 | -------------------------------------------------------------------------------- /crates/wasm/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | dist 3 | -------------------------------------------------------------------------------- /crates/wasm/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "altius_wasm" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | crate-type = ["cdylib"] 8 | 9 | [dependencies] 10 | wasm-bindgen = "0.2" 11 | altius-core = { path = "../core" } 12 | altius_session = { path = "../session" } 13 | altius_session_interpreter = { path = "../session_interpreter" } 14 | console_error_panic_hook = "0.1.7" 15 | image = "0.24.2" 16 | ndarray = "0.15.6" 17 | -------------------------------------------------------------------------------- /crates/wasm/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "altius_wasm", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "build": "WEBPACK_PRODUCTION=1 webpack --progress", 8 | "serve": "webpack-dev-server --progress" 9 | }, 10 | "author": "", 11 | "license": "ISC", 12 | "dependencies": { 13 | "@babel/core": "^7.20.5", 14 | "@babel/preset-env": "^7.20.2", 15 | "@babel/preset-react": "^7.18.6", 16 | "@babel/preset-typescript": "^7.18.6", 17 | "@babel/register": "^7.18.9", 18 | "@emotion/react": "^11.10.5", 19 | "@emotion/styled": "^11.10.5", 20 | "@mui/icons-material": "^5.11.0", 21 | "@mui/material": "^5.11.0", 22 | "@types/react": "^18.0.26", 23 | "@types/react-dom": "^18.0.9", 24 | "babel-loader": "^9.1.0", 25 | "react": "^18.2.0", 26 | "react-dom": "^18.2.0", 27 | "ts-loader": "^9.4.2", 28 | "ts-node": "^10.9.1", 29 | "typescript": "^4.9.4", 30 | "webpack": "^5.75.0", 31 | "webpack-cli": "^5.0.1", 32 | "webpack-dev-server": "^4.11.1" 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /crates/wasm/src/index.tsx: -------------------------------------------------------------------------------- 1 | import { Button } from "@mui/material"; 2 | import React, { useState } from "react"; 3 | import { createRoot } from 'react-dom/client'; 4 | import Box from '@mui/material/Box'; 5 | import Card from '@mui/material/Card'; 6 | import CardActions from '@mui/material/CardActions'; 7 | import CardContent from '@mui/material/CardContent'; 8 | import Typography from '@mui/material/Typography'; 9 | import init, { load_and_run } from "../pkg/altius_wasm.js"; 10 | 11 | const App: React.FC = () => { 12 | const [image, setImage] = useState(); 13 | const [model, setModel] = useState(); 14 | const [imgBtnVariant, setImgBtnVariant] = useState('outlined'); 15 | 16 | const setImgUploaded = () => { 17 | if (imgBtnVariant === 'outlined') { 18 | setImgBtnVariant('contained'); 19 | } else { 20 | setImgBtnVariant('outlined'); 21 | } 22 | } 23 | 24 | function loadImage() { 25 | const img = document.querySelector("#img"); 26 | if (!img || !img.files) return; 27 | 28 | const file = img.files[0]; 29 | const reader = new FileReader(); 30 | 31 | reader.onload = (_: Event) => { 32 | const loadedImage = new Uint8Array(reader.result as ArrayBufferLike); 33 | setImage(loadedImage); 34 | console.log(reader.result as string); 35 | var blob = new Blob([loadedImage]); 36 | var urlCreator = window.URL || window.webkitURL; 37 | var imageUrl = urlCreator.createObjectURL( blob ); 38 | (document.getElementById("image") as HTMLImageElement).src = imageUrl; 39 | }; 40 | 41 | reader.readAsArrayBuffer(file); 42 | setImgUploaded(); 43 | } 44 | 45 | function loadModel() { 46 | const onnx = document.querySelector("#onnx"); 47 | if (!onnx || !onnx.files) return; 48 | 49 | const file = onnx.files[0]; 50 | const reader = new FileReader(); 51 | 52 | reader.onload = (_: Event) => { 53 | const model = new Uint8Array(reader.result as ArrayBufferLike); 54 | setModel(model); 55 | }; 56 | 57 | reader.readAsArrayBuffer(file); 58 | } 59 | 60 | function setResultHtml(html: string) { 61 | (document.getElementById("results") as HTMLDivElement).innerHTML = html; 62 | } 63 | 64 | function runInference() { 65 | if (!model || !image) return; 66 | const msg = load_and_run(model, image); 67 | setResultHtml(msg ?? "Failed to run inference"); 68 | } 69 | 70 | return ( 71 |
72 |

73 | Altius on Web 74 |

75 | 76 | 77 | 83 | 84 | 85 | 86 | 87 | 88 | Select an image and an ONNX model first. 89 | 90 | 91 | 92 | 93 | 94 | 103 | 112 | 120 | 121 | 122 |
123 | ); 124 | }; 125 | 126 | init(); 127 | 128 | const container = document.getElementById('app'); 129 | if (container) { 130 | const root = createRoot(container); 131 | root.render(); 132 | } 133 | -------------------------------------------------------------------------------- /crates/wasm/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::{cmp::Ordering, io::Cursor}; 2 | 3 | use altius_core::{onnx::load_onnx_from_buffer, tensor::Tensor}; 4 | use altius_session_interpreter::InterpreterSessionBuilder; 5 | use image::io::Reader; 6 | use wasm_bindgen::prelude::*; 7 | 8 | #[wasm_bindgen] 9 | extern "C" { 10 | pub fn alert(s: &str); 11 | } 12 | 13 | #[wasm_bindgen] 14 | pub fn load_and_run(onnx: &[u8], img: &[u8]) -> Option { 15 | std::panic::set_hook(Box::new(console_error_panic_hook::hook)); 16 | 17 | let model = load_onnx_from_buffer(onnx).expect("failed to load onnx"); 18 | let sess = InterpreterSessionBuilder::new(model).build().unwrap(); 19 | 20 | let image = Reader::new(Cursor::new(img)) 21 | .with_guessed_format() 22 | .unwrap() 23 | .decode() 24 | .unwrap() 25 | .to_rgb8(); 26 | 27 | let resized = image::imageops::resize(&image, 224, 224, image::imageops::FilterType::Triangle); 28 | let image = ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { 29 | let mean = [0.485, 0.456, 0.406][c]; 30 | let std = [0.229, 0.224, 0.225][c]; 31 | (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std 32 | }); 33 | let input = Tensor::new( 34 | vec![1, 3, 224, 224].into(), 35 | image.into_raw_vec().into_iter().collect::>(), 36 | ); 37 | 38 | let results = sess.run(vec![input]).ok()?; 39 | 40 | let mut out = results[0] 41 | .data::() 42 | .iter() 43 | .enumerate() 44 | .collect::>(); 45 | out.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(Ordering::Equal)); 46 | 47 | let classes = include_str!("../../../models/imagenet_classes.txt"); 48 | let classes = classes.split('\n').collect::>(); 49 | let mut result_str = "".to_string(); 50 | for i in 0..5 { 51 | result_str.push_str(&format!("top{}: {}
", i + 1, classes[out[i].0])); 52 | } 53 | 54 | Some(result_str) 55 | } 56 | -------------------------------------------------------------------------------- /crates/wasm/static/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0px; 3 | } 4 | 5 | h1 { 6 | margin-top: 10px; 7 | margin-bottom: 10px; 8 | } 9 | 10 | #main { 11 | margin: 20px; 12 | } 13 | 14 | #image { 15 | width: 224px; 16 | height: 224px; 17 | border: 1px solid; 18 | } 19 | 20 | #results { 21 | margin-top: 10px; 22 | } 23 | -------------------------------------------------------------------------------- /crates/wasm/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Altius on Web 6 | 7 | 8 | 9 | 10 |
11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /crates/wasm/webpack.config.ts: -------------------------------------------------------------------------------- 1 | import path from 'path'; 2 | import { Configuration } from 'webpack'; 3 | import WebpackDevServer from 'webpack-dev-server'; 4 | 5 | const config: Configuration = { 6 | mode: process.env.WEBPACK_PRODUCTION ? "production" : "development", 7 | context: path.join(__dirname, 'src'), 8 | entry: { 9 | index: [path.resolve(__dirname, "src", "index.tsx")], 10 | }, 11 | output: { 12 | path: path.join(__dirname, 'dist'), 13 | filename: '[name].bundle.js', 14 | }, 15 | module: { 16 | rules: [ 17 | { 18 | test: /\.tsx?$/, 19 | use: 'ts-loader', 20 | }, 21 | ], 22 | }, 23 | resolve: { 24 | extensions: ['.ts', '.tsx', '.js', '.jsx'], 25 | }, 26 | optimization: { 27 | usedExports: !!process.env.WEBPACK_PRODUCTION, 28 | }, 29 | devtool: process.env.WEBPACK_PRODUCTION ? false : "eval-cheap-source-map", 30 | devServer: { 31 | static: { 32 | directory: "static", 33 | publicPath: "/", 34 | }, 35 | onBeforeSetupMiddleware: (devserver: WebpackDevServer) => { 36 | devserver.app?.use("/", (req, res, next) => { 37 | console.log(`${req.ip} - ${req.method} - ${req.originalUrl}`); 38 | next(); 39 | }); 40 | }, 41 | }, 42 | }; 43 | 44 | export default config; 45 | -------------------------------------------------------------------------------- /models/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | if [ "${1:-}" = "CI" ]; then 4 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/mnist-8.onnx 5 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/mobilenetv3.onnx 6 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/cat.png 7 | else 8 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/mnist-8.onnx 9 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/mobilenetv3.onnx 10 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/deeplab_mobilenetv3.onnx 11 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/fcn-resnet50.onnx 12 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/yolov5s.onnx 13 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/realesrgan_256x256.onnx 14 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/cat.png 15 | wget https://pub-edba5feea2c145019e8be2a71dbeea81.r2.dev/dog.jpg 16 | fi 17 | 18 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly-2025-06-01" 3 | components = [ "rustfmt", "rust-analyzer", "clippy" ] 4 | -------------------------------------------------------------------------------- /snippets/coreml/mobilenet.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import coremltools 4 | import torch 5 | import torchvision 6 | 7 | from PIL import Image 8 | import numpy as np 9 | from torchvision import transforms 10 | 11 | 12 | labels = open("../../models/imagenet_classes.txt").readlines() 13 | image = Image.open("../../models/cat.png") 14 | 15 | model = torchvision.models.mobilenet_v3_large(pretrained=True) 16 | model.eval() 17 | model = torch.jit.trace(model, torch.zeros(1, 3, 224, 224)) 18 | 19 | coreml_model = coremltools.convert( 20 | model, 21 | inputs=[coremltools.TensorType(name="input_1", shape=(1, 3, 224, 224))], 22 | outputs=[coremltools.TensorType(name="output_1")], 23 | ) 24 | 25 | preprocess = transforms.Compose( 26 | [ 27 | transforms.Resize(256), 28 | transforms.CenterCrop(224), 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 31 | ] 32 | ) 33 | input = preprocess(image) 34 | input = input.unsqueeze(0).numpy() 35 | 36 | for i in range(100): 37 | start = time.time() 38 | pred = coreml_model.predict({"input_1": input})["output_1"][0] 39 | print(f"elapsed: {(time.time() - start) * 1000:.2f}ms") 40 | output = np.argsort(pred)[::-1][:5] 41 | output = [labels[i].strip() for i in output] 42 | print(f"top5: {output}") 43 | -------------------------------------------------------------------------------- /snippets/coreml/requirements.txt: -------------------------------------------------------------------------------- 1 | coremltools==8.0b2 2 | torch==2.3.0 3 | torchvision 4 | -------------------------------------------------------------------------------- /snippets/cuda/Makefile: -------------------------------------------------------------------------------- 1 | # SRCS=$(wildcard *.cc) 2 | # OBJS=$(SRCS:%.cc=%.o) 3 | # DEPS=$(SRCS:%.cc=%.d) 4 | CCFLAGS= 5 | LDFLAGS= 6 | CC=clang 7 | 8 | all: main 9 | 10 | main: main.c libcuda-gemm-act.so 11 | $(CC) -O3 $(LDFLAGS) -o $@ main.c -L. -lcuda-gemm-act -lblis 12 | 13 | libcuda-gemm-act.so: cuda-gemm-act.cu 14 | nvcc -O3 -shared -Xcompiler -fPIC -lcublas -lcurand -o $@ $< 15 | 16 | format: main.c cuda-gemm-act.cu 17 | clang-format -i -style=llvm $^ 18 | 19 | test: all 20 | @./main 21 | 22 | clean: 23 | @rm -rf main *.so 24 | 25 | -include $(DEPS) 26 | 27 | .PHONY: all test clean 28 | -------------------------------------------------------------------------------- /snippets/cuda/cuda-gemm-act.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define ATTEMPT 10 12 | 13 | double now_in_sec() { 14 | struct timeval tv; 15 | gettimeofday(&tv, NULL); 16 | return (double)tv.tv_sec + (double)tv.tv_usec / 1000.f / 1000.f; 17 | } 18 | 19 | __global__ void relu(float *x, int n) { 20 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 21 | if (i < n) { 22 | x[i] = x[i] > 0.f ? x[i] : 0.f; 23 | } 24 | } 25 | 26 | __global__ void sigmoid(float *x, int n) { 27 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 28 | if (i < n) { 29 | x[i] = 1.f / (1.f + expf(-x[i])); 30 | } 31 | } 32 | 33 | extern "C" void entry() { 34 | cublasHandle_t handle; 35 | cublasCreate(&handle); 36 | 37 | float *lhs, *rhs, *result; 38 | const int m = 1000, k = 200, n = 100; 39 | const float alpha = 1.f, beta = 0.f; 40 | 41 | cudaMalloc(&lhs, m * k * sizeof(float)); 42 | cudaMalloc(&rhs, k * n * sizeof(float)); 43 | cudaMalloc(&result, m * n * sizeof(float)); 44 | 45 | { 46 | // fill lhs and rhs with random numbers 47 | curandGenerator_t gen; 48 | curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT); 49 | curandSetPseudoRandomGeneratorSeed(gen, 1234ULL); 50 | curandGenerateUniform(gen, lhs, m * k); 51 | curandGenerateUniform(gen, rhs, k * n); 52 | curandDestroyGenerator(gen); 53 | } 54 | 55 | for (int attempt = 0; attempt < ATTEMPT; attempt++) { 56 | const double start = now_in_sec(); 57 | for (int i = 0; i < 1000; i++) { 58 | cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, rhs, n, 59 | lhs, k, &beta, result, n); 60 | // relu<<<(m * n + 31) / 32, 32>>>(result, m * n); 61 | // sigmoid<<<(m * n + 31) / 32, 32>>>(result, m * n); 62 | } 63 | const double end = now_in_sec(); 64 | printf("GPU Time: %lf[ms]\n", (end - start) * 1000.0); 65 | } 66 | 67 | float *gpu_result = (float *)malloc(m * n * sizeof(float)); 68 | cudaMemcpy(gpu_result, result, m * n * sizeof(float), cudaMemcpyDeviceToHost); 69 | 70 | float *cpu_lhs = (float *)malloc(m * k * sizeof(float)); 71 | float *cpu_rhs = (float *)malloc(k * n * sizeof(float)); 72 | float *cpu_result = (float *)malloc(m * n * sizeof(float)); 73 | 74 | cudaMemcpy(cpu_lhs, lhs, m * k * sizeof(float), cudaMemcpyDeviceToHost); 75 | cudaMemcpy(cpu_rhs, rhs, k * n * sizeof(float), cudaMemcpyDeviceToHost); 76 | 77 | for (int attempt = 0; attempt < ATTEMPT; attempt++) { 78 | const double start = now_in_sec(); 79 | for (int i = 0; i < 1000; i++) { 80 | cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, 81 | cpu_lhs, k, cpu_rhs, n, beta, cpu_result, n); 82 | } 83 | const double end = now_in_sec(); 84 | printf("CPU Time: %lf[ms]\n", (end - start) * 1000.0); 85 | } 86 | 87 | for (int i = 0; i < m * n; i++) { 88 | const float diff = fabs(gpu_result[i] - cpu_result[i]); 89 | assert(diff < 1e-3); 90 | } 91 | 92 | cudaFree(lhs); 93 | cudaFree(rhs); 94 | cudaFree(result); 95 | } 96 | -------------------------------------------------------------------------------- /snippets/cuda/main.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int entry(); 4 | 5 | int main() { 6 | entry(); 7 | return 0; 8 | } 9 | -------------------------------------------------------------------------------- /snippets/float.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define MAX(a, b) ((a) > (b) ? (a) : (b)) 5 | 6 | typedef struct { 7 | unsigned int frac : 23; 8 | unsigned int exp : 8; 9 | unsigned int sign : 1; 10 | } f32; 11 | 12 | float add(float x, float y) { 13 | const f32 xi = *(f32*)&x; 14 | const f32 yi = *(f32*)&y; 15 | printf("xi.exp = %d, yi.exp = %d\n", xi.exp, yi.exp); 16 | unsigned int exp = MAX(xi.exp, yi.exp); 17 | unsigned int xi_frac = (xi.frac | (1 << 23)) >> (exp - xi.exp); 18 | unsigned int yi_frac = (yi.frac | (1 << 23)) >> (exp - yi.exp); 19 | const int carry = (xi_frac + yi_frac) > 0xffffff; 20 | 21 | const f32 zi = { 22 | .sign = 0, 23 | .exp = exp + carry, 24 | .frac = ((xi_frac + yi_frac) >> carry) & 0x7fffff, 25 | }; 26 | return *(float*)&zi; 27 | 28 | #if 0 29 | const int xi = *(int*)&x; 30 | const int yi = *(int*)&y; 31 | const int xi_sign = (xi >> 31) & 0x1; 32 | const int xi_exp = (xi >> 23) & 0x7f; 33 | const int xi_frac = xi & 0x7fffff; 34 | const int yi_sign = (yi >> 31) & 0x1; 35 | const int yi_exp = (yi >> 23) & 0x7f; 36 | const int yi_frac = yi & 0x7fffff; 37 | assert(xi_sign == 0); 38 | assert(yi_sign == 0); 39 | assert(xi_exp == 127); 40 | assert(yi_exp == 127 - 1); 41 | /* assert(xi_frac == 0); */ 42 | /* assert(yi_frac == 0); */ 43 | const int zi_sign = 0; 44 | const int zi_exp = 127; 45 | const int zi_frac = xi_frac + (yi_frac >> 1) + 0x400000; 46 | const int zi = (zi_sign << 31) | (zi_exp << 23) | zi_frac; 47 | return *(float*)&zi; 48 | #endif 49 | } 50 | 51 | int main() { 52 | /* 1 + 8 + 7 */ 53 | /* sign exp frac */ 54 | float x = 10.2; 55 | float y = 0.1; 56 | float z = add(x, y); 57 | printf("%f\n", z); 58 | return 0; 59 | } 60 | -------------------------------------------------------------------------------- /snippets/onnx_float16.py: -------------------------------------------------------------------------------- 1 | # Convert float32 values into float16 values in ONNX model 2 | 3 | import argparse 4 | 5 | import onnx 6 | from onnxmltools.utils.float16_converter import convert_float_to_float16 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--input", type=str, required=True) 12 | parser.add_argument("--output", type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | model = onnx.load(args.input) 16 | f16_model = convert_float_to_float16(model) 17 | onnx.save(f16_model, args.output) 18 | -------------------------------------------------------------------------------- /snippets/q.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #define SZ 8 * 8 * 1024 * 1024 6 | 7 | float input_0[SZ]; 8 | float input_1[SZ]; 9 | float output[SZ]; 10 | 11 | int16_t input_0_int16[SZ]; 12 | int16_t input_1_int16[SZ]; 13 | int16_t output_int16[SZ]; 14 | 15 | double now_in_sec() { 16 | struct timeval tv; 17 | gettimeofday(&tv, NULL); 18 | return (double)tv.tv_sec + (double)tv.tv_usec / 1000.f / 1000.f; 19 | } 20 | 21 | void add(const float *a, const float *b, float *c, const int n) { 22 | #pragma omp parallel for 23 | for (int i = 0; i < n; i++) { 24 | c[i] = a[i] + b[i]; 25 | } 26 | } 27 | 28 | void add_int16(const int16_t *a, const int16_t *b, int16_t *c, const int n) { 29 | #pragma omp parallel for 30 | for (int i = 0; i < n; i++) { 31 | c[i] = a[i] + b[i]; 32 | } 33 | } 34 | 35 | int main() { 36 | const int attempts = 10; 37 | int count = 0; 38 | double sum = 0; 39 | for (int i = 0; i < attempts; i++) { 40 | const auto start = now_in_sec(); 41 | { add(input_0, input_1, output, SZ); } 42 | const auto end = now_in_sec(); 43 | if (i > 0) { 44 | sum += (end - start); 45 | count++; 46 | } 47 | std::cout << "float mean: " << sum / count << "s\t" << '\r'; 48 | // std::cout << (end - start) << "s" << std::endl; 49 | } 50 | 51 | std::cout << std::endl; 52 | 53 | count = 0; 54 | sum = 0; 55 | for (int i = 0; i < attempts; i++) { 56 | const auto start = now_in_sec(); 57 | { add_int16(input_0_int16, input_1_int16, output_int16, SZ); } 58 | const auto end = now_in_sec(); 59 | if (i > 0) { 60 | sum += (end - start); 61 | count++; 62 | } 63 | std::cout << "int16_t mean: " << sum / count << "s\t" << '\r'; 64 | } 65 | 66 | std::cout << std::endl; 67 | 68 | return 0; 69 | } 70 | -------------------------------------------------------------------------------- /snippets/sgemm/.gitignore: -------------------------------------------------------------------------------- 1 | main 2 | -------------------------------------------------------------------------------- /snippets/sgemm/Makefile: -------------------------------------------------------------------------------- 1 | SRCS=$(wildcard *.cc) 2 | OBJS=$(SRCS:%.cc=%.o) 3 | DEPS=$(SRCS:%.cc=%.d) 4 | CXX=clang++ 5 | 6 | all: main 7 | 8 | main: main.cc 9 | $(CXX) -O3 -fopenmp -march=native -ffp-contract=fast -lm -lblis -lomp -o $@ $< 10 | 11 | test: all 12 | @./main 13 | 14 | clean: 15 | @rm -rf main *.so 16 | 17 | -include $(DEPS) 18 | 19 | .PHONY: all test clean 20 | -------------------------------------------------------------------------------- /snippets/sgemm/gemm.deit.cc: -------------------------------------------------------------------------------- 1 | // [2024-10-05T16:04:57Z DEBUG altius_session_clang::translator] m=197, k=1536, n=384 2 | // [2024-10-05T16:04:57Z DEBUG altius_session_clang::translator] m=197, k=384, n=384 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | static const size_t N = 4000; 14 | static const size_t ThreadBlockSizeI = N/16; // 250; // 5周(×16スレッド) 15 | static const size_t ThreadBlockSizeK = N; // 50周 16 | static const size_t ThreadBlockSizeJ = N; // 50周 17 | static const size_t L3CacheBlockSizeI = 50; // 1周 18 | static const size_t L3CacheBlockSizeK = 80; // 2周 19 | static const size_t L3CacheBlockSizeJ = 80; // 2周 20 | static const size_t L1DCacheBlockSizeI = 50; // 10周 21 | static const size_t L1DCacheBlockSizeK = 40; // 10周 22 | static const size_t L1DCacheBlockSizeJ = 80; // SIMD方向8要素×5周 23 | static const size_t RegisterBlockSizeI = 5; // 5レジスタ並列に 24 | static const size_t RegisterBlockSizeK = 4; // fma連鎖4回 25 | 26 | void mm( const float *__restrict__ a, const float *__restrict__ b, float *__restrict__ c ) { 27 | for( int i1 = 0; i1 < ThreadBlockSizeI; i1 += L3CacheBlockSizeI ) 28 | for( int k1 = 0; k1 < ThreadBlockSizeK; k1 += L3CacheBlockSizeK ) 29 | for( int j1 = 0; j1 < ThreadBlockSizeJ; j1 += L3CacheBlockSizeJ ) 30 | for( int i2 = 0; i2 < L3CacheBlockSizeI; i2 += L1DCacheBlockSizeI ) 31 | for( int k2 = 0; k2 < L3CacheBlockSizeK; k2 += L1DCacheBlockSizeK ) 32 | for( int j2 = 0; j2 < L3CacheBlockSizeJ; j2 += L1DCacheBlockSizeJ ) 33 | for( int i3 = 0; i3 < L1DCacheBlockSizeI; i3 += RegisterBlockSizeI ) 34 | for( int k3 = 0; k3 < L1DCacheBlockSizeK; k3 += RegisterBlockSizeK ) 35 | for( int j3 = 0; j3 < L1DCacheBlockSizeJ; j3 += 1 ) 36 | for( int i4 = 0; i4 < RegisterBlockSizeI; i4 += 1 ) 37 | for( int k4 = 0; k4 < RegisterBlockSizeK; k4 += 1 ) 38 | { 39 | int i = i1 + i2 + i3 + i4; 40 | int k = k1 + k2 + k3 + k4; 41 | int j = j1 + j2 + j3; 42 | 43 | c[i*N+j] = fma( a[i*N+k], b[k*N+j], c[i*N+j] ); 44 | } 45 | } 46 | 47 | alignas(64) float ah[N*N]; 48 | alignas(64) float bh[N*N]; 49 | alignas(64) float ch[N*N], ch_cblas[N*N]; 50 | 51 | int main() { 52 | std::mt19937_64 mt; 53 | std::uniform_real_distribution dist(-1.0, 1.0); 54 | 55 | for( int i = 0; i < N*N; ++i ) { 56 | ah[i] = dist( mt ); 57 | bh[i] = dist( mt ); 58 | ch[i] = ch_cblas[i] = dist( mt ); 59 | } 60 | 61 | std::cout << "initialized." << std::endl; 62 | 63 | while (true) { 64 | const auto start = std::chrono::system_clock::now(); 65 | 66 | #pragma omp parallel for num_threads(16) 67 | for( int tid = 0; tid < 16; ++tid ) 68 | { 69 | int i0 = tid % 16 * ThreadBlockSizeI; 70 | int j0 = tid / 16 * ThreadBlockSizeJ; 71 | mm( &ah[i0*N], &bh[j0], &ch[i0*N+j0] ); 72 | } 73 | 74 | const auto finish = std::chrono::system_clock::now(); 75 | 76 | cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, ah, N, bh, N, 1.0, ch_cblas, N); 77 | 78 | for( int i = 0; i < N*N; ++i ) 79 | if( std::abs( ch[i] - ch_cblas[i] ) > 1e-2 ) { 80 | std::cerr << "mismatch at " << i << ": " << ch[i] << " != " << ch_cblas[i] << std::endl; 81 | exit(EXIT_FAILURE); 82 | } 83 | 84 | const double s = std::chrono::duration_cast( finish - start ).count() * 1e-9; 85 | static constexpr double flop_per_fma = 2.0; 86 | std::cout << s << " seconds, " << N*N*N*flop_per_fma/s * 1e-9 << " GFLOPS" << std::endl; 87 | } 88 | } 89 | 90 | 91 | -------------------------------------------------------------------------------- /snippets/sgemm/main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | const int m = 128; 11 | const int n = 256; 12 | const int k = 1024; 13 | // const int m = 8; 14 | // const int n = 8; 15 | // const int k = 8; 16 | 17 | const int iter = 10; 18 | 19 | double now_in_sec() { 20 | struct timeval tv; 21 | gettimeofday(&tv, NULL); 22 | return tv.tv_sec + tv.tv_usec * 1e-6; 23 | } 24 | 25 | void myblas_sgemm_1(int m, int n, int k, const float *a, int lda, 26 | const float *b, int ldb, float *c, int ldc) { 27 | for (int i = 0; i < m; i++) 28 | for (int j = 0; j < n; j++) { 29 | float sum = c[i * ldc + j]; 30 | #pragma clang loop vectorize(enable) 31 | for (int l = 0; l < k; l++) 32 | sum += a[i * lda + l] * b[l * ldb + j]; 33 | c[i * ldc + j] = sum; 34 | } 35 | } 36 | 37 | void myblas_sgemm_2(int m, int n, int k, const float *a, int lda, 38 | const float *b, int ldb, float *c, int ldc) { 39 | assert(m % 8 == 0); 40 | assert(n % 8 == 0); 41 | assert(k % 8 == 0); 42 | __m256 sum[8] = {_mm256_setzero_ps(), _mm256_setzero_ps(), 43 | _mm256_setzero_ps(), _mm256_setzero_ps(), 44 | _mm256_setzero_ps(), _mm256_setzero_ps(), 45 | _mm256_setzero_ps(), _mm256_setzero_ps()}; 46 | for (int i = 0; i < m; i += 8) 47 | for (int j = 0; j < n; j += 8) { 48 | 49 | #pragma unroll 50 | for (int l = 0; l < 8; l++) 51 | sum[l] = _mm256_setzero_ps(); 52 | for (int l = 0; l < k; l++) { 53 | _mm_prefetch((const char *)(b + (l + 0) * ldb + j), _MM_HINT_T0); 54 | _mm_prefetch((const char *)(b + (l + 1) * ldb + j), _MM_HINT_T0); 55 | _mm_prefetch((const char *)(b + (l + 2) * ldb + j), _MM_HINT_T0); 56 | _mm_prefetch((const char *)(b + (l + 3) * ldb + j), _MM_HINT_T0); 57 | __m256 as[8]; 58 | #pragma unroll 59 | for (int ll = 0; ll < 8; ll++) 60 | as[ll] = _mm256_broadcast_ss(a + (i + ll) * lda + l); 61 | __m256 bs = _mm256_loadu_ps(b + l * ldb + j); 62 | #pragma unroll 63 | for (int ll = 0; ll < 8; ll++) 64 | sum[ll] = _mm256_fmadd_ps(as[ll], bs, sum[ll]); 65 | } 66 | #pragma unroll 67 | for (int l = 0; l < 8; l++) 68 | _mm256_storeu_ps(c + (i + l) * ldc + j, sum[l]); 69 | } 70 | } 71 | 72 | const int simd_lane = 8; 73 | 74 | inline void micro_kernel(float *a, int lda, float *b, int ldb, float *c, 75 | int ldc) { 76 | for (int i = 0; i < simd_lane; i++) { 77 | __m256 sum = _mm256_loadu_ps(c + i * ldc); 78 | for (int l = 0; l < simd_lane; l++) { 79 | _mm_prefetch((const char *)(a + i * lda + l), _MM_HINT_T0); 80 | __m256 as = _mm256_broadcast_ss(a + i * lda + l); 81 | __m256 bs = _mm256_loadu_ps(b + l * ldb); 82 | sum = _mm256_fmadd_ps(as, bs, sum); 83 | } 84 | _mm256_storeu_ps(c + i * ldc, sum); 85 | } 86 | } 87 | 88 | void myblas_sgemm_3(int m, int n, int k, float *A, int lda, float *B, int ldb, 89 | float *C, int ldc) { 90 | for (int i = 0; i < m; ++i) { 91 | #pragma clang loop vectorize(enable) 92 | for (int j = 0; j < n; ++j) { 93 | C[i * ldc + j] = 0; 94 | } 95 | } 96 | 97 | int nc = n; 98 | int kc = simd_lane; 99 | int mc = m / simd_lane; 100 | int mr = simd_lane; 101 | int nr = simd_lane; 102 | for (int i0 = 0; i0 < k; i0 += kc) { 103 | float *Ap = A + i0; // m*kc 104 | float *Bp = B + i0 * ldb; // kc*n 105 | for (int i1 = 0; i1 < m; i1 += mc) { 106 | float *Ai = Ap + i1 * lda; // mc*kc 107 | float *Cp = C + i1 * ldc; // mc*n 108 | for (int i2 = 0; i2 < n; i2 += nr) { 109 | float *Bi = Bp + i2; // kc*nr 110 | float *Ci = Cp + i2; // mc*nr 111 | for (int i3 = 0; i3 < mc; i3 += mr) { 112 | // 8x8x8 113 | micro_kernel(Ai + i3 * lda, lda, Bi, ldb, Ci + i3 * ldc, ldc); 114 | } 115 | } 116 | } 117 | } 118 | } 119 | 120 | void fill_random(float *x, int n) { 121 | for (int i = 0; i < n; i++) 122 | x[i] = (float)rand() / (float)RAND_MAX; 123 | } 124 | 125 | bool allclose(const float *x, const float *y, int n) { 126 | for (int i = 0; i < n; i++) { 127 | // std::cout << x[i] << " vs " << y[i] << std::endl; 128 | // std::cout << fabs(x[i] - y[i]) << std::endl; 129 | if (fabs(x[i] - y[i]) > 1e-3) 130 | return false; 131 | } 132 | return true; 133 | } 134 | 135 | int main() { 136 | float *x = (float *)calloc(m * k, sizeof(float)); 137 | float *y = (float *)calloc(k * n, sizeof(float)); 138 | float *cblas_z = (float *)calloc(m * n, sizeof(float)); 139 | float *myblas_z = (float *)calloc(m * n, sizeof(float)); 140 | 141 | fill_random(x, m * k); 142 | fill_random(y, k * n); 143 | 144 | for (int i = 0; i < iter; i++) { 145 | const int ave = 30; 146 | double cblas_elapsed = 0.0; 147 | double myblas_elapsed = 0.0; 148 | 149 | for (int j = 0; j < ave; j++) { 150 | const double cblas_start = now_in_sec(); 151 | cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0, x, k, 152 | y, n, 0.0, cblas_z, n); 153 | cblas_elapsed += now_in_sec() - cblas_start; 154 | 155 | const double myblas_start = now_in_sec(); 156 | myblas_sgemm_3(m, n, k, x, k, y, n, myblas_z, n); 157 | myblas_elapsed += now_in_sec() - myblas_start; 158 | 159 | assert(allclose(cblas_z, myblas_z, m * n)); 160 | } 161 | 162 | std::cout << "[blis] " << (cblas_elapsed * 1000.0 / ave) << " [ms]" 163 | << std::endl; 164 | std::cout << "[mine] " << (myblas_elapsed * 1000.0 / ave) << " [ms]" 165 | << std::endl; 166 | } 167 | 168 | return 0; 169 | } 170 | -------------------------------------------------------------------------------- /snippets/softmax.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | double now_in_sec() { 7 | struct timeval tv; 8 | gettimeofday(&tv, NULL); 9 | return (double)tv.tv_sec + (double)tv.tv_usec / 1000.f / 1000.f; 10 | } 11 | 12 | void softmax(float *input_name, float *output_name) { 13 | int batch = 1200; 14 | int axis_len = 100; 15 | 16 | const float LOWER_RANGE = -88.37626; 17 | const float ROUNDING_BIAS = 12582912.0; 18 | const float LOG2RECIPROCAL = 1.44269504088896341; 19 | const float LOG2HIGH = -6.93145752e-1; 20 | const float LOG2LOW = -1.42860677e-6; 21 | const float POLY_0 = 0.0013780593872; 22 | const float POLY_1 = 0.0083731245250; 23 | const float POLY_2 = 0.0416695363820; 24 | const float POLY_3 = 0.1666647195816; 25 | const float POLY_4 = 0.4999998509884; 26 | const float POLY_56 = 1.0000000000000; 27 | const int32_t MAXIMUM_EXPONENT = 0x3F800000; 28 | 29 | #pragma omp parallel for num_threads(8) 30 | for (int i = 0; i < batch; i++) { 31 | const float *input = input_name + i * axis_len; 32 | float *output = output_name + i * axis_len; 33 | 34 | float max = -INFINITY; 35 | for (int j = 0; j < axis_len; j++) { 36 | max = fmaxf(input[j], max); 37 | } 38 | 39 | float sum = 0.0; 40 | #pragma clang loop vectorize(enable) 41 | for (int j = 0; j < axis_len; j++) { 42 | const int val0 = fmaxf(input[j] - max, LOWER_RANGE); 43 | const int biased = fmaf(val0, LOG2RECIPROCAL, ROUNDING_BIAS); 44 | const int m = biased - ROUNDING_BIAS; 45 | const int val1 = fmaf(m, LOG2HIGH, val0); 46 | const int val2 = fmaf(m, LOG2LOW, val1); 47 | const int32_t normal = (*(int *)&biased) << 23; 48 | const int32_t normal2 = normal + MAXIMUM_EXPONENT; 49 | const float p0 = POLY_0; 50 | const float p1 = fmaf(p0, val2, POLY_1); 51 | const float p2 = fmaf(p1, val2, POLY_2); 52 | const float p3 = fmaf(p2, val2, POLY_3); 53 | const float p4 = fmaf(p3, val2, POLY_4); 54 | const float p5 = fmaf(p4, val2, POLY_56); 55 | const float p6 = fmaf(p5, val2, POLY_56); 56 | const float p7 = p6 * (*(float *)&normal2); 57 | sum += p7; 58 | output[j] = p7; 59 | } 60 | 61 | const float recip_sum = 1.0 / sum; 62 | #pragma clang loop vectorize(enable) 63 | for (int j = 0; j < axis_len; j++) { 64 | output[j] = output[j] * recip_sum; 65 | } 66 | } 67 | } 68 | 69 | int main() { 70 | float *input = (float *)malloc(120000 * sizeof(float)); 71 | float *output = (float *)malloc(120000 * sizeof(float)); 72 | 73 | { 74 | double elapsed = 0; 75 | int attempts = 10; 76 | for (int i = 0; i < attempts; i++) { 77 | double start = now_in_sec(); 78 | softmax(input, output); 79 | elapsed += now_in_sec() - start; 80 | } 81 | } 82 | { 83 | double elapsed = 0; 84 | int attempts = 10000; 85 | for (int i = 0; i < attempts; i++) { 86 | double start = now_in_sec(); 87 | softmax(input, output); 88 | elapsed += now_in_sec() - start; 89 | } 90 | printf("%f ms/call\n", (elapsed / attempts) * 1000.0); 91 | } 92 | 93 | return 0; 94 | } 95 | 96 | --------------------------------------------------------------------------------