├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── examples ├── resnet18.rs └── test_image.json ├── iree-sys ├── Cargo.lock ├── Cargo.toml ├── build.rs ├── examples │ ├── resnet18.rs │ ├── simple_mul.rs │ ├── simple_mul_module.vmfb │ └── test_image.json └── src │ ├── helper.rs │ ├── iree │ ├── mod.rs │ └── runtime │ │ ├── api.rs │ │ └── mod.rs │ └── lib.rs ├── scripts ├── dump_mlir.ipynb └── torchscript_resnet18.py ├── src ├── err │ └── mod.rs ├── ffi │ └── mod.rs ├── lib.rs └── types │ ├── allocator.rs │ ├── bytespan.rs │ ├── hal_allocator.rs │ ├── hal_buffer.rs │ ├── hal_device.rs │ ├── mod.rs │ ├── runtime │ ├── call.rs │ ├── instance.rs │ ├── mod.rs │ └── session.rs │ └── status.rs └── tests ├── test_hal.rs └── test_runtime.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "adler" 7 | version = "1.0.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" 10 | 11 | [[package]] 12 | name = "aes" 13 | version = "0.7.5" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" 16 | dependencies = [ 17 | "cfg-if", 18 | "cipher", 19 | "cpufeatures", 20 | "opaque-debug", 21 | ] 22 | 23 | [[package]] 24 | name = "anyhow" 25 | version = "1.0.69" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" 28 | 29 | [[package]] 30 | name = "autocfg" 31 | version = "1.1.0" 32 | source = "registry+https://github.com/rust-lang/crates.io-index" 33 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 34 | 35 | [[package]] 36 | name = "base64ct" 37 | version = "1.5.3" 38 | source = "registry+https://github.com/rust-lang/crates.io-index" 39 | checksum = "b645a089122eccb6111b4f81cbc1a49f5900ac4666bb93ac027feaecf15607bf" 40 | 41 | [[package]] 42 | name = "bindgen" 43 | version = "0.63.0" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885" 46 | dependencies = [ 47 | "bitflags", 48 | "cexpr", 49 | "clang-sys", 50 | "lazy_static", 51 | "lazycell", 52 | "log", 53 | "peeking_take_while", 54 | "proc-macro2", 55 | "quote", 56 | "regex", 57 | "rustc-hash", 58 | "shlex", 59 | "syn", 60 | "which", 61 | ] 62 | 63 | [[package]] 64 | name = "bitflags" 65 | version = "1.3.2" 66 | source = "registry+https://github.com/rust-lang/crates.io-index" 67 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 68 | 69 | [[package]] 70 | name = "block-buffer" 71 | version = "0.10.3" 72 | source = "registry+https://github.com/rust-lang/crates.io-index" 73 | checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" 74 | dependencies = [ 75 | "generic-array", 76 | ] 77 | 78 | [[package]] 79 | name = "byteorder" 80 | version = "1.4.3" 81 | source = "registry+https://github.com/rust-lang/crates.io-index" 82 | checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" 83 | 84 | [[package]] 85 | name = "bzip2" 86 | version = "0.4.4" 87 | source = "registry+https://github.com/rust-lang/crates.io-index" 88 | checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" 89 | dependencies = [ 90 | "bzip2-sys", 91 | "libc", 92 | ] 93 | 94 | [[package]] 95 | name = "bzip2-sys" 96 | version = "0.1.11+1.0.8" 97 | source = "registry+https://github.com/rust-lang/crates.io-index" 98 | checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" 99 | dependencies = [ 100 | "cc", 101 | "libc", 102 | "pkg-config", 103 | ] 104 | 105 | [[package]] 106 | name = "cc" 107 | version = "1.0.79" 108 | source = "registry+https://github.com/rust-lang/crates.io-index" 109 | checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" 110 | dependencies = [ 111 | "jobserver", 112 | ] 113 | 114 | [[package]] 115 | name = "cexpr" 116 | version = "0.6.0" 117 | source = "registry+https://github.com/rust-lang/crates.io-index" 118 | checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" 119 | dependencies = [ 120 | "nom", 121 | ] 122 | 123 | [[package]] 124 | name = "cfg-if" 125 | version = "1.0.0" 126 | source = "registry+https://github.com/rust-lang/crates.io-index" 127 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 128 | 129 | [[package]] 130 | name = "cipher" 131 | version = "0.3.0" 132 | source = "registry+https://github.com/rust-lang/crates.io-index" 133 | checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" 134 | dependencies = [ 135 | "generic-array", 136 | ] 137 | 138 | [[package]] 139 | name = "clang-sys" 140 | version = "1.4.0" 141 | source = "registry+https://github.com/rust-lang/crates.io-index" 142 | checksum = "fa2e27ae6ab525c3d369ded447057bca5438d86dc3a68f6faafb8269ba82ebf3" 143 | dependencies = [ 144 | "glob", 145 | "libc", 146 | "libloading", 147 | ] 148 | 149 | [[package]] 150 | name = "cmake" 151 | version = "0.1.49" 152 | source = "registry+https://github.com/rust-lang/crates.io-index" 153 | checksum = "db34956e100b30725f2eb215f90d4871051239535632f84fea3bc92722c66b7c" 154 | dependencies = [ 155 | "cc", 156 | ] 157 | 158 | [[package]] 159 | name = "constant_time_eq" 160 | version = "0.1.5" 161 | source = "registry+https://github.com/rust-lang/crates.io-index" 162 | checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" 163 | 164 | [[package]] 165 | name = "cpufeatures" 166 | version = "0.2.5" 167 | source = "registry+https://github.com/rust-lang/crates.io-index" 168 | checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" 169 | dependencies = [ 170 | "libc", 171 | ] 172 | 173 | [[package]] 174 | name = "crc32fast" 175 | version = "1.3.2" 176 | source = "registry+https://github.com/rust-lang/crates.io-index" 177 | checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" 178 | dependencies = [ 179 | "cfg-if", 180 | ] 181 | 182 | [[package]] 183 | name = "crossbeam-utils" 184 | version = "0.8.14" 185 | source = "registry+https://github.com/rust-lang/crates.io-index" 186 | checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" 187 | dependencies = [ 188 | "cfg-if", 189 | ] 190 | 191 | [[package]] 192 | name = "crypto-common" 193 | version = "0.1.6" 194 | source = "registry+https://github.com/rust-lang/crates.io-index" 195 | checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" 196 | dependencies = [ 197 | "generic-array", 198 | "typenum", 199 | ] 200 | 201 | [[package]] 202 | name = "curl" 203 | version = "0.4.44" 204 | source = "registry+https://github.com/rust-lang/crates.io-index" 205 | checksum = "509bd11746c7ac09ebd19f0b17782eae80aadee26237658a6b4808afb5c11a22" 206 | dependencies = [ 207 | "curl-sys", 208 | "libc", 209 | "openssl-probe", 210 | "openssl-sys", 211 | "schannel", 212 | "socket2", 213 | "winapi", 214 | ] 215 | 216 | [[package]] 217 | name = "curl-sys" 218 | version = "0.4.59+curl-7.86.0" 219 | source = "registry+https://github.com/rust-lang/crates.io-index" 220 | checksum = "6cfce34829f448b08f55b7db6d0009e23e2e86a34e8c2b366269bf5799b4a407" 221 | dependencies = [ 222 | "cc", 223 | "libc", 224 | "libz-sys", 225 | "openssl-sys", 226 | "pkg-config", 227 | "vcpkg", 228 | "winapi", 229 | ] 230 | 231 | [[package]] 232 | name = "digest" 233 | version = "0.10.6" 234 | source = "registry+https://github.com/rust-lang/crates.io-index" 235 | checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" 236 | dependencies = [ 237 | "block-buffer", 238 | "crypto-common", 239 | "subtle", 240 | ] 241 | 242 | [[package]] 243 | name = "either" 244 | version = "1.8.1" 245 | source = "registry+https://github.com/rust-lang/crates.io-index" 246 | checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" 247 | 248 | [[package]] 249 | name = "flatbuffers" 250 | version = "23.1.21" 251 | source = "registry+https://github.com/rust-lang/crates.io-index" 252 | checksum = "77f5399c2c9c50ae9418e522842ad362f61ee48b346ac106807bd355a8a7c619" 253 | dependencies = [ 254 | "bitflags", 255 | "rustc_version", 256 | "serde", 257 | ] 258 | 259 | [[package]] 260 | name = "flate2" 261 | version = "1.0.25" 262 | source = "registry+https://github.com/rust-lang/crates.io-index" 263 | checksum = "a8a2db397cb1c8772f31494cb8917e48cd1e64f0fa7efac59fbd741a0a8ce841" 264 | dependencies = [ 265 | "crc32fast", 266 | "miniz_oxide", 267 | ] 268 | 269 | [[package]] 270 | name = "form_urlencoded" 271 | version = "1.1.0" 272 | source = "registry+https://github.com/rust-lang/crates.io-index" 273 | checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" 274 | dependencies = [ 275 | "percent-encoding", 276 | ] 277 | 278 | [[package]] 279 | name = "generic-array" 280 | version = "0.14.6" 281 | source = "registry+https://github.com/rust-lang/crates.io-index" 282 | checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" 283 | dependencies = [ 284 | "typenum", 285 | "version_check", 286 | ] 287 | 288 | [[package]] 289 | name = "getrandom" 290 | version = "0.2.8" 291 | source = "registry+https://github.com/rust-lang/crates.io-index" 292 | checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" 293 | dependencies = [ 294 | "cfg-if", 295 | "libc", 296 | "wasi", 297 | ] 298 | 299 | [[package]] 300 | name = "git2" 301 | version = "0.16.1" 302 | source = "registry+https://github.com/rust-lang/crates.io-index" 303 | checksum = "ccf7f68c2995f392c49fffb4f95ae2c873297830eb25c6bc4c114ce8f4562acc" 304 | dependencies = [ 305 | "bitflags", 306 | "libc", 307 | "libgit2-sys", 308 | "log", 309 | "openssl-probe", 310 | "openssl-sys", 311 | "url", 312 | ] 313 | 314 | [[package]] 315 | name = "glob" 316 | version = "0.3.1" 317 | source = "registry+https://github.com/rust-lang/crates.io-index" 318 | checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" 319 | 320 | [[package]] 321 | name = "half" 322 | version = "1.8.2" 323 | source = "registry+https://github.com/rust-lang/crates.io-index" 324 | checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" 325 | 326 | [[package]] 327 | name = "hmac" 328 | version = "0.12.1" 329 | source = "registry+https://github.com/rust-lang/crates.io-index" 330 | checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" 331 | dependencies = [ 332 | "digest", 333 | ] 334 | 335 | [[package]] 336 | name = "idna" 337 | version = "0.3.0" 338 | source = "registry+https://github.com/rust-lang/crates.io-index" 339 | checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" 340 | dependencies = [ 341 | "unicode-bidi", 342 | "unicode-normalization", 343 | ] 344 | 345 | [[package]] 346 | name = "iree-rs" 347 | version = "0.1.1" 348 | dependencies = [ 349 | "iree-sys", 350 | "once_cell", 351 | "serde", 352 | "serde_json", 353 | ] 354 | 355 | [[package]] 356 | name = "iree-sys" 357 | version = "0.1.0" 358 | dependencies = [ 359 | "anyhow", 360 | "bindgen", 361 | "cmake", 362 | "flatbuffers", 363 | "git2", 364 | "once_cell", 365 | "pkg-config", 366 | "serde", 367 | "serde_json", 368 | "tch", 369 | ] 370 | 371 | [[package]] 372 | name = "itoa" 373 | version = "1.0.5" 374 | source = "registry+https://github.com/rust-lang/crates.io-index" 375 | checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" 376 | 377 | [[package]] 378 | name = "jobserver" 379 | version = "0.1.25" 380 | source = "registry+https://github.com/rust-lang/crates.io-index" 381 | checksum = "068b1ee6743e4d11fb9c6a1e6064b3693a1b600e7f5f5988047d98b3dc9fb90b" 382 | dependencies = [ 383 | "libc", 384 | ] 385 | 386 | [[package]] 387 | name = "lazy_static" 388 | version = "1.4.0" 389 | source = "registry+https://github.com/rust-lang/crates.io-index" 390 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 391 | 392 | [[package]] 393 | name = "lazycell" 394 | version = "1.3.0" 395 | source = "registry+https://github.com/rust-lang/crates.io-index" 396 | checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" 397 | 398 | [[package]] 399 | name = "libc" 400 | version = "0.2.139" 401 | source = "registry+https://github.com/rust-lang/crates.io-index" 402 | checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" 403 | 404 | [[package]] 405 | name = "libgit2-sys" 406 | version = "0.14.2+1.5.1" 407 | source = "registry+https://github.com/rust-lang/crates.io-index" 408 | checksum = "7f3d95f6b51075fe9810a7ae22c7095f12b98005ab364d8544797a825ce946a4" 409 | dependencies = [ 410 | "cc", 411 | "libc", 412 | "libssh2-sys", 413 | "libz-sys", 414 | "openssl-sys", 415 | "pkg-config", 416 | ] 417 | 418 | [[package]] 419 | name = "libloading" 420 | version = "0.7.4" 421 | source = "registry+https://github.com/rust-lang/crates.io-index" 422 | checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" 423 | dependencies = [ 424 | "cfg-if", 425 | "winapi", 426 | ] 427 | 428 | [[package]] 429 | name = "libssh2-sys" 430 | version = "0.2.23" 431 | source = "registry+https://github.com/rust-lang/crates.io-index" 432 | checksum = "b094a36eb4b8b8c8a7b4b8ae43b2944502be3e59cd87687595cf6b0a71b3f4ca" 433 | dependencies = [ 434 | "cc", 435 | "libc", 436 | "libz-sys", 437 | "openssl-sys", 438 | "pkg-config", 439 | "vcpkg", 440 | ] 441 | 442 | [[package]] 443 | name = "libz-sys" 444 | version = "1.1.8" 445 | source = "registry+https://github.com/rust-lang/crates.io-index" 446 | checksum = "9702761c3935f8cc2f101793272e202c72b99da8f4224a19ddcf1279a6450bbf" 447 | dependencies = [ 448 | "cc", 449 | "libc", 450 | "pkg-config", 451 | "vcpkg", 452 | ] 453 | 454 | [[package]] 455 | name = "log" 456 | version = "0.4.17" 457 | source = "registry+https://github.com/rust-lang/crates.io-index" 458 | checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" 459 | dependencies = [ 460 | "cfg-if", 461 | ] 462 | 463 | [[package]] 464 | name = "matrixmultiply" 465 | version = "0.3.2" 466 | source = "registry+https://github.com/rust-lang/crates.io-index" 467 | checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" 468 | dependencies = [ 469 | "rawpointer", 470 | ] 471 | 472 | [[package]] 473 | name = "memchr" 474 | version = "2.5.0" 475 | source = "registry+https://github.com/rust-lang/crates.io-index" 476 | checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" 477 | 478 | [[package]] 479 | name = "minimal-lexical" 480 | version = "0.2.1" 481 | source = "registry+https://github.com/rust-lang/crates.io-index" 482 | checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" 483 | 484 | [[package]] 485 | name = "miniz_oxide" 486 | version = "0.6.2" 487 | source = "registry+https://github.com/rust-lang/crates.io-index" 488 | checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" 489 | dependencies = [ 490 | "adler", 491 | ] 492 | 493 | [[package]] 494 | name = "ndarray" 495 | version = "0.15.6" 496 | source = "registry+https://github.com/rust-lang/crates.io-index" 497 | checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" 498 | dependencies = [ 499 | "matrixmultiply", 500 | "num-complex", 501 | "num-integer", 502 | "num-traits", 503 | "rawpointer", 504 | ] 505 | 506 | [[package]] 507 | name = "nom" 508 | version = "7.1.3" 509 | source = "registry+https://github.com/rust-lang/crates.io-index" 510 | checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" 511 | dependencies = [ 512 | "memchr", 513 | "minimal-lexical", 514 | ] 515 | 516 | [[package]] 517 | name = "num-complex" 518 | version = "0.4.3" 519 | source = "registry+https://github.com/rust-lang/crates.io-index" 520 | checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" 521 | dependencies = [ 522 | "num-traits", 523 | ] 524 | 525 | [[package]] 526 | name = "num-integer" 527 | version = "0.1.45" 528 | source = "registry+https://github.com/rust-lang/crates.io-index" 529 | checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" 530 | dependencies = [ 531 | "autocfg", 532 | "num-traits", 533 | ] 534 | 535 | [[package]] 536 | name = "num-traits" 537 | version = "0.2.15" 538 | source = "registry+https://github.com/rust-lang/crates.io-index" 539 | checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" 540 | dependencies = [ 541 | "autocfg", 542 | ] 543 | 544 | [[package]] 545 | name = "once_cell" 546 | version = "1.17.1" 547 | source = "registry+https://github.com/rust-lang/crates.io-index" 548 | checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" 549 | 550 | [[package]] 551 | name = "opaque-debug" 552 | version = "0.3.0" 553 | source = "registry+https://github.com/rust-lang/crates.io-index" 554 | checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" 555 | 556 | [[package]] 557 | name = "openssl-probe" 558 | version = "0.1.5" 559 | source = "registry+https://github.com/rust-lang/crates.io-index" 560 | checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" 561 | 562 | [[package]] 563 | name = "openssl-sys" 564 | version = "0.9.80" 565 | source = "registry+https://github.com/rust-lang/crates.io-index" 566 | checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" 567 | dependencies = [ 568 | "autocfg", 569 | "cc", 570 | "libc", 571 | "pkg-config", 572 | "vcpkg", 573 | ] 574 | 575 | [[package]] 576 | name = "password-hash" 577 | version = "0.4.2" 578 | source = "registry+https://github.com/rust-lang/crates.io-index" 579 | checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" 580 | dependencies = [ 581 | "base64ct", 582 | "rand_core", 583 | "subtle", 584 | ] 585 | 586 | [[package]] 587 | name = "pbkdf2" 588 | version = "0.11.0" 589 | source = "registry+https://github.com/rust-lang/crates.io-index" 590 | checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" 591 | dependencies = [ 592 | "digest", 593 | "hmac", 594 | "password-hash", 595 | "sha2", 596 | ] 597 | 598 | [[package]] 599 | name = "peeking_take_while" 600 | version = "0.1.2" 601 | source = "registry+https://github.com/rust-lang/crates.io-index" 602 | checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" 603 | 604 | [[package]] 605 | name = "percent-encoding" 606 | version = "2.2.0" 607 | source = "registry+https://github.com/rust-lang/crates.io-index" 608 | checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" 609 | 610 | [[package]] 611 | name = "pkg-config" 612 | version = "0.3.26" 613 | source = "registry+https://github.com/rust-lang/crates.io-index" 614 | checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" 615 | 616 | [[package]] 617 | name = "ppv-lite86" 618 | version = "0.2.17" 619 | source = "registry+https://github.com/rust-lang/crates.io-index" 620 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 621 | 622 | [[package]] 623 | name = "proc-macro2" 624 | version = "1.0.51" 625 | source = "registry+https://github.com/rust-lang/crates.io-index" 626 | checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" 627 | dependencies = [ 628 | "unicode-ident", 629 | ] 630 | 631 | [[package]] 632 | name = "quote" 633 | version = "1.0.23" 634 | source = "registry+https://github.com/rust-lang/crates.io-index" 635 | checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" 636 | dependencies = [ 637 | "proc-macro2", 638 | ] 639 | 640 | [[package]] 641 | name = "rand" 642 | version = "0.8.5" 643 | source = "registry+https://github.com/rust-lang/crates.io-index" 644 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 645 | dependencies = [ 646 | "libc", 647 | "rand_chacha", 648 | "rand_core", 649 | ] 650 | 651 | [[package]] 652 | name = "rand_chacha" 653 | version = "0.3.1" 654 | source = "registry+https://github.com/rust-lang/crates.io-index" 655 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 656 | dependencies = [ 657 | "ppv-lite86", 658 | "rand_core", 659 | ] 660 | 661 | [[package]] 662 | name = "rand_core" 663 | version = "0.6.4" 664 | source = "registry+https://github.com/rust-lang/crates.io-index" 665 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 666 | dependencies = [ 667 | "getrandom", 668 | ] 669 | 670 | [[package]] 671 | name = "rawpointer" 672 | version = "0.2.1" 673 | source = "registry+https://github.com/rust-lang/crates.io-index" 674 | checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" 675 | 676 | [[package]] 677 | name = "regex" 678 | version = "1.7.1" 679 | source = "registry+https://github.com/rust-lang/crates.io-index" 680 | checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" 681 | dependencies = [ 682 | "regex-syntax", 683 | ] 684 | 685 | [[package]] 686 | name = "regex-syntax" 687 | version = "0.6.28" 688 | source = "registry+https://github.com/rust-lang/crates.io-index" 689 | checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" 690 | 691 | [[package]] 692 | name = "rustc-hash" 693 | version = "1.1.0" 694 | source = "registry+https://github.com/rust-lang/crates.io-index" 695 | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" 696 | 697 | [[package]] 698 | name = "rustc_version" 699 | version = "0.4.0" 700 | source = "registry+https://github.com/rust-lang/crates.io-index" 701 | checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" 702 | dependencies = [ 703 | "semver", 704 | ] 705 | 706 | [[package]] 707 | name = "ryu" 708 | version = "1.0.12" 709 | source = "registry+https://github.com/rust-lang/crates.io-index" 710 | checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" 711 | 712 | [[package]] 713 | name = "schannel" 714 | version = "0.1.21" 715 | source = "registry+https://github.com/rust-lang/crates.io-index" 716 | checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" 717 | dependencies = [ 718 | "windows-sys", 719 | ] 720 | 721 | [[package]] 722 | name = "semver" 723 | version = "1.0.16" 724 | source = "registry+https://github.com/rust-lang/crates.io-index" 725 | checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" 726 | 727 | [[package]] 728 | name = "serde" 729 | version = "1.0.152" 730 | source = "registry+https://github.com/rust-lang/crates.io-index" 731 | checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" 732 | dependencies = [ 733 | "serde_derive", 734 | ] 735 | 736 | [[package]] 737 | name = "serde_derive" 738 | version = "1.0.152" 739 | source = "registry+https://github.com/rust-lang/crates.io-index" 740 | checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" 741 | dependencies = [ 742 | "proc-macro2", 743 | "quote", 744 | "syn", 745 | ] 746 | 747 | [[package]] 748 | name = "serde_json" 749 | version = "1.0.93" 750 | source = "registry+https://github.com/rust-lang/crates.io-index" 751 | checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" 752 | dependencies = [ 753 | "itoa", 754 | "ryu", 755 | "serde", 756 | ] 757 | 758 | [[package]] 759 | name = "sha1" 760 | version = "0.10.5" 761 | source = "registry+https://github.com/rust-lang/crates.io-index" 762 | checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" 763 | dependencies = [ 764 | "cfg-if", 765 | "cpufeatures", 766 | "digest", 767 | ] 768 | 769 | [[package]] 770 | name = "sha2" 771 | version = "0.10.6" 772 | source = "registry+https://github.com/rust-lang/crates.io-index" 773 | checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" 774 | dependencies = [ 775 | "cfg-if", 776 | "cpufeatures", 777 | "digest", 778 | ] 779 | 780 | [[package]] 781 | name = "shlex" 782 | version = "1.1.0" 783 | source = "registry+https://github.com/rust-lang/crates.io-index" 784 | checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" 785 | 786 | [[package]] 787 | name = "socket2" 788 | version = "0.4.7" 789 | source = "registry+https://github.com/rust-lang/crates.io-index" 790 | checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" 791 | dependencies = [ 792 | "libc", 793 | "winapi", 794 | ] 795 | 796 | [[package]] 797 | name = "subtle" 798 | version = "2.4.1" 799 | source = "registry+https://github.com/rust-lang/crates.io-index" 800 | checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" 801 | 802 | [[package]] 803 | name = "syn" 804 | version = "1.0.107" 805 | source = "registry+https://github.com/rust-lang/crates.io-index" 806 | checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" 807 | dependencies = [ 808 | "proc-macro2", 809 | "quote", 810 | "unicode-ident", 811 | ] 812 | 813 | [[package]] 814 | name = "tch" 815 | version = "0.10.1" 816 | source = "registry+https://github.com/rust-lang/crates.io-index" 817 | checksum = "f4e8ecac1bcd6c92726de9b1e998aa99b3977af0716992382dfd1171289b9575" 818 | dependencies = [ 819 | "half", 820 | "lazy_static", 821 | "libc", 822 | "ndarray", 823 | "rand", 824 | "thiserror", 825 | "torch-sys", 826 | "zip", 827 | ] 828 | 829 | [[package]] 830 | name = "thiserror" 831 | version = "1.0.38" 832 | source = "registry+https://github.com/rust-lang/crates.io-index" 833 | checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" 834 | dependencies = [ 835 | "thiserror-impl", 836 | ] 837 | 838 | [[package]] 839 | name = "thiserror-impl" 840 | version = "1.0.38" 841 | source = "registry+https://github.com/rust-lang/crates.io-index" 842 | checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" 843 | dependencies = [ 844 | "proc-macro2", 845 | "quote", 846 | "syn", 847 | ] 848 | 849 | [[package]] 850 | name = "time" 851 | version = "0.3.17" 852 | source = "registry+https://github.com/rust-lang/crates.io-index" 853 | checksum = "a561bf4617eebd33bca6434b988f39ed798e527f51a1e797d0ee4f61c0a38376" 854 | dependencies = [ 855 | "serde", 856 | "time-core", 857 | ] 858 | 859 | [[package]] 860 | name = "time-core" 861 | version = "0.1.0" 862 | source = "registry+https://github.com/rust-lang/crates.io-index" 863 | checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" 864 | 865 | [[package]] 866 | name = "tinyvec" 867 | version = "1.6.0" 868 | source = "registry+https://github.com/rust-lang/crates.io-index" 869 | checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" 870 | dependencies = [ 871 | "tinyvec_macros", 872 | ] 873 | 874 | [[package]] 875 | name = "tinyvec_macros" 876 | version = "0.1.1" 877 | source = "registry+https://github.com/rust-lang/crates.io-index" 878 | checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" 879 | 880 | [[package]] 881 | name = "torch-sys" 882 | version = "0.10.0" 883 | source = "registry+https://github.com/rust-lang/crates.io-index" 884 | checksum = "877dbdc2732bdb118a71c94d0004333d29f76ebb5e88f193a3abe068f7bd6de9" 885 | dependencies = [ 886 | "anyhow", 887 | "cc", 888 | "curl", 889 | "libc", 890 | "zip", 891 | ] 892 | 893 | [[package]] 894 | name = "typenum" 895 | version = "1.16.0" 896 | source = "registry+https://github.com/rust-lang/crates.io-index" 897 | checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" 898 | 899 | [[package]] 900 | name = "unicode-bidi" 901 | version = "0.3.10" 902 | source = "registry+https://github.com/rust-lang/crates.io-index" 903 | checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" 904 | 905 | [[package]] 906 | name = "unicode-ident" 907 | version = "1.0.6" 908 | source = "registry+https://github.com/rust-lang/crates.io-index" 909 | checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" 910 | 911 | [[package]] 912 | name = "unicode-normalization" 913 | version = "0.1.22" 914 | source = "registry+https://github.com/rust-lang/crates.io-index" 915 | checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" 916 | dependencies = [ 917 | "tinyvec", 918 | ] 919 | 920 | [[package]] 921 | name = "url" 922 | version = "2.3.1" 923 | source = "registry+https://github.com/rust-lang/crates.io-index" 924 | checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" 925 | dependencies = [ 926 | "form_urlencoded", 927 | "idna", 928 | "percent-encoding", 929 | ] 930 | 931 | [[package]] 932 | name = "vcpkg" 933 | version = "0.2.15" 934 | source = "registry+https://github.com/rust-lang/crates.io-index" 935 | checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" 936 | 937 | [[package]] 938 | name = "version_check" 939 | version = "0.9.4" 940 | source = "registry+https://github.com/rust-lang/crates.io-index" 941 | checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" 942 | 943 | [[package]] 944 | name = "wasi" 945 | version = "0.11.0+wasi-snapshot-preview1" 946 | source = "registry+https://github.com/rust-lang/crates.io-index" 947 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 948 | 949 | [[package]] 950 | name = "which" 951 | version = "4.4.0" 952 | source = "registry+https://github.com/rust-lang/crates.io-index" 953 | checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" 954 | dependencies = [ 955 | "either", 956 | "libc", 957 | "once_cell", 958 | ] 959 | 960 | [[package]] 961 | name = "winapi" 962 | version = "0.3.9" 963 | source = "registry+https://github.com/rust-lang/crates.io-index" 964 | checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" 965 | dependencies = [ 966 | "winapi-i686-pc-windows-gnu", 967 | "winapi-x86_64-pc-windows-gnu", 968 | ] 969 | 970 | [[package]] 971 | name = "winapi-i686-pc-windows-gnu" 972 | version = "0.4.0" 973 | source = "registry+https://github.com/rust-lang/crates.io-index" 974 | checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 975 | 976 | [[package]] 977 | name = "winapi-x86_64-pc-windows-gnu" 978 | version = "0.4.0" 979 | source = "registry+https://github.com/rust-lang/crates.io-index" 980 | checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 981 | 982 | [[package]] 983 | name = "windows-sys" 984 | version = "0.42.0" 985 | source = "registry+https://github.com/rust-lang/crates.io-index" 986 | checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" 987 | dependencies = [ 988 | "windows_aarch64_gnullvm", 989 | "windows_aarch64_msvc", 990 | "windows_i686_gnu", 991 | "windows_i686_msvc", 992 | "windows_x86_64_gnu", 993 | "windows_x86_64_gnullvm", 994 | "windows_x86_64_msvc", 995 | ] 996 | 997 | [[package]] 998 | name = "windows_aarch64_gnullvm" 999 | version = "0.42.1" 1000 | source = "registry+https://github.com/rust-lang/crates.io-index" 1001 | checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" 1002 | 1003 | [[package]] 1004 | name = "windows_aarch64_msvc" 1005 | version = "0.42.1" 1006 | source = "registry+https://github.com/rust-lang/crates.io-index" 1007 | checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" 1008 | 1009 | [[package]] 1010 | name = "windows_i686_gnu" 1011 | version = "0.42.1" 1012 | source = "registry+https://github.com/rust-lang/crates.io-index" 1013 | checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" 1014 | 1015 | [[package]] 1016 | name = "windows_i686_msvc" 1017 | version = "0.42.1" 1018 | source = "registry+https://github.com/rust-lang/crates.io-index" 1019 | checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" 1020 | 1021 | [[package]] 1022 | name = "windows_x86_64_gnu" 1023 | version = "0.42.1" 1024 | source = "registry+https://github.com/rust-lang/crates.io-index" 1025 | checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" 1026 | 1027 | [[package]] 1028 | name = "windows_x86_64_gnullvm" 1029 | version = "0.42.1" 1030 | source = "registry+https://github.com/rust-lang/crates.io-index" 1031 | checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" 1032 | 1033 | [[package]] 1034 | name = "windows_x86_64_msvc" 1035 | version = "0.42.1" 1036 | source = "registry+https://github.com/rust-lang/crates.io-index" 1037 | checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" 1038 | 1039 | [[package]] 1040 | name = "zip" 1041 | version = "0.6.4" 1042 | source = "registry+https://github.com/rust-lang/crates.io-index" 1043 | checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef" 1044 | dependencies = [ 1045 | "aes", 1046 | "byteorder", 1047 | "bzip2", 1048 | "constant_time_eq", 1049 | "crc32fast", 1050 | "crossbeam-utils", 1051 | "flate2", 1052 | "hmac", 1053 | "pbkdf2", 1054 | "sha1", 1055 | "time", 1056 | "zstd", 1057 | ] 1058 | 1059 | [[package]] 1060 | name = "zstd" 1061 | version = "0.11.2+zstd.1.5.2" 1062 | source = "registry+https://github.com/rust-lang/crates.io-index" 1063 | checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" 1064 | dependencies = [ 1065 | "zstd-safe", 1066 | ] 1067 | 1068 | [[package]] 1069 | name = "zstd-safe" 1070 | version = "5.0.2+zstd.1.5.2" 1071 | source = "registry+https://github.com/rust-lang/crates.io-index" 1072 | checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" 1073 | dependencies = [ 1074 | "libc", 1075 | "zstd-sys", 1076 | ] 1077 | 1078 | [[package]] 1079 | name = "zstd-sys" 1080 | version = "2.0.7+zstd.1.5.4" 1081 | source = "registry+https://github.com/rust-lang/crates.io-index" 1082 | checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" 1083 | dependencies = [ 1084 | "cc", 1085 | "libc", 1086 | "pkg-config", 1087 | ] 1088 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "iree-rs" 3 | version = "0.1.1" 4 | edition = "2021" 5 | description = "Rustic bindings for the IREE runtime" 6 | license = "MIT" 7 | repository = "https://github.com/SamKG/iree-rs" 8 | readme = "README.md" 9 | 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dependencies] 13 | iree-sys = { path = "iree-sys", version = "0.1.0" } 14 | 15 | [dev-dependencies] 16 | serde = { version = "1.0.152", features = ["derive"] } 17 | serde_json = "1.0.93" 18 | once_cell = "1.17.0" 19 | 20 | [workspace] 21 | members = ["iree-sys"] 22 | 23 | 24 | [[example]] 25 | name = "resnet18" 26 | test = true 27 | bench = true 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # iree-rs 2 | 3 | ## What is this? 4 | This crate contains rustic bindings for the [IREE](https://iree-org.github.io/iree/) runtime. 5 | 6 | ## Building 7 | In order to build, iree-rs requires the following to be available on your machine: 8 | - clang/clang++ (tested with v12.01, but other versions may also work) 9 | - git 10 | 11 | iree-rs clones and builds the [main branch of the IREE repo](https://github.com/iree-org/iree) during build time, so you don't need to have iree pre-installed on your machine 12 | 13 | ## Examples 14 | Examples for iree-rs are available [in the repository](https://github.com/SamKG/iree-rs/tree/main/examples) 15 | 16 | Since some examples require model weights, you may have to run [scripts](https://github.com/SamKG/iree-rs/tree/main/scripts) to get the required data ahead of time. 17 | 18 | -------------------------------------------------------------------------------- /examples/resnet18.rs: -------------------------------------------------------------------------------- 1 | use iree_rs::{ 2 | err::IreeError, 3 | types::{ 4 | allocator::IreeAllocator, 5 | bytespan::IreeConstByteSpan, 6 | hal_buffer::{IreeHalBufferView, IreeHalBufferViewParamsBuilder}, 7 | runtime::{ 8 | instance::{IreeRuntimeInstance, IreeRuntimeInstanceOptionsBuilder}, 9 | session::{IreeRuntimeSession, IreeRuntimeSessionOptionsBuilder}, 10 | }, 11 | }, 12 | }; 13 | use iree_sys::iree::runtime::api::{ 14 | iree_hal_buffer_usage_bits_t_IREE_HAL_BUFFER_USAGE_DEFAULT, 15 | iree_hal_element_types_t_IREE_HAL_ELEMENT_TYPE_FLOAT_32, 16 | iree_hal_encoding_types_t_IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, 17 | iree_hal_memory_type_bits_t_IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, iree_runtime_call_flags_t, 18 | }; 19 | use once_cell::sync::Lazy; 20 | use serde::{Deserialize, Serialize}; 21 | 22 | #[derive(Deserialize, Serialize)] 23 | struct Image { 24 | data: Vec, 25 | shape: Vec, 26 | } 27 | 28 | static RESNET18_VMFB: Lazy> = Lazy::new(|| include_bytes!("resnet18.vmfb").to_vec()); 29 | static TEST_IMAGE: Lazy> = Lazy::new(|| include_bytes!("test_image.json").to_vec()); 30 | 31 | pub fn run_resnet18() -> Result<(), IreeError> { 32 | // create a runtime instance 33 | let instance = IreeRuntimeInstance::try_from_options( 34 | &IreeRuntimeInstanceOptionsBuilder::default() 35 | .use_all_available_drivers() 36 | .build(), 37 | &IreeAllocator::system_allocator(), 38 | )?; 39 | 40 | // create a device 41 | let device = instance.try_create_default_device("local-task")?; 42 | 43 | // get host allocator 44 | let allocator = instance.host_allocator(); 45 | 46 | // create a session 47 | let session = IreeRuntimeSession::create_with_device( 48 | &instance, 49 | &IreeRuntimeSessionOptionsBuilder::default().build(), 50 | &device, 51 | &allocator, 52 | )?; 53 | 54 | // load resnet18 vmfb to session 55 | session.append_bytecode_module_from_memory(RESNET18_VMFB.as_slice(), &allocator)?; 56 | 57 | // // get the entry function 58 | let mut call = session.get_call_by_name("module.forward")?; 59 | 60 | // load input image 61 | let j: Image = serde_json::from_slice(&TEST_IMAGE).unwrap(); 62 | 63 | // get device allocator 64 | let device_allocator = session.device_allocator(); 65 | 66 | // convert image to const byte span 67 | let bytespan = IreeConstByteSpan::from_slice(&j.data); 68 | let image_shape = j.shape; 69 | let buffer_params = IreeHalBufferViewParamsBuilder::default() 70 | .type_(iree_hal_memory_type_bits_t_IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL.0) 71 | .access(0) 72 | .usage(iree_hal_buffer_usage_bits_t_IREE_HAL_BUFFER_USAGE_DEFAULT.0) 73 | .build(); 74 | 75 | // create hal buffer view 76 | let input = IreeHalBufferView::allocate_buffer( 77 | &device_allocator, 78 | &image_shape, 79 | iree_hal_element_types_t_IREE_HAL_ELEMENT_TYPE_FLOAT_32, 80 | iree_hal_encoding_types_t_IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, 81 | &buffer_params, 82 | &bytespan, 83 | )?; 84 | 85 | // push input to call 86 | call.inputs_push_back_buffer_view(&input)?; 87 | 88 | // invoke call 89 | call.invoke(iree_runtime_call_flags_t::default())?; 90 | 91 | // pop output from call 92 | let output = call.outputs_pop_front_buffer_view()?; 93 | 94 | println!("output: {}", output); 95 | 96 | Ok(()) 97 | } 98 | 99 | pub fn main() { 100 | run_resnet18().unwrap(); 101 | } 102 | 103 | #[cfg(test)] 104 | mod tests { 105 | use super::*; 106 | #[test] 107 | fn test_resnet18() { 108 | run_resnet18().unwrap(); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /iree-sys/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "adler" 7 | version = "1.0.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" 10 | 11 | [[package]] 12 | name = "aes" 13 | version = "0.7.5" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" 16 | dependencies = [ 17 | "cfg-if", 18 | "cipher", 19 | "cpufeatures", 20 | "opaque-debug", 21 | ] 22 | 23 | [[package]] 24 | name = "anyhow" 25 | version = "1.0.69" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" 28 | 29 | [[package]] 30 | name = "autocfg" 31 | version = "1.1.0" 32 | source = "registry+https://github.com/rust-lang/crates.io-index" 33 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 34 | 35 | [[package]] 36 | name = "base64ct" 37 | version = "1.5.3" 38 | source = "registry+https://github.com/rust-lang/crates.io-index" 39 | checksum = "b645a089122eccb6111b4f81cbc1a49f5900ac4666bb93ac027feaecf15607bf" 40 | 41 | [[package]] 42 | name = "bindgen" 43 | version = "0.63.0" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885" 46 | dependencies = [ 47 | "bitflags", 48 | "cexpr", 49 | "clang-sys", 50 | "lazy_static", 51 | "lazycell", 52 | "log", 53 | "peeking_take_while", 54 | "proc-macro2", 55 | "quote", 56 | "regex", 57 | "rustc-hash", 58 | "shlex", 59 | "syn", 60 | "which", 61 | ] 62 | 63 | [[package]] 64 | name = "bitflags" 65 | version = "1.3.2" 66 | source = "registry+https://github.com/rust-lang/crates.io-index" 67 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 68 | 69 | [[package]] 70 | name = "block-buffer" 71 | version = "0.10.3" 72 | source = "registry+https://github.com/rust-lang/crates.io-index" 73 | checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" 74 | dependencies = [ 75 | "generic-array", 76 | ] 77 | 78 | [[package]] 79 | name = "byteorder" 80 | version = "1.4.3" 81 | source = "registry+https://github.com/rust-lang/crates.io-index" 82 | checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" 83 | 84 | [[package]] 85 | name = "bzip2" 86 | version = "0.4.4" 87 | source = "registry+https://github.com/rust-lang/crates.io-index" 88 | checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" 89 | dependencies = [ 90 | "bzip2-sys", 91 | "libc", 92 | ] 93 | 94 | [[package]] 95 | name = "bzip2-sys" 96 | version = "0.1.11+1.0.8" 97 | source = "registry+https://github.com/rust-lang/crates.io-index" 98 | checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" 99 | dependencies = [ 100 | "cc", 101 | "libc", 102 | "pkg-config", 103 | ] 104 | 105 | [[package]] 106 | name = "cc" 107 | version = "1.0.79" 108 | source = "registry+https://github.com/rust-lang/crates.io-index" 109 | checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" 110 | dependencies = [ 111 | "jobserver", 112 | ] 113 | 114 | [[package]] 115 | name = "cexpr" 116 | version = "0.6.0" 117 | source = "registry+https://github.com/rust-lang/crates.io-index" 118 | checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" 119 | dependencies = [ 120 | "nom", 121 | ] 122 | 123 | [[package]] 124 | name = "cfg-if" 125 | version = "1.0.0" 126 | source = "registry+https://github.com/rust-lang/crates.io-index" 127 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 128 | 129 | [[package]] 130 | name = "cipher" 131 | version = "0.3.0" 132 | source = "registry+https://github.com/rust-lang/crates.io-index" 133 | checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" 134 | dependencies = [ 135 | "generic-array", 136 | ] 137 | 138 | [[package]] 139 | name = "clang-sys" 140 | version = "1.4.0" 141 | source = "registry+https://github.com/rust-lang/crates.io-index" 142 | checksum = "fa2e27ae6ab525c3d369ded447057bca5438d86dc3a68f6faafb8269ba82ebf3" 143 | dependencies = [ 144 | "glob", 145 | "libc", 146 | "libloading", 147 | ] 148 | 149 | [[package]] 150 | name = "constant_time_eq" 151 | version = "0.1.5" 152 | source = "registry+https://github.com/rust-lang/crates.io-index" 153 | checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" 154 | 155 | [[package]] 156 | name = "cpufeatures" 157 | version = "0.2.5" 158 | source = "registry+https://github.com/rust-lang/crates.io-index" 159 | checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" 160 | dependencies = [ 161 | "libc", 162 | ] 163 | 164 | [[package]] 165 | name = "crc32fast" 166 | version = "1.3.2" 167 | source = "registry+https://github.com/rust-lang/crates.io-index" 168 | checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" 169 | dependencies = [ 170 | "cfg-if", 171 | ] 172 | 173 | [[package]] 174 | name = "crossbeam-utils" 175 | version = "0.8.14" 176 | source = "registry+https://github.com/rust-lang/crates.io-index" 177 | checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" 178 | dependencies = [ 179 | "cfg-if", 180 | ] 181 | 182 | [[package]] 183 | name = "crypto-common" 184 | version = "0.1.6" 185 | source = "registry+https://github.com/rust-lang/crates.io-index" 186 | checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" 187 | dependencies = [ 188 | "generic-array", 189 | "typenum", 190 | ] 191 | 192 | [[package]] 193 | name = "curl" 194 | version = "0.4.44" 195 | source = "registry+https://github.com/rust-lang/crates.io-index" 196 | checksum = "509bd11746c7ac09ebd19f0b17782eae80aadee26237658a6b4808afb5c11a22" 197 | dependencies = [ 198 | "curl-sys", 199 | "libc", 200 | "openssl-probe", 201 | "openssl-sys", 202 | "schannel", 203 | "socket2", 204 | "winapi", 205 | ] 206 | 207 | [[package]] 208 | name = "curl-sys" 209 | version = "0.4.59+curl-7.86.0" 210 | source = "registry+https://github.com/rust-lang/crates.io-index" 211 | checksum = "6cfce34829f448b08f55b7db6d0009e23e2e86a34e8c2b366269bf5799b4a407" 212 | dependencies = [ 213 | "cc", 214 | "libc", 215 | "libz-sys", 216 | "openssl-sys", 217 | "pkg-config", 218 | "vcpkg", 219 | "winapi", 220 | ] 221 | 222 | [[package]] 223 | name = "digest" 224 | version = "0.10.6" 225 | source = "registry+https://github.com/rust-lang/crates.io-index" 226 | checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" 227 | dependencies = [ 228 | "block-buffer", 229 | "crypto-common", 230 | "subtle", 231 | ] 232 | 233 | [[package]] 234 | name = "either" 235 | version = "1.8.1" 236 | source = "registry+https://github.com/rust-lang/crates.io-index" 237 | checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" 238 | 239 | [[package]] 240 | name = "flatbuffers" 241 | version = "23.1.21" 242 | source = "registry+https://github.com/rust-lang/crates.io-index" 243 | checksum = "77f5399c2c9c50ae9418e522842ad362f61ee48b346ac106807bd355a8a7c619" 244 | dependencies = [ 245 | "bitflags", 246 | "rustc_version", 247 | "serde", 248 | ] 249 | 250 | [[package]] 251 | name = "flate2" 252 | version = "1.0.25" 253 | source = "registry+https://github.com/rust-lang/crates.io-index" 254 | checksum = "a8a2db397cb1c8772f31494cb8917e48cd1e64f0fa7efac59fbd741a0a8ce841" 255 | dependencies = [ 256 | "crc32fast", 257 | "miniz_oxide", 258 | ] 259 | 260 | [[package]] 261 | name = "generic-array" 262 | version = "0.14.6" 263 | source = "registry+https://github.com/rust-lang/crates.io-index" 264 | checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" 265 | dependencies = [ 266 | "typenum", 267 | "version_check", 268 | ] 269 | 270 | [[package]] 271 | name = "getrandom" 272 | version = "0.2.8" 273 | source = "registry+https://github.com/rust-lang/crates.io-index" 274 | checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" 275 | dependencies = [ 276 | "cfg-if", 277 | "libc", 278 | "wasi", 279 | ] 280 | 281 | [[package]] 282 | name = "glob" 283 | version = "0.3.1" 284 | source = "registry+https://github.com/rust-lang/crates.io-index" 285 | checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" 286 | 287 | [[package]] 288 | name = "half" 289 | version = "1.8.2" 290 | source = "registry+https://github.com/rust-lang/crates.io-index" 291 | checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" 292 | 293 | [[package]] 294 | name = "hmac" 295 | version = "0.12.1" 296 | source = "registry+https://github.com/rust-lang/crates.io-index" 297 | checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" 298 | dependencies = [ 299 | "digest", 300 | ] 301 | 302 | [[package]] 303 | name = "iree-sys" 304 | version = "0.1.0" 305 | dependencies = [ 306 | "anyhow", 307 | "bindgen", 308 | "flatbuffers", 309 | "once_cell", 310 | "serde", 311 | "serde_json", 312 | "tch", 313 | ] 314 | 315 | [[package]] 316 | name = "itoa" 317 | version = "1.0.5" 318 | source = "registry+https://github.com/rust-lang/crates.io-index" 319 | checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" 320 | 321 | [[package]] 322 | name = "jobserver" 323 | version = "0.1.25" 324 | source = "registry+https://github.com/rust-lang/crates.io-index" 325 | checksum = "068b1ee6743e4d11fb9c6a1e6064b3693a1b600e7f5f5988047d98b3dc9fb90b" 326 | dependencies = [ 327 | "libc", 328 | ] 329 | 330 | [[package]] 331 | name = "lazy_static" 332 | version = "1.4.0" 333 | source = "registry+https://github.com/rust-lang/crates.io-index" 334 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 335 | 336 | [[package]] 337 | name = "lazycell" 338 | version = "1.3.0" 339 | source = "registry+https://github.com/rust-lang/crates.io-index" 340 | checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" 341 | 342 | [[package]] 343 | name = "libc" 344 | version = "0.2.139" 345 | source = "registry+https://github.com/rust-lang/crates.io-index" 346 | checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" 347 | 348 | [[package]] 349 | name = "libloading" 350 | version = "0.7.4" 351 | source = "registry+https://github.com/rust-lang/crates.io-index" 352 | checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" 353 | dependencies = [ 354 | "cfg-if", 355 | "winapi", 356 | ] 357 | 358 | [[package]] 359 | name = "libz-sys" 360 | version = "1.1.8" 361 | source = "registry+https://github.com/rust-lang/crates.io-index" 362 | checksum = "9702761c3935f8cc2f101793272e202c72b99da8f4224a19ddcf1279a6450bbf" 363 | dependencies = [ 364 | "cc", 365 | "libc", 366 | "pkg-config", 367 | "vcpkg", 368 | ] 369 | 370 | [[package]] 371 | name = "log" 372 | version = "0.4.17" 373 | source = "registry+https://github.com/rust-lang/crates.io-index" 374 | checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" 375 | dependencies = [ 376 | "cfg-if", 377 | ] 378 | 379 | [[package]] 380 | name = "matrixmultiply" 381 | version = "0.3.2" 382 | source = "registry+https://github.com/rust-lang/crates.io-index" 383 | checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" 384 | dependencies = [ 385 | "rawpointer", 386 | ] 387 | 388 | [[package]] 389 | name = "memchr" 390 | version = "2.5.0" 391 | source = "registry+https://github.com/rust-lang/crates.io-index" 392 | checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" 393 | 394 | [[package]] 395 | name = "minimal-lexical" 396 | version = "0.2.1" 397 | source = "registry+https://github.com/rust-lang/crates.io-index" 398 | checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" 399 | 400 | [[package]] 401 | name = "miniz_oxide" 402 | version = "0.6.2" 403 | source = "registry+https://github.com/rust-lang/crates.io-index" 404 | checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" 405 | dependencies = [ 406 | "adler", 407 | ] 408 | 409 | [[package]] 410 | name = "ndarray" 411 | version = "0.15.6" 412 | source = "registry+https://github.com/rust-lang/crates.io-index" 413 | checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" 414 | dependencies = [ 415 | "matrixmultiply", 416 | "num-complex", 417 | "num-integer", 418 | "num-traits", 419 | "rawpointer", 420 | ] 421 | 422 | [[package]] 423 | name = "nom" 424 | version = "7.1.3" 425 | source = "registry+https://github.com/rust-lang/crates.io-index" 426 | checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" 427 | dependencies = [ 428 | "memchr", 429 | "minimal-lexical", 430 | ] 431 | 432 | [[package]] 433 | name = "num-complex" 434 | version = "0.4.3" 435 | source = "registry+https://github.com/rust-lang/crates.io-index" 436 | checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" 437 | dependencies = [ 438 | "num-traits", 439 | ] 440 | 441 | [[package]] 442 | name = "num-integer" 443 | version = "0.1.45" 444 | source = "registry+https://github.com/rust-lang/crates.io-index" 445 | checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" 446 | dependencies = [ 447 | "autocfg", 448 | "num-traits", 449 | ] 450 | 451 | [[package]] 452 | name = "num-traits" 453 | version = "0.2.15" 454 | source = "registry+https://github.com/rust-lang/crates.io-index" 455 | checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" 456 | dependencies = [ 457 | "autocfg", 458 | ] 459 | 460 | [[package]] 461 | name = "once_cell" 462 | version = "1.17.0" 463 | source = "registry+https://github.com/rust-lang/crates.io-index" 464 | checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" 465 | 466 | [[package]] 467 | name = "opaque-debug" 468 | version = "0.3.0" 469 | source = "registry+https://github.com/rust-lang/crates.io-index" 470 | checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" 471 | 472 | [[package]] 473 | name = "openssl-probe" 474 | version = "0.1.5" 475 | source = "registry+https://github.com/rust-lang/crates.io-index" 476 | checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" 477 | 478 | [[package]] 479 | name = "openssl-sys" 480 | version = "0.9.80" 481 | source = "registry+https://github.com/rust-lang/crates.io-index" 482 | checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" 483 | dependencies = [ 484 | "autocfg", 485 | "cc", 486 | "libc", 487 | "pkg-config", 488 | "vcpkg", 489 | ] 490 | 491 | [[package]] 492 | name = "password-hash" 493 | version = "0.4.2" 494 | source = "registry+https://github.com/rust-lang/crates.io-index" 495 | checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" 496 | dependencies = [ 497 | "base64ct", 498 | "rand_core", 499 | "subtle", 500 | ] 501 | 502 | [[package]] 503 | name = "pbkdf2" 504 | version = "0.11.0" 505 | source = "registry+https://github.com/rust-lang/crates.io-index" 506 | checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" 507 | dependencies = [ 508 | "digest", 509 | "hmac", 510 | "password-hash", 511 | "sha2", 512 | ] 513 | 514 | [[package]] 515 | name = "peeking_take_while" 516 | version = "0.1.2" 517 | source = "registry+https://github.com/rust-lang/crates.io-index" 518 | checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" 519 | 520 | [[package]] 521 | name = "pkg-config" 522 | version = "0.3.26" 523 | source = "registry+https://github.com/rust-lang/crates.io-index" 524 | checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" 525 | 526 | [[package]] 527 | name = "ppv-lite86" 528 | version = "0.2.17" 529 | source = "registry+https://github.com/rust-lang/crates.io-index" 530 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 531 | 532 | [[package]] 533 | name = "proc-macro2" 534 | version = "1.0.51" 535 | source = "registry+https://github.com/rust-lang/crates.io-index" 536 | checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" 537 | dependencies = [ 538 | "unicode-ident", 539 | ] 540 | 541 | [[package]] 542 | name = "quote" 543 | version = "1.0.23" 544 | source = "registry+https://github.com/rust-lang/crates.io-index" 545 | checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" 546 | dependencies = [ 547 | "proc-macro2", 548 | ] 549 | 550 | [[package]] 551 | name = "rand" 552 | version = "0.8.5" 553 | source = "registry+https://github.com/rust-lang/crates.io-index" 554 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 555 | dependencies = [ 556 | "libc", 557 | "rand_chacha", 558 | "rand_core", 559 | ] 560 | 561 | [[package]] 562 | name = "rand_chacha" 563 | version = "0.3.1" 564 | source = "registry+https://github.com/rust-lang/crates.io-index" 565 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 566 | dependencies = [ 567 | "ppv-lite86", 568 | "rand_core", 569 | ] 570 | 571 | [[package]] 572 | name = "rand_core" 573 | version = "0.6.4" 574 | source = "registry+https://github.com/rust-lang/crates.io-index" 575 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 576 | dependencies = [ 577 | "getrandom", 578 | ] 579 | 580 | [[package]] 581 | name = "rawpointer" 582 | version = "0.2.1" 583 | source = "registry+https://github.com/rust-lang/crates.io-index" 584 | checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" 585 | 586 | [[package]] 587 | name = "regex" 588 | version = "1.7.1" 589 | source = "registry+https://github.com/rust-lang/crates.io-index" 590 | checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" 591 | dependencies = [ 592 | "regex-syntax", 593 | ] 594 | 595 | [[package]] 596 | name = "regex-syntax" 597 | version = "0.6.28" 598 | source = "registry+https://github.com/rust-lang/crates.io-index" 599 | checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" 600 | 601 | [[package]] 602 | name = "rustc-hash" 603 | version = "1.1.0" 604 | source = "registry+https://github.com/rust-lang/crates.io-index" 605 | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" 606 | 607 | [[package]] 608 | name = "rustc_version" 609 | version = "0.4.0" 610 | source = "registry+https://github.com/rust-lang/crates.io-index" 611 | checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" 612 | dependencies = [ 613 | "semver", 614 | ] 615 | 616 | [[package]] 617 | name = "ryu" 618 | version = "1.0.12" 619 | source = "registry+https://github.com/rust-lang/crates.io-index" 620 | checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" 621 | 622 | [[package]] 623 | name = "schannel" 624 | version = "0.1.21" 625 | source = "registry+https://github.com/rust-lang/crates.io-index" 626 | checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" 627 | dependencies = [ 628 | "windows-sys", 629 | ] 630 | 631 | [[package]] 632 | name = "semver" 633 | version = "1.0.16" 634 | source = "registry+https://github.com/rust-lang/crates.io-index" 635 | checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" 636 | 637 | [[package]] 638 | name = "serde" 639 | version = "1.0.152" 640 | source = "registry+https://github.com/rust-lang/crates.io-index" 641 | checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" 642 | dependencies = [ 643 | "serde_derive", 644 | ] 645 | 646 | [[package]] 647 | name = "serde_derive" 648 | version = "1.0.152" 649 | source = "registry+https://github.com/rust-lang/crates.io-index" 650 | checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" 651 | dependencies = [ 652 | "proc-macro2", 653 | "quote", 654 | "syn", 655 | ] 656 | 657 | [[package]] 658 | name = "serde_json" 659 | version = "1.0.93" 660 | source = "registry+https://github.com/rust-lang/crates.io-index" 661 | checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" 662 | dependencies = [ 663 | "itoa", 664 | "ryu", 665 | "serde", 666 | ] 667 | 668 | [[package]] 669 | name = "sha1" 670 | version = "0.10.5" 671 | source = "registry+https://github.com/rust-lang/crates.io-index" 672 | checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" 673 | dependencies = [ 674 | "cfg-if", 675 | "cpufeatures", 676 | "digest", 677 | ] 678 | 679 | [[package]] 680 | name = "sha2" 681 | version = "0.10.6" 682 | source = "registry+https://github.com/rust-lang/crates.io-index" 683 | checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" 684 | dependencies = [ 685 | "cfg-if", 686 | "cpufeatures", 687 | "digest", 688 | ] 689 | 690 | [[package]] 691 | name = "shlex" 692 | version = "1.1.0" 693 | source = "registry+https://github.com/rust-lang/crates.io-index" 694 | checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" 695 | 696 | [[package]] 697 | name = "socket2" 698 | version = "0.4.7" 699 | source = "registry+https://github.com/rust-lang/crates.io-index" 700 | checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" 701 | dependencies = [ 702 | "libc", 703 | "winapi", 704 | ] 705 | 706 | [[package]] 707 | name = "subtle" 708 | version = "2.4.1" 709 | source = "registry+https://github.com/rust-lang/crates.io-index" 710 | checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" 711 | 712 | [[package]] 713 | name = "syn" 714 | version = "1.0.107" 715 | source = "registry+https://github.com/rust-lang/crates.io-index" 716 | checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" 717 | dependencies = [ 718 | "proc-macro2", 719 | "quote", 720 | "unicode-ident", 721 | ] 722 | 723 | [[package]] 724 | name = "tch" 725 | version = "0.10.1" 726 | source = "registry+https://github.com/rust-lang/crates.io-index" 727 | checksum = "f4e8ecac1bcd6c92726de9b1e998aa99b3977af0716992382dfd1171289b9575" 728 | dependencies = [ 729 | "half", 730 | "lazy_static", 731 | "libc", 732 | "ndarray", 733 | "rand", 734 | "thiserror", 735 | "torch-sys", 736 | "zip", 737 | ] 738 | 739 | [[package]] 740 | name = "thiserror" 741 | version = "1.0.38" 742 | source = "registry+https://github.com/rust-lang/crates.io-index" 743 | checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" 744 | dependencies = [ 745 | "thiserror-impl", 746 | ] 747 | 748 | [[package]] 749 | name = "thiserror-impl" 750 | version = "1.0.38" 751 | source = "registry+https://github.com/rust-lang/crates.io-index" 752 | checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" 753 | dependencies = [ 754 | "proc-macro2", 755 | "quote", 756 | "syn", 757 | ] 758 | 759 | [[package]] 760 | name = "time" 761 | version = "0.3.17" 762 | source = "registry+https://github.com/rust-lang/crates.io-index" 763 | checksum = "a561bf4617eebd33bca6434b988f39ed798e527f51a1e797d0ee4f61c0a38376" 764 | dependencies = [ 765 | "serde", 766 | "time-core", 767 | ] 768 | 769 | [[package]] 770 | name = "time-core" 771 | version = "0.1.0" 772 | source = "registry+https://github.com/rust-lang/crates.io-index" 773 | checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" 774 | 775 | [[package]] 776 | name = "torch-sys" 777 | version = "0.10.0" 778 | source = "registry+https://github.com/rust-lang/crates.io-index" 779 | checksum = "877dbdc2732bdb118a71c94d0004333d29f76ebb5e88f193a3abe068f7bd6de9" 780 | dependencies = [ 781 | "anyhow", 782 | "cc", 783 | "curl", 784 | "libc", 785 | "zip", 786 | ] 787 | 788 | [[package]] 789 | name = "typenum" 790 | version = "1.16.0" 791 | source = "registry+https://github.com/rust-lang/crates.io-index" 792 | checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" 793 | 794 | [[package]] 795 | name = "unicode-ident" 796 | version = "1.0.6" 797 | source = "registry+https://github.com/rust-lang/crates.io-index" 798 | checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" 799 | 800 | [[package]] 801 | name = "vcpkg" 802 | version = "0.2.15" 803 | source = "registry+https://github.com/rust-lang/crates.io-index" 804 | checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" 805 | 806 | [[package]] 807 | name = "version_check" 808 | version = "0.9.4" 809 | source = "registry+https://github.com/rust-lang/crates.io-index" 810 | checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" 811 | 812 | [[package]] 813 | name = "wasi" 814 | version = "0.11.0+wasi-snapshot-preview1" 815 | source = "registry+https://github.com/rust-lang/crates.io-index" 816 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 817 | 818 | [[package]] 819 | name = "which" 820 | version = "4.4.0" 821 | source = "registry+https://github.com/rust-lang/crates.io-index" 822 | checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" 823 | dependencies = [ 824 | "either", 825 | "libc", 826 | "once_cell", 827 | ] 828 | 829 | [[package]] 830 | name = "winapi" 831 | version = "0.3.9" 832 | source = "registry+https://github.com/rust-lang/crates.io-index" 833 | checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" 834 | dependencies = [ 835 | "winapi-i686-pc-windows-gnu", 836 | "winapi-x86_64-pc-windows-gnu", 837 | ] 838 | 839 | [[package]] 840 | name = "winapi-i686-pc-windows-gnu" 841 | version = "0.4.0" 842 | source = "registry+https://github.com/rust-lang/crates.io-index" 843 | checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 844 | 845 | [[package]] 846 | name = "winapi-x86_64-pc-windows-gnu" 847 | version = "0.4.0" 848 | source = "registry+https://github.com/rust-lang/crates.io-index" 849 | checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 850 | 851 | [[package]] 852 | name = "windows-sys" 853 | version = "0.42.0" 854 | source = "registry+https://github.com/rust-lang/crates.io-index" 855 | checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" 856 | dependencies = [ 857 | "windows_aarch64_gnullvm", 858 | "windows_aarch64_msvc", 859 | "windows_i686_gnu", 860 | "windows_i686_msvc", 861 | "windows_x86_64_gnu", 862 | "windows_x86_64_gnullvm", 863 | "windows_x86_64_msvc", 864 | ] 865 | 866 | [[package]] 867 | name = "windows_aarch64_gnullvm" 868 | version = "0.42.1" 869 | source = "registry+https://github.com/rust-lang/crates.io-index" 870 | checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" 871 | 872 | [[package]] 873 | name = "windows_aarch64_msvc" 874 | version = "0.42.1" 875 | source = "registry+https://github.com/rust-lang/crates.io-index" 876 | checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" 877 | 878 | [[package]] 879 | name = "windows_i686_gnu" 880 | version = "0.42.1" 881 | source = "registry+https://github.com/rust-lang/crates.io-index" 882 | checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" 883 | 884 | [[package]] 885 | name = "windows_i686_msvc" 886 | version = "0.42.1" 887 | source = "registry+https://github.com/rust-lang/crates.io-index" 888 | checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" 889 | 890 | [[package]] 891 | name = "windows_x86_64_gnu" 892 | version = "0.42.1" 893 | source = "registry+https://github.com/rust-lang/crates.io-index" 894 | checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" 895 | 896 | [[package]] 897 | name = "windows_x86_64_gnullvm" 898 | version = "0.42.1" 899 | source = "registry+https://github.com/rust-lang/crates.io-index" 900 | checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" 901 | 902 | [[package]] 903 | name = "windows_x86_64_msvc" 904 | version = "0.42.1" 905 | source = "registry+https://github.com/rust-lang/crates.io-index" 906 | checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" 907 | 908 | [[package]] 909 | name = "zip" 910 | version = "0.6.4" 911 | source = "registry+https://github.com/rust-lang/crates.io-index" 912 | checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef" 913 | dependencies = [ 914 | "aes", 915 | "byteorder", 916 | "bzip2", 917 | "constant_time_eq", 918 | "crc32fast", 919 | "crossbeam-utils", 920 | "flate2", 921 | "hmac", 922 | "pbkdf2", 923 | "sha1", 924 | "time", 925 | "zstd", 926 | ] 927 | 928 | [[package]] 929 | name = "zstd" 930 | version = "0.11.2+zstd.1.5.2" 931 | source = "registry+https://github.com/rust-lang/crates.io-index" 932 | checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" 933 | dependencies = [ 934 | "zstd-safe", 935 | ] 936 | 937 | [[package]] 938 | name = "zstd-safe" 939 | version = "5.0.2+zstd.1.5.2" 940 | source = "registry+https://github.com/rust-lang/crates.io-index" 941 | checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" 942 | dependencies = [ 943 | "libc", 944 | "zstd-sys", 945 | ] 946 | 947 | [[package]] 948 | name = "zstd-sys" 949 | version = "2.0.7+zstd.1.5.4" 950 | source = "registry+https://github.com/rust-lang/crates.io-index" 951 | checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" 952 | dependencies = [ 953 | "cc", 954 | "libc", 955 | "pkg-config", 956 | ] 957 | -------------------------------------------------------------------------------- /iree-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "iree-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | links = "iree" 6 | description = "Rust FFI bindings for IREE" 7 | license = "MIT" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [build-dependencies] 12 | bindgen = "0.63.0" 13 | cmake = "0.1.49" 14 | git2 = "0.16.1" 15 | pkg-config = "0.3.26" 16 | 17 | [dependencies] 18 | anyhow = "1.0.69" 19 | 20 | [dev-dependencies] 21 | flatbuffers = { version = "23.1.21", features = ["serde"] } 22 | once_cell = "1.17.0" 23 | serde = { version = "1.0.152", features = ["derive"] } 24 | serde_json = "1.0.93" 25 | tch = "0.10.1" 26 | 27 | 28 | [[example]] 29 | name = "simple_mul" 30 | test = true 31 | 32 | [[example]] 33 | name = "resnet18" 34 | test = true 35 | -------------------------------------------------------------------------------- /iree-sys/build.rs: -------------------------------------------------------------------------------- 1 | extern crate bindgen; 2 | 3 | use std::env; 4 | use std::path::{Path, PathBuf}; 5 | use std::process::{Command, Output}; 6 | 7 | use git2::Repository; 8 | 9 | static IREE_SAMPLES_REPO: &str = "https://github.com/iree-org/iree-samples"; 10 | static IREE_REPO: &str = "https://github.com/iree-org/iree"; 11 | 12 | fn shallow_clone(path: &Path, repo: &str) -> Repository { 13 | let mut child = Command::new("git") 14 | .args(&[ 15 | "clone", 16 | "--depth", 17 | "1", 18 | "--recurse-submodules", 19 | "--shallow-submodules", 20 | "-j10", 21 | repo, 22 | path.to_str().unwrap(), 23 | ]) 24 | .spawn() 25 | .expect("failed to execute process"); 26 | child.wait().unwrap(); 27 | 28 | git2::Repository::open(path).unwrap() 29 | } 30 | 31 | /// use cached repo if it exists, otherwise clone it 32 | fn get_repo(path: &Path, repo: &str) -> git2::Repository { 33 | println!("Checking for cached repo at: {}", path.to_str().unwrap()); 34 | if path.exists() { 35 | git2::Repository::open(path).unwrap() 36 | } else { 37 | // shallow clone 38 | shallow_clone(path, repo) 39 | } 40 | } 41 | 42 | /// Clones the IREE repository and builds it. 43 | fn clone_and_build_iree(out_dir: &Path) -> PathBuf { 44 | // clone IREE repo 45 | let iree_dir = out_dir.join("iree"); 46 | let iree = get_repo(iree_dir.as_path(), IREE_REPO); 47 | 48 | // clone IREE samples repo 49 | let iree_samples = get_repo(&out_dir.join("iree-samples"), IREE_SAMPLES_REPO); 50 | 51 | // make build directory 52 | let mut iree_samples_build_path = out_dir.join("iree-samples-build"); 53 | if iree_samples_build_path.exists() { 54 | // already built! 55 | return iree_samples_build_path; 56 | } 57 | std::fs::create_dir_all(iree_samples_build_path.clone()).unwrap(); 58 | 59 | // build iree-samples 60 | cmake::Config::new(out_dir.join("iree-samples/runtime-library")) 61 | .define("BUILD_SHARED_LIBS", "OFF") 62 | .define("CMAKE_C_COMPILER", "clang") 63 | .define("CMAKE_CXX_COMPILER", "clang++") 64 | .define( 65 | "IREE_ROOT_DIR", 66 | out_dir 67 | .join("iree") 68 | .canonicalize() 69 | .unwrap() 70 | .to_str() 71 | .unwrap(), 72 | ) 73 | .out_dir(iree_samples_build_path.clone()) 74 | .build(); 75 | 76 | // add library path to linker 77 | 78 | iree_samples_build_path 79 | } 80 | 81 | fn main() { 82 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); 83 | let iree_build_dir = clone_and_build_iree(out_path.as_path()); 84 | println!( 85 | "cargo:rustc-link-search={}", 86 | iree_build_dir.join("build/lib").to_str().unwrap() 87 | ); 88 | 89 | // add built third party libraries to linker 90 | // cpuinfo 91 | println!( 92 | "cargo:rustc-link-search={}", 93 | iree_build_dir 94 | .join("build/iree_core/third_party/cpuinfo/") 95 | .to_str() 96 | .unwrap() 97 | ); 98 | 99 | // flatcc 100 | println!( 101 | "cargo:rustc-link-search={}", 102 | iree_build_dir 103 | .join("build/iree_core/build_tools/third_party/flatcc/") 104 | .to_str() 105 | .unwrap() 106 | ); 107 | 108 | // clog 109 | println!( 110 | "cargo:rustc-link-search={}", 111 | iree_build_dir 112 | .join("build/iree_core/third_party/cpuinfo/deps/clog/") 113 | .to_str() 114 | .unwrap() 115 | ); 116 | let iree_include_dir = iree_build_dir.as_path().join("build/include"); 117 | 118 | println!("cargo:rustc-link-lib=iree"); 119 | 120 | // third party libraries 121 | println!("cargo:rustc-link-lib=cpuinfo"); 122 | println!("cargo:rustc-link-lib=flatcc_parsing"); 123 | println!("cargo:rustc-link-lib=clog"); 124 | println!("cargo:rustc-link-lib=stdc++"); 125 | 126 | // gather all api headers we want 127 | let iree_api_headers = ["iree/runtime/api.h"]; 128 | 129 | for &header in iree_api_headers.iter() { 130 | let header_out = Path::new(header) 131 | .to_str() 132 | .and_then(|s| s.strip_suffix(".h")) 133 | .and_then(|s| Some(format!("{}.rs", s))) 134 | .unwrap(); 135 | 136 | if out_path.join(header_out.clone()).exists() { 137 | // already generated 138 | continue; 139 | } 140 | let header_buf = iree_include_dir.join(header); 141 | let header_path = header_buf.as_path(); 142 | 143 | let dir = out_path.join(Path::new(header).parent().unwrap()); 144 | 145 | if !dir.exists() { 146 | std::fs::create_dir_all(&dir).expect("Unable to create directory"); 147 | } 148 | 149 | let bindings = bindgen::Builder::default() 150 | .header(header_path.to_str().unwrap()) 151 | .clang_arg(format!("-I{}", iree_include_dir.to_str().unwrap())) 152 | .default_enum_style(bindgen::EnumVariation::NewType { 153 | is_bitfield: true, 154 | is_global: true, 155 | }) 156 | .generate_inline_functions(false) 157 | .derive_default(true) 158 | .parse_callbacks(Box::new(bindgen::CargoCallbacks)) 159 | .generate() 160 | .expect("Unable to generate bindings"); 161 | 162 | bindings 163 | .write_to_file(out_path.join(header_out)) 164 | .expect("Couldn't write bindings!"); 165 | } 166 | 167 | println!("cargo:rerun-if-changed=build.rs"); 168 | } 169 | -------------------------------------------------------------------------------- /iree-sys/examples/resnet18.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::{helper::*, iree::runtime::api::*}; 2 | use once_cell::sync::Lazy; 3 | use serde::{Deserialize, Serialize}; 4 | use std::{ffi::CString, ptr::null_mut, time::Instant}; 5 | use tch; 6 | 7 | #[derive(Deserialize, Serialize)] 8 | struct Image { 9 | data: Vec, 10 | shape: Vec, 11 | } 12 | 13 | fn load_image() -> Result { 14 | let j = serde_json::from_str::(include_str!("test_image.json")).unwrap(); 15 | Ok(tch::Tensor::of_slice(&j.data).reshape(j.shape.as_slice())) 16 | } 17 | 18 | static mut FLATBUFFER_DATA: Lazy> = Lazy::new(|| include_bytes!("resnet18.vmfb").to_vec()); 19 | 20 | unsafe fn iree_runtime_demo_run_session(instance: *mut iree_runtime_instance_t) { 21 | // TODO(#5724): move device selection into the compiled modules. 22 | let mut device: *mut iree_hal_device_t = null_mut(); 23 | 24 | let s_str = CString::new("local-task").unwrap(); 25 | let string_view = iree_string_view_t { 26 | data: s_str.as_ptr() as *const i8, 27 | size: s_str.as_bytes().len(), 28 | }; 29 | 30 | let status = 31 | iree_runtime_instance_try_create_default_device(instance, string_view, &mut device as _); 32 | assert!( 33 | IREE_CHECK_OK(status), 34 | "status: {}", 35 | IREE_STATUS_TO_STRING(status) 36 | ); 37 | 38 | let allocator = iree_runtime_instance_host_allocator(instance); 39 | 40 | // Create one session per loaded module to hold the module state. 41 | let mut session_options = iree_runtime_session_options_t::default(); 42 | 43 | iree_runtime_session_options_initialize(&mut session_options as _); 44 | 45 | let mut session: *mut iree_runtime_session_t = null_mut(); 46 | let status = iree_runtime_session_create_with_device( 47 | instance, 48 | &session_options as _, 49 | device, 50 | allocator, 51 | &mut session as _, 52 | ); 53 | 54 | assert!( 55 | IREE_CHECK_OK(status), 56 | "status: {}", 57 | IREE_STATUS_TO_STRING(status) 58 | ); 59 | iree_hal_device_release(device); 60 | 61 | // Load your user module into the session (from memory, from file, etc). 62 | 63 | FLATBUFFER_DATA.push(0); 64 | let status = iree_runtime_session_append_bytecode_module_from_memory( 65 | session, 66 | iree_const_byte_span_t { 67 | data: FLATBUFFER_DATA.as_ptr() as _, 68 | data_length: FLATBUFFER_DATA.len(), 69 | }, 70 | iree_runtime_session_host_allocator(session), 71 | ); 72 | 73 | // let fpath = CString::new("examples/resnet18.vmfb").unwrap(); 74 | // let status = iree_runtime_session_append_bytecode_module_from_file( 75 | // session, 76 | // fpath.as_ptr() as *const c_char, 77 | // ); 78 | 79 | assert!( 80 | IREE_CHECK_OK(status), 81 | "status: {}", 82 | IREE_STATUS_TO_STRING(status) 83 | ); 84 | 85 | // Run your functions; you should reuse the session to make multiple calls. 86 | iree_runtime_demo_perform_mul(session); 87 | 88 | iree_runtime_session_release(session); 89 | } 90 | 91 | //===----------------------------------------------------------------------===// 92 | // 3. Call a function within a module with buffer views 93 | //===----------------------------------------------------------------------===// 94 | 95 | // func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> 96 | // tensor<4xf32> 97 | unsafe fn iree_runtime_demo_perform_mul(session: *mut iree_runtime_session_t) { 98 | let mut call = iree_runtime_call_t::default(); 99 | let mut str_n = "module.forward"; 100 | let status = iree_runtime_call_initialize_by_name( 101 | session, 102 | iree_string_view_t { 103 | data: str_n.as_ptr() as *const i8, 104 | size: str_n.len(), 105 | }, 106 | &mut call as _, 107 | ); 108 | 109 | assert!( 110 | IREE_CHECK_OK(status), 111 | "status: {}", 112 | IREE_STATUS_TO_STRING(status) 113 | ); 114 | 115 | // %arg0: tensor<4xf32> 116 | let mut arg0: *mut iree_hal_buffer_view_t = null_mut(); 117 | let img = load_image().unwrap(); 118 | 119 | let arg0_shape: Vec = img.size().iter().map(|x| *x as _).collect(); 120 | let arg0_data = img.to_kind(tch::Kind::Float); 121 | 122 | let allocator = iree_runtime_session_device_allocator(session); 123 | 124 | let byte_span = iree_const_byte_span_t { 125 | data: arg0_data.data_ptr() as _, 126 | data_length: arg0_data.flatten(0, -1).size()[0] as usize * std::mem::size_of::(), 127 | }; 128 | println!("byte_span: {:?}", byte_span); 129 | 130 | let mut buff_params = iree_hal_buffer_params_t::default(); 131 | buff_params.type_ = iree_hal_memory_type_bits_t_IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL.0; 132 | buff_params.access = 0; // fixme: incorrect type? 133 | buff_params.usage = iree_hal_buffer_usage_bits_t_IREE_HAL_BUFFER_USAGE_DEFAULT.0; 134 | 135 | let status = iree_hal_buffer_view_allocate_buffer( 136 | allocator, 137 | arg0_shape.len(), 138 | arg0_shape.as_ptr() as _, 139 | iree_hal_element_types_t_IREE_HAL_ELEMENT_TYPE_FLOAT_32.0, 140 | iree_hal_encoding_types_t_IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR.0, 141 | buff_params, 142 | byte_span, 143 | &mut arg0 as _, 144 | ); 145 | 146 | assert!( 147 | IREE_CHECK_OK(status), 148 | "status: {}", 149 | IREE_STATUS_TO_STRING(status) 150 | ); 151 | 152 | iree_hal_buffer_view_fprint( 153 | stdout, 154 | arg0, 155 | /*max_element_count=*/ 10, 156 | iree_runtime_session_host_allocator(session), 157 | ); 158 | iree_runtime_call_inputs_push_back_buffer_view(&mut call as _, arg0); 159 | iree_hal_buffer_view_release(arg0); 160 | 161 | let start_t = Instant::now(); 162 | let status = iree_runtime_call_invoke(&mut call as _, /*flags=*/ 0); 163 | let end_t = Instant::now(); 164 | 165 | println!("invoke time: {:?}", end_t - start_t); 166 | assert!( 167 | IREE_CHECK_OK(status), 168 | "status: {}", 169 | IREE_STATUS_TO_STRING(status) 170 | ); 171 | 172 | // -> tensor<4xf32> 173 | let mut ret0: *mut iree_hal_buffer_view_t = null_mut(); 174 | let status = iree_runtime_call_outputs_pop_front_buffer_view(&mut call as _, &mut ret0 as _); 175 | assert!( 176 | IREE_CHECK_OK(status), 177 | "status: {}", 178 | IREE_STATUS_TO_STRING(status) 179 | ); 180 | 181 | iree_hal_buffer_view_fprint( 182 | stdout, 183 | ret0, 184 | /*max_element_count=*/ 10, 185 | iree_runtime_session_host_allocator(session), 186 | ); 187 | iree_hal_buffer_view_release(ret0); 188 | 189 | iree_runtime_call_deinitialize(&mut call as _); 190 | } 191 | 192 | #[cfg(test)] 193 | pub mod test { 194 | 195 | use iree_sys::iree::runtime::api::*; 196 | use std::ptr::null_mut; 197 | 198 | use crate::*; 199 | 200 | #[test] 201 | fn main() { 202 | unsafe { 203 | let mut instance_options = iree_runtime_instance_options_t::default(); 204 | iree_runtime_instance_options_initialize(&mut instance_options as *mut _); 205 | iree_runtime_instance_options_use_all_available_drivers( 206 | &mut instance_options as *mut _, 207 | ); 208 | let mut instance: *mut iree_runtime_instance_t = null_mut(); 209 | 210 | let allocator = iree_allocator_t { 211 | self_: null_mut(), 212 | ctl: Some(iree_allocator_system_ctl as _), 213 | }; 214 | 215 | iree_runtime_instance_create(&instance_options, allocator, &mut instance as _); 216 | 217 | // All sessions should share the same instance. 218 | iree_runtime_demo_run_session(instance); 219 | 220 | iree_runtime_instance_release(instance); 221 | } 222 | } 223 | } 224 | 225 | fn main() {} 226 | -------------------------------------------------------------------------------- /iree-sys/examples/simple_mul.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::{helper::*, iree::runtime::api::*}; 2 | use std::{ffi::CString, os::raw::c_char, ptr::null_mut}; 3 | 4 | unsafe fn iree_runtime_demo_run_session(instance: *mut iree_runtime_instance_t) { 5 | // TODO(#5724): move device selection into the compiled modules. 6 | let mut device: *mut iree_hal_device_t = null_mut(); 7 | 8 | let s_str = CString::new("local-task").unwrap(); 9 | let string_view = iree_string_view_t { 10 | data: s_str.as_ptr() as *const i8, 11 | size: s_str.as_bytes().len(), 12 | }; 13 | 14 | let status = 15 | iree_runtime_instance_try_create_default_device(instance, string_view, &mut device as _); 16 | assert!( 17 | IREE_CHECK_OK(status), 18 | "status: {}", 19 | IREE_STATUS_TO_STRING(status) 20 | ); 21 | 22 | let allocator = iree_runtime_instance_host_allocator(instance); 23 | 24 | // Create one session per loaded module to hold the module state. 25 | let mut session_options = iree_runtime_session_options_t::default(); 26 | 27 | iree_runtime_session_options_initialize(&mut session_options as _); 28 | 29 | let mut session: *mut iree_runtime_session_t = null_mut(); 30 | let status = iree_runtime_session_create_with_device( 31 | instance, 32 | &session_options as _, 33 | device, 34 | allocator, 35 | &mut session as _, 36 | ); 37 | 38 | assert!( 39 | IREE_CHECK_OK(status), 40 | "status: {}", 41 | IREE_STATUS_TO_STRING(status) 42 | ); 43 | iree_hal_device_release(device); 44 | 45 | // Load your user module into the session (from memory, from file, etc). 46 | 47 | let fpath = CString::new("examples/simple_mul_module.vmfb").unwrap(); 48 | let status = iree_runtime_session_append_bytecode_module_from_file( 49 | session, 50 | fpath.as_ptr() as *const c_char, 51 | ); 52 | 53 | assert!( 54 | IREE_CHECK_OK(status), 55 | "status: {}", 56 | IREE_STATUS_TO_STRING(status) 57 | ); 58 | 59 | // Run your functions; you should reuse the session to make multiple calls. 60 | iree_runtime_demo_perform_mul(session); 61 | 62 | iree_runtime_session_release(session); 63 | } 64 | 65 | //===----------------------------------------------------------------------===// 66 | // 3. Call a function within a module with buffer views 67 | //===----------------------------------------------------------------------===// 68 | 69 | // func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> 70 | // tensor<4xf32> 71 | unsafe fn iree_runtime_demo_perform_mul(session: *mut iree_runtime_session_t) { 72 | let mut call = iree_runtime_call_t::default(); 73 | let status = iree_runtime_call_initialize_by_name( 74 | session, 75 | iree_string_view_t { 76 | data: "module.simple_mul".as_ptr() as *const i8, 77 | size: 17, 78 | }, 79 | &mut call as _, 80 | ); 81 | 82 | assert!( 83 | IREE_CHECK_OK(status), 84 | "status: {}", 85 | IREE_STATUS_TO_STRING(status) 86 | ); 87 | 88 | // %arg0: tensor<4xf32> 89 | let mut arg0: *mut iree_hal_buffer_view_t = null_mut(); 90 | let arg0_shape: [iree_hal_dim_t; 1] = [4]; 91 | let arg0_data: [f32; 4] = [1.0, 1.1, 1.2, 1.3]; 92 | 93 | let allocator = iree_runtime_session_device_allocator(session); 94 | 95 | let byte_span = iree_const_byte_span_t { 96 | data: arg0_data.as_ptr() as _, 97 | data_length: arg0_data.len() * std::mem::size_of::(), 98 | }; 99 | println!("byte_span: {:?}", byte_span); 100 | 101 | let mut buff_params = iree_hal_buffer_params_t::default(); 102 | buff_params.type_ = iree_hal_memory_type_bits_t_IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL.0; 103 | buff_params.access = 0; // fixme: incorrect type? 104 | buff_params.usage = iree_hal_buffer_usage_bits_t_IREE_HAL_BUFFER_USAGE_DEFAULT.0; 105 | 106 | let status = iree_hal_buffer_view_allocate_buffer( 107 | allocator, 108 | arg0_shape.len(), 109 | arg0_shape.as_ptr() as _, 110 | iree_hal_element_types_t_IREE_HAL_ELEMENT_TYPE_FLOAT_32.0, 111 | iree_hal_encoding_types_t_IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR.0, 112 | buff_params, 113 | byte_span, 114 | &mut arg0 as _, 115 | ); 116 | 117 | assert!(IREE_CHECK_OK(status)); 118 | 119 | iree_hal_buffer_view_fprint( 120 | stdout, 121 | arg0, 122 | /*max_element_count=*/ 4096, 123 | iree_runtime_session_host_allocator(session), 124 | ); 125 | iree_runtime_call_inputs_push_back_buffer_view(&mut call as _, arg0); 126 | iree_hal_buffer_view_release(arg0); 127 | 128 | // %arg1: tensor<4xf32> 129 | let mut arg1: *mut iree_hal_buffer_view_t = null_mut(); 130 | let arg1_shape: [iree_hal_dim_t; 1] = [4]; 131 | let arg1_data: [f32; 4] = [1.0, 10.0, 100.0, 1000.0]; 132 | 133 | let allocator = iree_runtime_session_device_allocator(session); 134 | 135 | let byte_span = iree_const_byte_span_t { 136 | data: arg1_data.as_ptr() as _, 137 | data_length: arg1_data.len() * std::mem::size_of::(), 138 | }; 139 | println!("byte_span: {:?}", byte_span); 140 | 141 | let mut buff_params = iree_hal_buffer_params_t::default(); 142 | buff_params.type_ = iree_hal_memory_type_bits_t_IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL.0; 143 | buff_params.access = 0; // fixme: incorrect type? 144 | buff_params.usage = iree_hal_buffer_usage_bits_t_IREE_HAL_BUFFER_USAGE_DEFAULT.0; 145 | 146 | let status = iree_hal_buffer_view_allocate_buffer( 147 | allocator, 148 | arg1_shape.len(), 149 | arg1_shape.as_ptr() as _, 150 | iree_hal_element_types_t_IREE_HAL_ELEMENT_TYPE_FLOAT_32.0, 151 | iree_hal_encoding_types_t_IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR.0, 152 | buff_params, 153 | byte_span, 154 | &mut arg1 as _, 155 | ); 156 | 157 | assert!(IREE_CHECK_OK(status)); 158 | 159 | iree_hal_buffer_view_fprint( 160 | stdout, 161 | arg1, 162 | /*max_element_count=*/ 4096, 163 | iree_runtime_session_host_allocator(session), 164 | ); 165 | iree_runtime_call_inputs_push_back_buffer_view(&mut call as _, arg1); 166 | iree_hal_buffer_view_release(arg1); 167 | 168 | let status = iree_runtime_call_invoke(&mut call as _, /*flags=*/ 0); 169 | assert!( 170 | IREE_CHECK_OK(status), 171 | "status: {}", 172 | IREE_STATUS_TO_STRING(status) 173 | ); 174 | 175 | // -> tensor<4xf32> 176 | let mut ret0: *mut iree_hal_buffer_view_t = null_mut(); 177 | let status = iree_runtime_call_outputs_pop_front_buffer_view(&mut call as _, &mut ret0 as _); 178 | assert!( 179 | IREE_CHECK_OK(status), 180 | "status: {}", 181 | IREE_STATUS_TO_STRING(status) 182 | ); 183 | 184 | iree_hal_buffer_view_fprint( 185 | stdout, 186 | ret0, 187 | /*max_element_count=*/ 4096, 188 | iree_runtime_session_host_allocator(session), 189 | ); 190 | iree_hal_buffer_view_release(ret0); 191 | 192 | iree_runtime_call_deinitialize(&mut call as _); 193 | } 194 | 195 | #[cfg(test)] 196 | pub mod test { 197 | 198 | use iree_sys::iree::runtime::api::*; 199 | use std::ptr::null_mut; 200 | 201 | use crate::*; 202 | 203 | #[test] 204 | fn main() { 205 | unsafe { 206 | let mut instance_options = iree_runtime_instance_options_t::default(); 207 | iree_runtime_instance_options_initialize(&mut instance_options as *mut _); 208 | iree_runtime_instance_options_use_all_available_drivers( 209 | &mut instance_options as *mut _, 210 | ); 211 | let mut instance: *mut iree_runtime_instance_t = null_mut(); 212 | 213 | let allocator = iree_allocator_t { 214 | self_: null_mut(), 215 | ctl: Some(iree_allocator_system_ctl as _), 216 | }; 217 | 218 | iree_runtime_instance_create(&instance_options, allocator, &mut instance as _); 219 | 220 | // All sessions should share the same instance. 221 | iree_runtime_demo_run_session(instance); 222 | 223 | iree_runtime_instance_release(instance); 224 | } 225 | } 226 | } 227 | 228 | fn main() {} 229 | -------------------------------------------------------------------------------- /iree-sys/examples/simple_mul_module.vmfb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SamKG/iree-rs/0c80c6ed4308e65a888f1f8022ab0e1bd31c3b17/iree-sys/examples/simple_mul_module.vmfb -------------------------------------------------------------------------------- /iree-sys/src/helper.rs: -------------------------------------------------------------------------------- 1 | use crate::iree::runtime::api::*; 2 | use std::{ffi::CString, ptr::null_mut}; 3 | 4 | pub unsafe fn IREE_CHECK_OK(status: *mut iree_status_handle_t) -> bool { 5 | return status == iree_status_code_e_IREE_STATUS_OK.0 as _; 6 | } 7 | 8 | pub unsafe fn IREE_STATUS_TO_STRING(status: *mut iree_status_handle_t) -> String { 9 | let host_allocator = iree_allocator_t { 10 | self_: null_mut(), 11 | ctl: Some(iree_allocator_system_ctl as _), 12 | }; 13 | let mut out_buffer: *mut i8 = null_mut(); 14 | let mut out_buffer_length: usize = 0; 15 | 16 | iree_status_to_string( 17 | status, 18 | &host_allocator as _, 19 | &mut out_buffer as _, 20 | &mut out_buffer_length as _, 21 | ); 22 | return CString::from_raw(out_buffer).to_str().unwrap().to_string(); 23 | } 24 | -------------------------------------------------------------------------------- /iree-sys/src/iree/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod runtime; 2 | -------------------------------------------------------------------------------- /iree-sys/src/iree/runtime/api.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_upper_case_globals)] 2 | #![allow(non_camel_case_types)] 3 | #![allow(non_snake_case)] 4 | #![allow(dead_code)] 5 | 6 | include!(concat!(env!("OUT_DIR"), "/iree/runtime/api.rs")); 7 | -------------------------------------------------------------------------------- /iree-sys/src/iree/runtime/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod api; 2 | -------------------------------------------------------------------------------- /iree-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod helper; 2 | pub mod iree; 3 | -------------------------------------------------------------------------------- /scripts/dump_mlir.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 51, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch_mlir" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 52, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from torchvision.models import resnet18, ResNet18_Weights\n", 20 | "model = resnet18(weights=ResNet18_Weights.DEFAULT)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 54, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "model.eval()\n", 30 | "\n", 31 | "compiled_module = torch_mlir.compile(model, example_args=[torch.ones((1, 3, 224, 224))], output_type=torch_mlir.OutputType.TOSA)\n", 32 | "\n", 33 | "from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend\n", 34 | "from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend\n", 35 | "# backend = LinalgOnTensorsTosaBackend()\n", 36 | "backend = RefBackendLinalgOnTensorsBackend()\n", 37 | "compiled_module = torch_mlir.compile(model, example_args=[torch.ones((1, 3, 224, 224))], output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", 38 | "runnable = backend.compile(compiled_module)\n", 39 | "jit_module = backend.load(runnable)\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 55, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "94063946" 51 | ] 52 | }, 53 | "execution_count": 55, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "outfile = open(\"resnet18.mlir\", \"w\")\n", 60 | "outfile.write(str(compiled_module))" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 56, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "torch.Size([1, 3, 224, 224])\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "# get example image\n", 78 | "from PIL import Image\n", 79 | "import requests\n", 80 | "from io import BytesIO\n", 81 | "from torchvision import transforms\n", 82 | "\n", 83 | "def load_and_preprocess_image(url: str):\n", 84 | " headers = {\n", 85 | " 'User-Agent':\n", 86 | " 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'\n", 87 | " }\n", 88 | " img = Image.open(requests.get(url, headers=headers,\n", 89 | " stream=True).raw).convert(\"RGB\")\n", 90 | " # preprocessing pipeline\n", 91 | " preprocess = transforms.Compose([\n", 92 | " transforms.Resize(256),\n", 93 | " transforms.CenterCrop(224),\n", 94 | " transforms.ToTensor(),\n", 95 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 96 | " std=[0.229, 0.224, 0.225]),\n", 97 | " ])\n", 98 | " img_preprocessed = preprocess(img)\n", 99 | " return torch.unsqueeze(img_preprocessed, 0)\n", 100 | "\n", 101 | "image_url = \"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg\"\n", 102 | "\n", 103 | "\n", 104 | "img = load_and_preprocess_image(image_url)\n", 105 | "\n", 106 | "print(img.shape)\n", 107 | "# write img to file\n", 108 | "torch.save(img.numpy(), \"../examples/test_image.pt\")\n", 109 | "# write to json \n", 110 | "import json\n", 111 | "with open(\"../examples/test_image.json\", \"w\") as f:\n", 112 | " json.dump({\n", 113 | " \"data\": img.numpy().flatten().tolist(), \n", 114 | " \"shape\": list(img.shape)},f)\n", 115 | "arg = img.numpy()" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 57, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "'2.0.0.dev20230209+cu117'" 127 | ] 128 | }, 129 | "execution_count": 57, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "torch.__version__" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 58, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "16.245480597004644" 147 | ] 148 | }, 149 | "execution_count": 58, 150 | "metadata": {}, 151 | "output_type": "execute_result" 152 | } 153 | ], 154 | "source": [ 155 | "import timeit\n", 156 | "\n", 157 | "timeit.timeit(lambda: jit_module.forward(arg), number=1)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 59, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "2.6393603989999974" 169 | ] 170 | }, 171 | "execution_count": 59, 172 | "metadata": {}, 173 | "output_type": "execute_result" 174 | } 175 | ], 176 | "source": [ 177 | "\n", 178 | "timeit.timeit(lambda: jit_module.forward(arg), number=1)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/plain": [ 189 | "0.06698099699860904" 190 | ] 191 | }, 192 | "execution_count": 50, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "\n", 199 | "timeit.timeit(lambda: model(img), number=1)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [] 208 | } 209 | ], 210 | "metadata": { 211 | "kernelspec": { 212 | "display_name": "Python 3", 213 | "language": "python", 214 | "name": "python3" 215 | }, 216 | "language_info": { 217 | "codemirror_mode": { 218 | "name": "ipython", 219 | "version": 3 220 | }, 221 | "file_extension": ".py", 222 | "mimetype": "text/x-python", 223 | "name": "python", 224 | "nbconvert_exporter": "python", 225 | "pygments_lexer": "ipython3", 226 | "version": "3.10.9" 227 | }, 228 | "orig_nbformat": 4, 229 | "vscode": { 230 | "interpreter": { 231 | "hash": "18824bd52c2964eef04022de1082fbe6ca8a05a9cab1618bc6c06c0883c4df04" 232 | } 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 2 237 | } 238 | -------------------------------------------------------------------------------- /scripts/torchscript_resnet18.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | import sys 7 | 8 | from PIL import Image 9 | import requests 10 | 11 | import torch 12 | import torchvision.models as models 13 | from torchvision import transforms 14 | 15 | import torch_mlir 16 | from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend 17 | 18 | 19 | 20 | 21 | def load_labels(): 22 | classes_text = requests.get( 23 | "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", 24 | stream=True, 25 | ).text 26 | labels = [line.strip() for line in classes_text.splitlines()] 27 | return labels 28 | 29 | 30 | def top3_possibilities(res): 31 | _, indexes = torch.sort(res, descending=True) 32 | percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 33 | top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] 34 | return top3 35 | 36 | 37 | def predictions(torch_func, jit_func, img, labels): 38 | golden_prediction = top3_possibilities(torch_func(img)) 39 | print("PyTorch prediction") 40 | print(golden_prediction) 41 | prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy()))) 42 | print("torch-mlir prediction") 43 | print(prediction) 44 | 45 | 46 | print("load image from " + image_url, file=sys.stderr) 47 | img = load_and_preprocess_image(image_url) 48 | labels = load_labels() 49 | 50 | resnet18 = models.resnet18(pretrained=True) 51 | resnet18.train(False) 52 | module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") 53 | backend = refbackend.RefBackendLinalgOnTensorsBackend() 54 | compiled = backend.compile(module) 55 | jit_module = backend.load(compiled) 56 | 57 | predictions(resnet18.forward, jit_module.forward, img, labels) 58 | -------------------------------------------------------------------------------- /src/err/mod.rs: -------------------------------------------------------------------------------- 1 | use std::{error, ffi::NulError, fmt::Display, string::FromUtf8Error}; 2 | 3 | use crate::types::{allocator::IreeAllocator, status::IreeStatus}; 4 | 5 | /// Represents an error returned by IREE. 6 | /// IREE functions return a status code, which is a `u32` value. The IreeError struct assumes the status code is an error code. 7 | #[derive(Debug)] 8 | pub struct IreeError { 9 | kind: IreeErrorKind, 10 | } 11 | 12 | #[derive(Debug)] 13 | pub enum IreeErrorKind { 14 | Status(IreeStatus, String), // For when the function that returned the status code allocated a string for the error message 15 | UnallocatedStatus(IreeStatus), // For when the function that returned the status code did not allocate a string for the error message (e.g. when it doesn't have an allocator) 16 | Other(Box), // For external errors 17 | Unknown(String), 18 | } 19 | 20 | impl error::Error for IreeError {} 21 | 22 | impl From for IreeError { 23 | fn from(s: String) -> Self { 24 | Self { 25 | kind: IreeErrorKind::Unknown(s), 26 | } 27 | } 28 | } 29 | 30 | impl From for IreeError { 31 | fn from(e: FromUtf8Error) -> Self { 32 | Self { 33 | kind: IreeErrorKind::Other(Box::new(e)), 34 | } 35 | } 36 | } 37 | impl From for IreeError { 38 | fn from(e: NulError) -> Self { 39 | Self { 40 | kind: IreeErrorKind::Other(Box::new(e)), 41 | } 42 | } 43 | } 44 | 45 | impl IreeError { 46 | pub fn new(kind: IreeErrorKind) -> Self { 47 | Self { kind } 48 | } 49 | pub fn from_status(status: IreeStatus, allocator: &IreeAllocator) -> Self { 50 | Self { 51 | kind: IreeErrorKind::Status(status, status.to_string(allocator).unwrap()), 52 | } 53 | } 54 | } 55 | 56 | impl Display for IreeError { 57 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 58 | match &self.kind { 59 | IreeErrorKind::Status(status, msg) => write!(f, "IREE status: {:?} {}", status, msg), 60 | IreeErrorKind::UnallocatedStatus(status) => write!(f, "IREE unallocated status: {:?} (try allocating the error message string using an allocator!)", status), 61 | IreeErrorKind::Unknown(msg) => write!(f, "IREE unknown error: {}", msg), 62 | IreeErrorKind::Other(err) => write!(f, "IREE other error: {}", err), 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/ffi/mod.rs: -------------------------------------------------------------------------------- 1 | pub use iree_sys::*; 2 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod err; 2 | pub mod ffi; 3 | pub mod types; 4 | -------------------------------------------------------------------------------- /src/types/allocator.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::{self, iree::runtime::api::iree_allocator_t}; 2 | 3 | pub struct IreeAllocator { 4 | pub(crate) allocator: iree_allocator_t, 5 | } 6 | 7 | impl IreeAllocator { 8 | /// Creates a default allocator that uses the system allocator (typically malloc). 9 | pub fn system_allocator() -> Self { 10 | // FIXME: This emulates the functionality of the `iree_system_allocator` macro. We should ideally be able to use that macro directly. 11 | Self { 12 | allocator: iree_allocator_t { 13 | self_: std::ptr::null_mut(), 14 | ctl: Some(iree_sys::iree::runtime::api::iree_allocator_system_ctl as _), 15 | }, 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/types/bytespan.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::iree::runtime::api::iree_const_byte_span_t; 2 | 3 | pub struct IreeConstByteSpan<'a, T> { 4 | pub(crate) span: iree_const_byte_span_t, 5 | pub(crate) _data: &'a [T], // keep the data alive 6 | } 7 | 8 | impl<'a, T> IreeConstByteSpan<'a, T> { 9 | pub fn from_slice(data: &'a [T]) -> Self { 10 | Self { 11 | span: iree_const_byte_span_t { 12 | data: data.as_ptr() as *const _, 13 | data_length: data.len() * std::mem::size_of::(), 14 | }, 15 | _data: data, 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/types/hal_allocator.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::iree::runtime::api::iree_hal_allocator_t; 2 | 3 | #[derive(Clone)] 4 | pub struct IreeHalAllocator { 5 | pub(crate) allocator_ptr: *mut iree_hal_allocator_t, 6 | } 7 | -------------------------------------------------------------------------------- /src/types/hal_buffer.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Display, Error}; 2 | 3 | use iree_sys::{ 4 | helper::IREE_CHECK_OK, 5 | iree::runtime::api::{ 6 | iree_hal_buffer_params_t, iree_hal_buffer_usage_t, iree_hal_buffer_view_allocate_buffer, 7 | iree_hal_buffer_view_format, iree_hal_buffer_view_release, iree_hal_buffer_view_shape, 8 | iree_hal_buffer_view_t, iree_hal_dim_t, iree_hal_element_types_t, 9 | iree_hal_encoding_types_t, iree_hal_memory_access_t, iree_hal_memory_type_t, 10 | }, 11 | }; 12 | 13 | use crate::err::IreeError; 14 | 15 | use super::{ 16 | allocator::IreeAllocator, bytespan::IreeConstByteSpan, hal_allocator::IreeHalAllocator, 17 | status::IreeStatus, 18 | }; 19 | 20 | pub type IreeHalBufferShape = Vec; 21 | 22 | pub struct IreeHalBufferParams { 23 | params: iree_hal_buffer_params_t, 24 | } 25 | 26 | pub struct IreeHalBufferViewParamsBuilder { 27 | params: iree_hal_buffer_params_t, 28 | } 29 | 30 | impl Default for IreeHalBufferViewParamsBuilder { 31 | fn default() -> Self { 32 | let params = iree_hal_buffer_params_t::default(); 33 | Self { params } 34 | } 35 | } 36 | 37 | impl IreeHalBufferViewParamsBuilder { 38 | pub fn build(&self) -> IreeHalBufferParams { 39 | IreeHalBufferParams { 40 | params: self.params, 41 | } 42 | } 43 | 44 | pub fn type_(&mut self, type_: iree_hal_memory_type_t) -> &mut Self { 45 | self.params.type_ |= type_; 46 | self 47 | } 48 | 49 | pub fn access(&mut self, access: iree_hal_memory_access_t) -> &mut Self { 50 | self.params.access |= access; 51 | self 52 | } 53 | 54 | pub fn usage(&mut self, usage: iree_hal_buffer_usage_t) -> &mut Self { 55 | self.params.usage |= usage; 56 | self 57 | } 58 | } 59 | 60 | pub struct IreeHalBufferView { 61 | pub(crate) buffer_view_ptr: *mut iree_hal_buffer_view_t, 62 | } 63 | 64 | impl IreeHalBufferView { 65 | pub fn allocate_buffer( 66 | allocator: &IreeHalAllocator, 67 | shape: &IreeHalBufferShape, 68 | element_type: iree_hal_element_types_t, 69 | encoding_type: iree_hal_encoding_types_t, 70 | params: &IreeHalBufferParams, 71 | byte_span: &IreeConstByteSpan, 72 | ) -> Result { 73 | let mut buffer_view_ptr = std::mem::MaybeUninit::<*mut iree_hal_buffer_view_t>::uninit(); 74 | unsafe { 75 | let status = iree_hal_buffer_view_allocate_buffer( 76 | allocator.allocator_ptr, 77 | shape.len(), 78 | shape.as_ptr(), 79 | element_type.0, 80 | encoding_type.0, 81 | params.params, 82 | byte_span.span, 83 | buffer_view_ptr.as_mut_ptr(), 84 | ); 85 | if !IREE_CHECK_OK(status) { 86 | // FIXME: We don't have the host allocator here, so we can't allocate the error message! 87 | return Err(IreeError::from_status( 88 | IreeStatus { status }, 89 | &IreeAllocator::system_allocator(), 90 | )); 91 | } 92 | } 93 | Ok(Self { 94 | buffer_view_ptr: unsafe { buffer_view_ptr.assume_init() }, 95 | }) 96 | } 97 | pub fn try_to_string(&self, max_element_count: usize) -> Result { 98 | let mut buffer = vec![0i8; max_element_count * 24]; // assume 24 bytes per element (maybe overkill) 99 | let mut out_buffer_length = std::mem::MaybeUninit::::uninit(); 100 | unsafe { 101 | let status = iree_hal_buffer_view_format( 102 | self.buffer_view_ptr, 103 | max_element_count, 104 | buffer.len(), 105 | buffer.as_mut_ptr(), 106 | out_buffer_length.as_mut_ptr(), 107 | ); 108 | 109 | if !IREE_CHECK_OK(status) { 110 | return Err(IreeError::from_status( 111 | IreeStatus { status }, 112 | &IreeAllocator::system_allocator(), 113 | )); 114 | } 115 | let buffer_u8 = buffer.drain(..).map(|b| b as u8).collect::>(); 116 | Ok(String::from_utf8_lossy(&buffer_u8[..out_buffer_length.assume_init()]).to_string()) 117 | } 118 | } 119 | pub fn shape(&self) -> Result { 120 | let mut out_shape: Vec = vec![0; 32]; // assume max rank is 32 (probably overkill!) 121 | let mut out_shape_rank = std::mem::MaybeUninit::::uninit(); 122 | unsafe { 123 | let status = iree_hal_buffer_view_shape( 124 | self.buffer_view_ptr, 125 | out_shape.len(), 126 | out_shape.as_mut_ptr(), 127 | out_shape_rank.as_mut_ptr(), 128 | ); 129 | if !IREE_CHECK_OK(status) { 130 | return Err(IreeError::from_status( 131 | IreeStatus { status }, 132 | &IreeAllocator::system_allocator(), 133 | )); 134 | } 135 | out_shape.truncate(out_shape_rank.assume_init()); 136 | } 137 | return Ok(out_shape); 138 | } 139 | } 140 | 141 | impl Drop for IreeHalBufferView { 142 | fn drop(&mut self) { 143 | unsafe { 144 | iree_hal_buffer_view_release(self.buffer_view_ptr); 145 | } 146 | } 147 | } 148 | 149 | impl Display for IreeHalBufferView { 150 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 151 | let shape = self.shape(); 152 | let n_elements = shape.map(|s| s.iter().fold(1, |a, b| a * b)); 153 | let m = n_elements.and_then(|n| self.try_to_string(n)); 154 | match m { 155 | Ok(m) => write!(f, "{}", m), 156 | Err(_e) => return Err(Error {}), 157 | } 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /src/types/hal_device.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::iree::runtime::api::{iree_hal_device_release, iree_hal_device_t}; 2 | 3 | pub struct IreeHalDevice { 4 | pub(crate) device_ptr: *mut iree_hal_device_t, 5 | } 6 | 7 | impl Drop for IreeHalDevice { 8 | fn drop(&mut self) { 9 | unsafe { 10 | iree_hal_device_release(self.device_ptr); 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/types/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod allocator; 2 | pub mod bytespan; 3 | pub mod hal_allocator; 4 | pub mod hal_buffer; 5 | pub mod hal_device; 6 | pub mod runtime; 7 | pub mod status; 8 | -------------------------------------------------------------------------------- /src/types/runtime/call.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::{ 2 | helper::IREE_CHECK_OK, 3 | iree::runtime::api::{ 4 | iree_hal_buffer_view_t, iree_runtime_call_deinitialize, iree_runtime_call_flags_t, 5 | iree_runtime_call_initialize_by_name, iree_runtime_call_inputs_push_back_buffer_view, 6 | iree_runtime_call_invoke, iree_runtime_call_outputs_pop_front_buffer_view, 7 | iree_runtime_call_t, iree_string_view_t, 8 | }, 9 | }; 10 | 11 | use crate::{ 12 | err::IreeError, 13 | types::{allocator::IreeAllocator, hal_buffer::IreeHalBufferView, status::IreeStatus}, 14 | }; 15 | 16 | use super::session::IreeRuntimeSession; 17 | 18 | pub struct IreeRuntimeCall { 19 | pub(crate) call: iree_runtime_call_t, 20 | } 21 | impl IreeRuntimeCall { 22 | pub fn initialize_by_name( 23 | session: &IreeRuntimeSession, 24 | full_name: &String, 25 | ) -> Result { 26 | let mut call = iree_runtime_call_t::default(); 27 | 28 | unsafe { 29 | let status = iree_runtime_call_initialize_by_name( 30 | session.session_ptr, 31 | iree_string_view_t { 32 | data: full_name.as_ptr() as *const i8, 33 | size: full_name.len(), 34 | }, 35 | &mut call, 36 | ); 37 | if !IREE_CHECK_OK(status) { 38 | return Err(IreeError::from_status( 39 | IreeStatus { status }, 40 | &IreeAllocator::system_allocator(), 41 | )); 42 | } 43 | } 44 | 45 | Ok(Self { call }) 46 | } 47 | 48 | pub fn inputs_push_back_buffer_view( 49 | &mut self, 50 | buffer_view: &IreeHalBufferView, 51 | ) -> Result<(), IreeError> { 52 | unsafe { 53 | let status = iree_runtime_call_inputs_push_back_buffer_view( 54 | &mut self.call, 55 | buffer_view.buffer_view_ptr, 56 | ); 57 | if !IREE_CHECK_OK(status) { 58 | return Err(IreeError::from_status( 59 | IreeStatus { status }, 60 | &IreeAllocator::system_allocator(), 61 | )); 62 | } 63 | Ok(()) 64 | } 65 | } 66 | 67 | pub fn outputs_pop_front_buffer_view(&mut self) -> Result { 68 | let mut ret = std::mem::MaybeUninit::<*mut iree_hal_buffer_view_t>::uninit(); 69 | unsafe { 70 | let status = 71 | iree_runtime_call_outputs_pop_front_buffer_view(&mut self.call, ret.as_mut_ptr()); 72 | 73 | if !IREE_CHECK_OK(status) { 74 | return Err(IreeError::from_status( 75 | IreeStatus { status }, 76 | &IreeAllocator::system_allocator(), 77 | )); 78 | } 79 | 80 | Ok(IreeHalBufferView { 81 | buffer_view_ptr: ret.assume_init(), 82 | }) 83 | } 84 | } 85 | 86 | pub fn invoke(&mut self, flags: iree_runtime_call_flags_t) -> Result<(), IreeError> { 87 | unsafe { 88 | let status = iree_runtime_call_invoke(&mut self.call, flags); 89 | if !IREE_CHECK_OK(status) { 90 | return Err(IreeError::from_status( 91 | IreeStatus { status }, 92 | &IreeAllocator::system_allocator(), 93 | )); 94 | } 95 | } 96 | Ok(()) 97 | } 98 | } 99 | 100 | impl Drop for IreeRuntimeCall { 101 | fn drop(&mut self) { 102 | unsafe { 103 | iree_runtime_call_deinitialize(&mut self.call); 104 | } 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/types/runtime/instance.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::{ 2 | helper::IREE_CHECK_OK, 3 | iree::runtime::api::{ 4 | iree_hal_device_t, iree_runtime_instance_create, iree_runtime_instance_host_allocator, 5 | iree_runtime_instance_options_initialize, iree_runtime_instance_options_t, 6 | iree_runtime_instance_options_use_all_available_drivers, iree_runtime_instance_release, 7 | iree_runtime_instance_t, iree_runtime_instance_try_create_default_device, 8 | iree_string_view_t, 9 | }, 10 | }; 11 | 12 | use crate::{ 13 | err::IreeError, 14 | types::{allocator::IreeAllocator, hal_device::IreeHalDevice, status::IreeStatus}, 15 | }; 16 | 17 | pub struct IreeRuntimeInstanceOptions { 18 | options: iree_runtime_instance_options_t, 19 | } 20 | 21 | pub struct IreeRuntimeInstanceOptionsBuilder { 22 | options: iree_runtime_instance_options_t, 23 | } 24 | 25 | impl Default for IreeRuntimeInstanceOptionsBuilder { 26 | fn default() -> Self { 27 | let mut options = iree_runtime_instance_options_t::default(); 28 | unsafe { 29 | iree_runtime_instance_options_initialize(&mut options); 30 | } 31 | Self { options } 32 | } 33 | } 34 | 35 | impl IreeRuntimeInstanceOptionsBuilder { 36 | pub fn use_all_available_drivers(&mut self) -> &mut Self { 37 | unsafe { 38 | iree_runtime_instance_options_use_all_available_drivers(&mut self.options); 39 | } 40 | self 41 | } 42 | pub fn build(&self) -> IreeRuntimeInstanceOptions { 43 | IreeRuntimeInstanceOptions { 44 | options: self.options, 45 | } 46 | } 47 | } 48 | 49 | pub struct IreeRuntimeInstance { 50 | pub(crate) instance_ptr: *mut iree_runtime_instance_t, 51 | } 52 | 53 | impl IreeRuntimeInstance { 54 | pub fn try_from_options( 55 | options: &IreeRuntimeInstanceOptions, 56 | allocator: &IreeAllocator, 57 | ) -> Result { 58 | let mut instance_ptr = std::mem::MaybeUninit::<*mut iree_runtime_instance_t>::uninit(); 59 | unsafe { 60 | let status = iree_runtime_instance_create( 61 | &options.options, 62 | allocator.allocator, 63 | instance_ptr.as_mut_ptr(), 64 | ); 65 | if !IREE_CHECK_OK(status) { 66 | return Err(IreeError::from_status(IreeStatus { status }, allocator)); 67 | } 68 | } 69 | Ok(Self { 70 | instance_ptr: unsafe { instance_ptr.assume_init() }, 71 | }) 72 | } 73 | 74 | pub fn host_allocator(&self) -> IreeAllocator { 75 | let allocator = unsafe { iree_runtime_instance_host_allocator(self.instance_ptr) }; 76 | IreeAllocator { allocator } 77 | } 78 | 79 | pub fn try_create_default_device(&self, driver_name: &str) -> Result { 80 | let driver_name = iree_string_view_t { 81 | data: driver_name.as_ptr() as _, 82 | size: driver_name.len() as _, 83 | }; 84 | let mut device_ptr = std::mem::MaybeUninit::<*mut iree_hal_device_t>::uninit(); 85 | unsafe { 86 | let status = iree_runtime_instance_try_create_default_device( 87 | self.instance_ptr, 88 | driver_name, 89 | device_ptr.as_mut_ptr(), 90 | ); 91 | if !IREE_CHECK_OK(status) { 92 | return Err(IreeError::from_status( 93 | IreeStatus { status }, 94 | &self.host_allocator(), 95 | )); 96 | } 97 | } 98 | Ok(IreeHalDevice { 99 | device_ptr: unsafe { device_ptr.assume_init() }, 100 | }) 101 | } 102 | } 103 | 104 | impl Drop for IreeRuntimeInstance { 105 | fn drop(&mut self) { 106 | unsafe { 107 | iree_runtime_instance_release(self.instance_ptr); 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/types/runtime/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod call; 2 | pub mod instance; 3 | pub mod session; 4 | -------------------------------------------------------------------------------- /src/types/runtime/session.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::{ 2 | helper::IREE_CHECK_OK, 3 | iree::runtime::api::{ 4 | iree_const_byte_span_t, iree_runtime_call_initialize_by_name, iree_runtime_call_t, 5 | iree_runtime_session_append_bytecode_module_from_memory, 6 | iree_runtime_session_create_with_device, iree_runtime_session_device_allocator, 7 | iree_runtime_session_options_initialize, iree_runtime_session_options_t, 8 | iree_runtime_session_release, iree_runtime_session_t, iree_string_view_t, 9 | }, 10 | }; 11 | 12 | use crate::{ 13 | err::IreeError, 14 | types::{ 15 | allocator::IreeAllocator, hal_allocator::IreeHalAllocator, hal_device::IreeHalDevice, 16 | status::IreeStatus, 17 | }, 18 | }; 19 | 20 | use super::{call::IreeRuntimeCall, instance::IreeRuntimeInstance}; 21 | 22 | pub struct IreeRuntimeSessionOptions { 23 | options: iree_runtime_session_options_t, 24 | } 25 | 26 | pub struct IreeRuntimeSessionOptionsBuilder { 27 | options: iree_runtime_session_options_t, 28 | } 29 | 30 | impl Default for IreeRuntimeSessionOptionsBuilder { 31 | fn default() -> Self { 32 | let mut options = iree_runtime_session_options_t::default(); 33 | unsafe { 34 | iree_runtime_session_options_initialize(&mut options); 35 | } 36 | Self { options } 37 | } 38 | } 39 | 40 | impl IreeRuntimeSessionOptionsBuilder { 41 | pub fn build(&self) -> IreeRuntimeSessionOptions { 42 | IreeRuntimeSessionOptions { 43 | options: self.options, 44 | } 45 | } 46 | } 47 | 48 | pub struct IreeRuntimeSession { 49 | pub(crate) session_ptr: *mut iree_runtime_session_t, 50 | } 51 | 52 | impl IreeRuntimeSession { 53 | pub fn create_with_device( 54 | instance: &IreeRuntimeInstance, 55 | options: &IreeRuntimeSessionOptions, 56 | device: &IreeHalDevice, 57 | allocator: &IreeAllocator, 58 | ) -> Result { 59 | let mut session_ptr = std::mem::MaybeUninit::<*mut iree_runtime_session_t>::uninit(); 60 | 61 | unsafe { 62 | let status = iree_runtime_session_create_with_device( 63 | instance.instance_ptr, 64 | &options.options, 65 | device.device_ptr, 66 | allocator.allocator, 67 | session_ptr.as_mut_ptr(), 68 | ); 69 | if !IREE_CHECK_OK(status) { 70 | return Err(IreeError::from_status( 71 | IreeStatus { status }, 72 | &instance.host_allocator(), 73 | )); 74 | } 75 | } 76 | 77 | Ok(Self { 78 | session_ptr: unsafe { session_ptr.assume_init() }, 79 | }) 80 | } 81 | 82 | pub fn device_allocator(&self) -> IreeHalAllocator { 83 | let allocator_ptr = unsafe { iree_runtime_session_device_allocator(self.session_ptr) }; 84 | IreeHalAllocator { allocator_ptr } 85 | } 86 | 87 | pub fn get_call_by_name(&self, full_name: &str) -> Result { 88 | let mut call = iree_runtime_call_t::default(); 89 | unsafe { 90 | let status = iree_runtime_call_initialize_by_name( 91 | self.session_ptr, 92 | iree_string_view_t { 93 | data: full_name.as_ptr() as *const i8, 94 | size: full_name.len(), 95 | }, 96 | &mut call, 97 | ); 98 | 99 | if !IREE_CHECK_OK(status) { 100 | return Err(IreeError::from_status( 101 | IreeStatus { status }, 102 | &IreeAllocator::system_allocator(), 103 | )); 104 | } 105 | 106 | Ok(IreeRuntimeCall { call }) 107 | } 108 | } 109 | 110 | pub fn append_bytecode_module_from_memory( 111 | &self, 112 | module_data: &[u8], 113 | allocator: &IreeAllocator, 114 | ) -> Result<(), IreeError> { 115 | let module_data = iree_const_byte_span_t { 116 | data: module_data.as_ptr() as _, 117 | data_length: module_data.len() as _, 118 | }; 119 | unsafe { 120 | let status = iree_runtime_session_append_bytecode_module_from_memory( 121 | self.session_ptr, 122 | module_data, 123 | allocator.allocator, 124 | ); 125 | if !IREE_CHECK_OK(status) { 126 | return Err(IreeError::from_status(IreeStatus { status }, allocator)); 127 | } 128 | } 129 | Ok(()) 130 | } 131 | } 132 | 133 | impl Drop for IreeRuntimeSession { 134 | fn drop(&mut self) { 135 | unsafe { 136 | iree_runtime_session_release(self.session_ptr); 137 | } 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/types/status.rs: -------------------------------------------------------------------------------- 1 | use iree_sys::{ 2 | helper::IREE_CHECK_OK, 3 | iree::runtime::api::{iree_status_t, iree_status_to_string}, 4 | }; 5 | 6 | use crate::err::IreeError; 7 | 8 | use super::allocator::IreeAllocator; 9 | 10 | #[derive(Clone, Copy, Debug)] 11 | pub struct IreeStatus { 12 | pub(crate) status: iree_status_t, 13 | } 14 | 15 | impl From for IreeStatus { 16 | fn from(status: iree_status_t) -> Self { 17 | Self { status } 18 | } 19 | } 20 | 21 | impl IreeStatus { 22 | pub fn is_ok(&self) -> bool { 23 | unsafe { IREE_CHECK_OK(self.status) } 24 | } 25 | pub fn to_string(&self, allocator: &IreeAllocator) -> Result { 26 | let mut out_buffer = std::mem::MaybeUninit::<*mut u8>::uninit(); 27 | let mut out_buffer_length = std::mem::MaybeUninit::::uninit(); 28 | unsafe { 29 | let tostr_success = iree_status_to_string( 30 | self.status, 31 | &allocator.allocator, 32 | out_buffer.as_mut_ptr() as _, 33 | out_buffer_length.as_mut_ptr(), 34 | ); 35 | if !tostr_success { 36 | return Err("Failed to convert status to string".to_string().into()); 37 | } 38 | 39 | let out_buffer = out_buffer.assume_init(); 40 | let out_buffer_length = out_buffer_length.assume_init(); 41 | let buffer = std::slice::from_raw_parts(out_buffer, out_buffer_length); 42 | 43 | Ok(String::from_utf8(buffer.to_vec())?) 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /tests/test_hal.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests { 3 | use iree_rs::types::{ 4 | allocator::IreeAllocator, 5 | bytespan::IreeConstByteSpan, 6 | hal_buffer::{IreeHalBufferView, IreeHalBufferViewParamsBuilder}, 7 | runtime::{ 8 | instance::{IreeRuntimeInstance, IreeRuntimeInstanceOptionsBuilder}, 9 | session::{IreeRuntimeSession, IreeRuntimeSessionOptionsBuilder}, 10 | }, 11 | }; 12 | use iree_sys::iree::runtime::api::{ 13 | iree_hal_buffer_usage_bits_t_IREE_HAL_BUFFER_USAGE_DEFAULT, 14 | iree_hal_element_types_t_IREE_HAL_ELEMENT_TYPE_FLOAT_64, 15 | iree_hal_encoding_types_t_IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, 16 | iree_hal_memory_type_bits_t_IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, 17 | }; 18 | 19 | #[test] 20 | fn test_hal_buffer_view() { 21 | let allocator = IreeAllocator::system_allocator(); 22 | let options = IreeRuntimeInstanceOptionsBuilder::default() 23 | .use_all_available_drivers() 24 | .build(); 25 | let instance = IreeRuntimeInstance::try_from_options(&options, &allocator).unwrap(); 26 | let device = instance.try_create_default_device("local-task").unwrap(); 27 | let session_options = IreeRuntimeSessionOptionsBuilder::default().build(); 28 | let session = IreeRuntimeSession::create_with_device( 29 | &instance, 30 | &session_options, 31 | &device, 32 | &allocator, 33 | ) 34 | .unwrap(); 35 | 36 | let data = [1.0, 2.0, 3.0, 4.0]; 37 | let device_allocator = session.device_allocator(); 38 | let byte_span = IreeConstByteSpan::from_slice(&data); 39 | 40 | let buffer_params = IreeHalBufferViewParamsBuilder::default() 41 | .type_(iree_hal_memory_type_bits_t_IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL.0) 42 | .usage(iree_hal_buffer_usage_bits_t_IREE_HAL_BUFFER_USAGE_DEFAULT.0) 43 | .build(); 44 | 45 | let buffer = IreeHalBufferView::allocate_buffer( 46 | &device_allocator, 47 | &vec![data.len()], 48 | iree_hal_element_types_t_IREE_HAL_ELEMENT_TYPE_FLOAT_64, 49 | iree_hal_encoding_types_t_IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, 50 | &buffer_params, 51 | &byte_span, 52 | ) 53 | .unwrap(); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /tests/test_runtime.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests { 3 | use iree_rs::types::{ 4 | allocator::IreeAllocator, 5 | runtime::{ 6 | instance::{IreeRuntimeInstance, IreeRuntimeInstanceOptionsBuilder}, 7 | session::{IreeRuntimeSession, IreeRuntimeSessionOptionsBuilder}, 8 | }, 9 | }; 10 | 11 | #[test] 12 | fn test_runtime_instance() { 13 | let allocator = IreeAllocator::system_allocator(); 14 | let options = IreeRuntimeInstanceOptionsBuilder::default() 15 | .use_all_available_drivers() 16 | .build(); 17 | let instance = IreeRuntimeInstance::try_from_options(&options, &allocator).unwrap(); 18 | } 19 | 20 | #[test] 21 | fn test_runtime_instance_try_create_default_device() { 22 | let allocator = IreeAllocator::system_allocator(); 23 | let options = IreeRuntimeInstanceOptionsBuilder::default() 24 | .use_all_available_drivers() 25 | .build(); 26 | let instance = IreeRuntimeInstance::try_from_options(&options, &allocator).unwrap(); 27 | let device = instance.try_create_default_device("local-task").unwrap(); 28 | } 29 | 30 | #[test] 31 | fn test_runtime_session() { 32 | let allocator = IreeAllocator::system_allocator(); 33 | let options = IreeRuntimeInstanceOptionsBuilder::default() 34 | .use_all_available_drivers() 35 | .build(); 36 | let instance = IreeRuntimeInstance::try_from_options(&options, &allocator).unwrap(); 37 | let device = instance.try_create_default_device("local-task").unwrap(); 38 | let session_options = IreeRuntimeSessionOptionsBuilder::default().build(); 39 | let session = IreeRuntimeSession::create_with_device( 40 | &instance, 41 | &session_options, 42 | &device, 43 | &allocator, 44 | ) 45 | .unwrap(); 46 | } 47 | } 48 | --------------------------------------------------------------------------------