├── .gitignore ├── .travis.yml ├── .vscode └── launch.json ├── Cargo.lock ├── Cargo.toml ├── README.md ├── misc ├── gen.py ├── raw_fizzbuzz.py ├── test.py ├── tf_fizzbuzz.py └── xor.py ├── trsc ├── Cargo.lock ├── Cargo.toml ├── src │ ├── codegen │ │ ├── mod.rs │ │ └── pytorch.rs │ ├── core │ │ ├── conv.rs │ │ ├── lin.rs │ │ ├── mod.rs │ │ ├── nonlin.rs │ │ ├── prelude.rs │ │ └── reg.rs │ ├── errors │ │ ├── diagnostic.rs │ │ ├── emitter.rs │ │ └── mod.rs │ ├── main.rs │ ├── parsing │ │ ├── ast_builder.rs │ │ ├── grammar.rs │ │ ├── macros.rs │ │ ├── mod.rs │ │ └── term.rs │ ├── span.rs │ ├── tensorscript.pest │ └── typing │ │ ├── annotate.rs │ │ ├── constraint.rs │ │ ├── inferred_ast.rs │ │ ├── mod.rs │ │ ├── type_env.rs │ │ ├── typed_term.rs │ │ ├── types.rs │ │ └── unifier.rs └── tests │ ├── input │ ├── gan.trs │ ├── mnist.trs │ └── xor.trs │ ├── integration_test.rs │ └── output │ ├── gan.py │ ├── mnist.py │ └── xor.py ├── trsc_core_derive ├── Cargo.toml └── src │ ├── attrs.rs │ ├── lib.rs │ └── parser.rs └── vscode-syntax-ext ├── .gitignore ├── .vscode └── launch.json ├── .vscodeignore ├── CHANGELOG.md ├── README.md ├── language-configuration.json ├── package.json ├── syntaxes └── tensorscript.tmLanguage.json └── vsc-extension-quickstart.md /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /target 3 | **/*.rs.bk 4 | rls -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | rust: 3 | - nightly 4 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "lldb", 9 | "request": "launch", 10 | "name": "Debug", 11 | "program": "${workspaceFolder}/target/debug/tsrc", 12 | "args": ["--in", "examples/mnist.trs"], 13 | "cwd": "${workspaceFolder}" 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "ansi_term" 3 | version = "0.11.0" 4 | source = "registry+https://github.com/rust-lang/crates.io-index" 5 | dependencies = [ 6 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 7 | ] 8 | 9 | [[package]] 10 | name = "assert_cli" 11 | version = "0.6.0" 12 | source = "registry+https://github.com/rust-lang/crates.io-index" 13 | dependencies = [ 14 | "colored 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)", 15 | "difference 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", 16 | "environment 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 17 | "failure 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 18 | "failure_derive 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 19 | "serde_json 1.0.17 (registry+https://github.com/rust-lang/crates.io-index)", 20 | ] 21 | 22 | [[package]] 23 | name = "atty" 24 | version = "0.2.10" 25 | source = "registry+https://github.com/rust-lang/crates.io-index" 26 | dependencies = [ 27 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 28 | "termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 29 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 30 | ] 31 | 32 | [[package]] 33 | name = "backtrace" 34 | version = "0.3.7" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | dependencies = [ 37 | "backtrace-sys 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)", 38 | "cfg-if 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", 39 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 40 | "rustc-demangle 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", 41 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 42 | ] 43 | 44 | [[package]] 45 | name = "backtrace-sys" 46 | version = "0.1.16" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | dependencies = [ 49 | "cc 1.0.15 (registry+https://github.com/rust-lang/crates.io-index)", 50 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 51 | ] 52 | 53 | [[package]] 54 | name = "bitflags" 55 | version = "1.0.3" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | 58 | [[package]] 59 | name = "cc" 60 | version = "1.0.15" 61 | source = "registry+https://github.com/rust-lang/crates.io-index" 62 | 63 | [[package]] 64 | name = "cfg-if" 65 | version = "0.1.3" 66 | source = "registry+https://github.com/rust-lang/crates.io-index" 67 | 68 | [[package]] 69 | name = "clap" 70 | version = "2.31.2" 71 | source = "registry+https://github.com/rust-lang/crates.io-index" 72 | dependencies = [ 73 | "ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", 74 | "atty 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)", 75 | "bitflags 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", 76 | "strsim 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", 77 | "textwrap 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", 78 | "unicode-width 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", 79 | "vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", 80 | ] 81 | 82 | [[package]] 83 | name = "codespan" 84 | version = "0.1.2" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | dependencies = [ 87 | "failure 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 88 | "itertools 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)", 89 | ] 90 | 91 | [[package]] 92 | name = "codespan-reporting" 93 | version = "0.1.3" 94 | source = "registry+https://github.com/rust-lang/crates.io-index" 95 | dependencies = [ 96 | "codespan 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", 97 | "termcolor 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 98 | ] 99 | 100 | [[package]] 101 | name = "colored" 102 | version = "1.6.0" 103 | source = "registry+https://github.com/rust-lang/crates.io-index" 104 | dependencies = [ 105 | "lazy_static 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", 106 | ] 107 | 108 | [[package]] 109 | name = "difference" 110 | version = "2.0.0" 111 | source = "registry+https://github.com/rust-lang/crates.io-index" 112 | 113 | [[package]] 114 | name = "dtoa" 115 | version = "0.4.2" 116 | source = "registry+https://github.com/rust-lang/crates.io-index" 117 | 118 | [[package]] 119 | name = "either" 120 | version = "1.5.0" 121 | source = "registry+https://github.com/rust-lang/crates.io-index" 122 | 123 | [[package]] 124 | name = "environment" 125 | version = "0.1.1" 126 | source = "registry+https://github.com/rust-lang/crates.io-index" 127 | 128 | [[package]] 129 | name = "failure" 130 | version = "0.1.1" 131 | source = "registry+https://github.com/rust-lang/crates.io-index" 132 | dependencies = [ 133 | "backtrace 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)", 134 | "failure_derive 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 135 | ] 136 | 137 | [[package]] 138 | name = "failure_derive" 139 | version = "0.1.1" 140 | source = "registry+https://github.com/rust-lang/crates.io-index" 141 | dependencies = [ 142 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 143 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 144 | "synstructure 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", 145 | ] 146 | 147 | [[package]] 148 | name = "itertools" 149 | version = "0.7.8" 150 | source = "registry+https://github.com/rust-lang/crates.io-index" 151 | dependencies = [ 152 | "either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)", 153 | ] 154 | 155 | [[package]] 156 | name = "itoa" 157 | version = "0.4.1" 158 | source = "registry+https://github.com/rust-lang/crates.io-index" 159 | 160 | [[package]] 161 | name = "lazy_static" 162 | version = "0.2.11" 163 | source = "registry+https://github.com/rust-lang/crates.io-index" 164 | 165 | [[package]] 166 | name = "lazy_static" 167 | version = "1.0.0" 168 | source = "registry+https://github.com/rust-lang/crates.io-index" 169 | 170 | [[package]] 171 | name = "libc" 172 | version = "0.2.40" 173 | source = "registry+https://github.com/rust-lang/crates.io-index" 174 | 175 | [[package]] 176 | name = "maplit" 177 | version = "1.0.1" 178 | source = "registry+https://github.com/rust-lang/crates.io-index" 179 | 180 | [[package]] 181 | name = "pest" 182 | version = "1.0.6" 183 | source = "registry+https://github.com/rust-lang/crates.io-index" 184 | 185 | [[package]] 186 | name = "pest_derive" 187 | version = "1.0.7" 188 | source = "registry+https://github.com/rust-lang/crates.io-index" 189 | dependencies = [ 190 | "pest 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", 191 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 192 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 193 | ] 194 | 195 | [[package]] 196 | name = "quote" 197 | version = "0.3.15" 198 | source = "registry+https://github.com/rust-lang/crates.io-index" 199 | 200 | [[package]] 201 | name = "redox_syscall" 202 | version = "0.1.37" 203 | source = "registry+https://github.com/rust-lang/crates.io-index" 204 | 205 | [[package]] 206 | name = "redox_termios" 207 | version = "0.1.1" 208 | source = "registry+https://github.com/rust-lang/crates.io-index" 209 | dependencies = [ 210 | "redox_syscall 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)", 211 | ] 212 | 213 | [[package]] 214 | name = "rustc-demangle" 215 | version = "0.1.8" 216 | source = "registry+https://github.com/rust-lang/crates.io-index" 217 | 218 | [[package]] 219 | name = "serde" 220 | version = "1.0.57" 221 | source = "registry+https://github.com/rust-lang/crates.io-index" 222 | 223 | [[package]] 224 | name = "serde_json" 225 | version = "1.0.17" 226 | source = "registry+https://github.com/rust-lang/crates.io-index" 227 | dependencies = [ 228 | "dtoa 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", 229 | "itoa 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", 230 | "serde 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)", 231 | ] 232 | 233 | [[package]] 234 | name = "strsim" 235 | version = "0.7.0" 236 | source = "registry+https://github.com/rust-lang/crates.io-index" 237 | 238 | [[package]] 239 | name = "syn" 240 | version = "0.11.11" 241 | source = "registry+https://github.com/rust-lang/crates.io-index" 242 | dependencies = [ 243 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 244 | "synom 0.11.3 (registry+https://github.com/rust-lang/crates.io-index)", 245 | "unicode-xid 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)", 246 | ] 247 | 248 | [[package]] 249 | name = "synom" 250 | version = "0.11.3" 251 | source = "registry+https://github.com/rust-lang/crates.io-index" 252 | dependencies = [ 253 | "unicode-xid 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)", 254 | ] 255 | 256 | [[package]] 257 | name = "synstructure" 258 | version = "0.6.1" 259 | source = "registry+https://github.com/rust-lang/crates.io-index" 260 | dependencies = [ 261 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 262 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 263 | ] 264 | 265 | [[package]] 266 | name = "termcolor" 267 | version = "0.3.6" 268 | source = "registry+https://github.com/rust-lang/crates.io-index" 269 | dependencies = [ 270 | "wincolor 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", 271 | ] 272 | 273 | [[package]] 274 | name = "termion" 275 | version = "1.5.1" 276 | source = "registry+https://github.com/rust-lang/crates.io-index" 277 | dependencies = [ 278 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 279 | "redox_syscall 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)", 280 | "redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 281 | ] 282 | 283 | [[package]] 284 | name = "textwrap" 285 | version = "0.9.0" 286 | source = "registry+https://github.com/rust-lang/crates.io-index" 287 | dependencies = [ 288 | "unicode-width 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", 289 | ] 290 | 291 | [[package]] 292 | name = "trsc" 293 | version = "0.1.0" 294 | dependencies = [ 295 | "assert_cli 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", 296 | "clap 2.31.2 (registry+https://github.com/rust-lang/crates.io-index)", 297 | "codespan 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", 298 | "codespan-reporting 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", 299 | "lazy_static 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", 300 | "maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 301 | "pest 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", 302 | "pest_derive 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", 303 | "trsc_core_derive 0.1.0", 304 | ] 305 | 306 | [[package]] 307 | name = "trsc_core_derive" 308 | version = "0.1.0" 309 | dependencies = [ 310 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 311 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 312 | ] 313 | 314 | [[package]] 315 | name = "unicode-width" 316 | version = "0.1.4" 317 | source = "registry+https://github.com/rust-lang/crates.io-index" 318 | 319 | [[package]] 320 | name = "unicode-xid" 321 | version = "0.0.4" 322 | source = "registry+https://github.com/rust-lang/crates.io-index" 323 | 324 | [[package]] 325 | name = "vec_map" 326 | version = "0.8.1" 327 | source = "registry+https://github.com/rust-lang/crates.io-index" 328 | 329 | [[package]] 330 | name = "winapi" 331 | version = "0.3.4" 332 | source = "registry+https://github.com/rust-lang/crates.io-index" 333 | dependencies = [ 334 | "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 335 | "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 336 | ] 337 | 338 | [[package]] 339 | name = "winapi-i686-pc-windows-gnu" 340 | version = "0.4.0" 341 | source = "registry+https://github.com/rust-lang/crates.io-index" 342 | 343 | [[package]] 344 | name = "winapi-x86_64-pc-windows-gnu" 345 | version = "0.4.0" 346 | source = "registry+https://github.com/rust-lang/crates.io-index" 347 | 348 | [[package]] 349 | name = "wincolor" 350 | version = "0.1.6" 351 | source = "registry+https://github.com/rust-lang/crates.io-index" 352 | dependencies = [ 353 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 354 | ] 355 | 356 | [metadata] 357 | "checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" 358 | "checksum assert_cli 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5da59dbd8df54562665b925b427221ceda9b771408cb8a6cbd2125d3b001330b" 359 | "checksum atty 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "2fc4a1aa4c24c0718a250f0681885c1af91419d242f29eb8f2ab28502d80dbd1" 360 | "checksum backtrace 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea58cd16fd6c9d120b5bcb01d63883ae4cc7ba2aed35c1841b862a3c7ef6639" 361 | "checksum backtrace-sys 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "44585761d6161b0f57afc49482ab6bd067e4edef48c12a152c237eb0203f7661" 362 | "checksum bitflags 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d0c54bb8f454c567f21197eefcdbf5679d0bd99f2ddbe52e84c77061952e6789" 363 | "checksum cc 1.0.15 (registry+https://github.com/rust-lang/crates.io-index)" = "0ebb87d1116151416c0cf66a0e3fb6430cccd120fd6300794b4dfaa050ac40ba" 364 | "checksum cfg-if 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "405216fd8fe65f718daa7102ea808a946b6ce40c742998fbfd3463645552de18" 365 | "checksum clap 2.31.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f0f16b89cbb9ee36d87483dc939fe9f1e13c05898d56d7b230a0d4dff033a536" 366 | "checksum codespan 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b81a69fe4f90ef3238d50a1e312c224634986013f37b32f944e369f55439b961" 367 | "checksum codespan-reporting 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a5c1ae63259648f1cd349d0913570828d7e098de98cc685837dacad1367f428" 368 | "checksum colored 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b0aa3473e85a3161b59845d6096b289bb577874cafeaf75ea1b1beaa6572c7fc" 369 | "checksum difference 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "524cbf6897b527295dff137cec09ecf3a05f4fddffd7dfcd1585403449e74198" 370 | "checksum dtoa 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "09c3753c3db574d215cba4ea76018483895d7bff25a31b49ba45db21c48e50ab" 371 | "checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0" 372 | "checksum environment 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "1f4b14e20978669064c33b4c1e0fb4083412e40fe56cbea2eae80fd7591503ee" 373 | "checksum failure 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "934799b6c1de475a012a02dab0ace1ace43789ee4b99bcfbf1a2e3e8ced5de82" 374 | "checksum failure_derive 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c7cdda555bb90c9bb67a3b670a0f42de8e73f5981524123ad8578aafec8ddb8b" 375 | "checksum itertools 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)" = "f58856976b776fedd95533137617a02fb25719f40e7d9b01c7043cd65474f450" 376 | "checksum itoa 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c069bbec61e1ca5a596166e55dfe4773ff745c3d16b700013bcaff9a6df2c682" 377 | "checksum lazy_static 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "76f033c7ad61445c5b347c7382dd1237847eb1bce590fe50365dcb33d546be73" 378 | "checksum lazy_static 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c8f31047daa365f19be14b47c29df4f7c3b581832407daabe6ae77397619237d" 379 | "checksum libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)" = "6fd41f331ac7c5b8ac259b8bf82c75c0fb2e469bbf37d2becbba9a6a2221965b" 380 | "checksum maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "08cbb6b4fef96b6d77bfc40ec491b1690c779e77b05cd9f07f787ed376fd4c43" 381 | "checksum pest 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0fce5d8b5cc33983fc74f78ad552b5522ab41442c4ca91606e4236eb4b5ceefc" 382 | "checksum pest_derive 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "ab94faafeb93f4c5e3ce81ca0e5a779529a602ad5d09ae6d21996bfb8b6a52bf" 383 | "checksum quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" 384 | "checksum redox_syscall 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)" = "0d92eecebad22b767915e4d529f89f28ee96dbbf5a4810d2b844373f136417fd" 385 | "checksum redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7e891cfe48e9100a70a3b6eb652fef28920c117d366339687bd5576160db0f76" 386 | "checksum rustc-demangle 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "76d7ba1feafada44f2d38eed812bd2489a03c0f5abb975799251518b68848649" 387 | "checksum serde 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)" = "9478f147957b713a156ce5e4529d77275bbcfddc29563b794939b36230df8ca8" 388 | "checksum serde_json 1.0.17 (registry+https://github.com/rust-lang/crates.io-index)" = "f3ad6d546e765177cf3dded3c2e424a8040f870083a0e64064746b958ece9cb1" 389 | "checksum strsim 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bb4f380125926a99e52bc279241539c018323fab05ad6368b56f93d9369ff550" 390 | "checksum syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)" = "d3b891b9015c88c576343b9b3e41c2c11a51c219ef067b264bd9c8aa9b441dad" 391 | "checksum synom 0.11.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a393066ed9010ebaed60b9eafa373d4b1baac186dd7e008555b0f702b51945b6" 392 | "checksum synstructure 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "3a761d12e6d8dcb4dcf952a7a89b475e3a9d69e4a69307e01a470977642914bd" 393 | "checksum termcolor 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "adc4587ead41bf016f11af03e55a624c06568b5a19db4e90fde573d805074f83" 394 | "checksum termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "689a3bdfaab439fd92bc87df5c4c78417d3cbe537487274e9b0b2dce76e92096" 395 | "checksum textwrap 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c0b59b6b4b44d867f1370ef1bd91bfb262bf07bf0ae65c202ea2fbc16153b693" 396 | "checksum unicode-width 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "bf3a113775714a22dcb774d8ea3655c53a32debae63a063acc00a91cc586245f" 397 | "checksum unicode-xid 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f860d7d29cf02cb2f3f359fd35991af3d30bac52c57d265a3c461074cb4dc" 398 | "checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" 399 | "checksum winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "04e3bd221fcbe8a271359c04f21a76db7d0c6028862d1bb5512d85e1e2eb5bb3" 400 | "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 401 | "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 402 | "checksum wincolor 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "eeb06499a3a4d44302791052df005d5232b927ed1a9658146d842165c4de7767" 403 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "trsc", 4 | "trsc_core_derive", 5 | ] 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorScript 2 | 3 | Dependently-typed tensor computation. 4 | 5 | ## Features 6 | 7 | * Parametric polymorphism 8 | * Compile time type checking 9 | * Dependently typed tensors 10 | * Multiple targets(Tensorflow, PyTorch, more to come!) 11 | * Pipes operator 12 | 13 | ### Pipes operator 14 | 15 | Pipes operator is a syntax sugar for chained function calls inspired by F#, Elixir and R. 16 | For example, 17 | 18 | ```rust 19 | x |> lin1 |> leaky_relu(p=0.2) |> sigmoid 20 | ``` 21 | 22 | compiles to 23 | 24 | ```python 25 | x = lin1(x) 26 | x = leaky_relu(x, p=0.2) 27 | x = sigmoid(x) 28 | ``` 29 | 30 | ## Development 31 | 32 | [![Build Status](https://travis-ci.org/rickyhan/tensorscript.svg?branch=master)](https://travis-ci.org/rickyhan/tensorscript) 33 | 34 | The language is not usable in production or development. 35 | 36 | ### Todo 37 | 38 | 1. [x] implement module pattern matching 39 | 2. [x] type level computation (resolved tensor dimension) 40 | 3. [x] BUG: dimension mismatch for mnist example 41 | need to create fresh type variables for different static forward functions 42 | 4. [x] BUG: non-determinism 43 | 5. [x] BUG: impl Hash, Eq for Type 44 | 6. [x] set up examples and tests 45 | 7. [x] set up commandline 46 | 8. [x] more examples 47 | 9. [x] better errors in parser 48 | 10. [ ] code gen: PyTorch 49 | 11. [ ] add more examples 50 | 12. [x] lift dim and tsr to top level 51 | 13. [ ] add dim level computation dim1 * dim1 52 | 14. [ ] use Linear as L; aliasing 53 | 15. [ ] add binary ops (+, -, *, /, %) 54 | 16. [ ] add if else expression 55 | 17. [ ] add let binding 56 | 18. [ ] add more tests 57 | -------------------------------------------------------------------------------- /misc/gen.py: -------------------------------------------------------------------------------- 1 | import astunparse, ast, astpretty 2 | from ast import * 3 | 4 | fname = "./raw_fizzbuzz.py" 5 | with open(fname) as f: 6 | txt = f.read() 7 | 8 | class RewriteName(NodeTransformer): 9 | def visit_BoolOp(self, node): 10 | # print astpretty.pprint(node) 11 | if isinstance(node.op, And): 12 | funcname = "tf.logical_and" 13 | return copy_location(Call( 14 | func=Name(id=funcname, ctx=Load()), 15 | args=( 16 | self.visit(node.values[0]), 17 | self.visit(node.values[1]) 18 | ), 19 | keywords=(), 20 | starargs=(), 21 | kwargs=(), 22 | ), node) 23 | else: 24 | return node 25 | 26 | def visit_BinOp(self, node): 27 | # print astpretty.pprint(node) 28 | if isinstance(node.op, Mod): 29 | return copy_location(Call( 30 | func=Name(id="tf.mod", ctx=Load()), 31 | args=(node.left, node.right), 32 | keywords=(), 33 | starargs=(), 34 | kwargs=(), 35 | ), node) 36 | return node 37 | 38 | def visit_Compare(self, node): 39 | # print astpretty.pprint(node) 40 | if len(node.ops) == 1 and isinstance(node.ops[0], Eq): 41 | return copy_location(Call( 42 | func=Name(id="tf.equal", ctx=Load()), 43 | args=(self.visit(node.left), map(lambda i: self.visit(i), node.comparators)), 44 | keywords=(), 45 | starargs=(), 46 | kwargs=(), 47 | ), node) 48 | elif len(node.ops) == 1 and isinstance(node.ops[0], Lt): 49 | return copy_location(Call( 50 | func=Name(id="tf.less", ctx=Load()), 51 | args=(self.visit(node.left), map(lambda i: self.visit(i), node.comparators)), 52 | keywords=(), 53 | starargs=(), 54 | kwargs=(), 55 | ), node) 56 | else: 57 | return node 58 | 59 | def visit_If(self, node): 60 | # print astpretty.pprint(node) 61 | return copy_location(Call( 62 | func=Name(id="tf.cond", ctx=Load()), 63 | args=( 64 | self.visit(node.test), 65 | Lambda( 66 | args=arguments(args=[],defaults=[],vararg=[],kwarg=[]), 67 | body=self.visit(node.body[0]) 68 | ), 69 | Lambda( 70 | args=arguments(args=[],defaults=[],vararg=[],kwarg=[]), 71 | body=self.visit(node.orelse[0]) 72 | ), 73 | ), 74 | keywords=(), 75 | starargs=(), 76 | kwargs=(), 77 | ), node) 78 | 79 | def visit_Assign(self, node): 80 | # print astpretty.pprint(node) 81 | return copy_location(Call( 82 | func=Name(id="tf.assign", ctx=Load()), 83 | args=( 84 | self.visit(node.targets[0]), 85 | self.visit(node.value), 86 | ), 87 | keywords=(), 88 | starargs=(), 89 | kwargs=(), 90 | ), node) 91 | 92 | def visit_While(self, node): 93 | # print astpretty.pprint(node) 94 | return copy_location(Call( 95 | func=Name(id="tf.while_loop", ctx=Load()), 96 | args=( 97 | Lambda( 98 | args=arguments(args=[],defaults=[],vararg=[],kwarg=[]), 99 | body=self.visit(node.test), 100 | ), 101 | map(self.visit, node.body), 102 | ), 103 | keywords=(), 104 | starargs=(), 105 | kwargs=(), 106 | ), node) 107 | 108 | 109 | myast = ast.parse(txt) 110 | myast = RewriteName().visit(myast) 111 | # print astpretty.pprint(myast) 112 | print(astunparse.unparse(myast)) 113 | -------------------------------------------------------------------------------- /misc/raw_fizzbuzz.py: -------------------------------------------------------------------------------- 1 | length = 99 2 | arr = [0] * (length-1) 3 | i = 1 4 | while i < length: 5 | if n % 3 == 0 and n % 5 == 0: 6 | arr[n-1] = 'FizzBuzz' 7 | elif n % 3 == 0: 8 | arr[n-1] = 'Fizz' 9 | elif n % 5 == 0: 10 | arr[n-1] = 'Buzz' 11 | else: 12 | arr -------------------------------------------------------------------------------- /misc/test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | length = 100 3 | arr = tf.Variable([str(i) for i in range(1, length+1)]) 4 | graph = tf.while_loop( 5 | lambda i, _: tf.less(i, length+1), 6 | lambda i, _: (tf.add(i,1), tf.cond( 7 | tf.logical_and(tf.equal(tf.mod(i, 3), 0), tf.equal(tf.mod(i, 5), 0)), 8 | (lambda : tf.assign(arr[(i - 1)], 'FizzBuzz')), 9 | (lambda : tf.cond(tf.equal(tf.mod(i, 3), 0), 10 | (lambda : tf.assign(arr[(i - 1)], 'Fizz')), 11 | (lambda : tf.cond(tf.equal(tf.mod(i, 5), 0), 12 | (lambda : tf.assign(arr[(i - 1)], 'Buzz')), 13 | (lambda : arr))))))), 14 | [1, arr]) 15 | with tf.Session() as sess: 16 | tf.global_variables_initializer().run() 17 | idx, array = sess.run(graph) 18 | print array 19 | -------------------------------------------------------------------------------- /misc/tf_fizzbuzz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | 8 | 9 | class FizzBuzz(): 10 | """FizzBuzz""" 11 | length = 30 12 | def __init__(self): 13 | with tf.name_scope("fizzbuzz"): 14 | self.array = tf.Variable([str(i) for i in range(1, self.length+1)], dtype=tf.string, trainable=False) 15 | self.graph = tf.while_loop(self.cond, self.body, [1, self.array], 16 | shape_invariants=[tf.TensorShape([]), tf.TensorShape(self.length)], 17 | back_prop=False) 18 | 19 | def run(self): 20 | with tf.Session() as sess: 21 | tf.global_variables_initializer().run() 22 | return sess.run(self.graph) 23 | 24 | def cond(self, i, _): 25 | return (tf.less(i, self.length+1)) 26 | 27 | def body(self, i, _): 28 | flow = tf.cond( 29 | tf.logical_and(tf.equal(tf.mod(i, 3), 0), tf.equal(tf.mod(i, 5), 0)), 30 | lambda: tf.assign(self.array[i - 1], 'FizzBuzz'), 31 | lambda: tf.cond(tf.equal(tf.mod(i, 3), 0), 32 | lambda: tf.assign(self.array[i - 1], 'Fizz'), 33 | lambda: tf.cond(tf.equal(tf.mod(i, 5), 0), 34 | lambda: tf.assign(self.array[i - 1], 'Buzz'), 35 | lambda: self.array 36 | ) 37 | ) 38 | ) 39 | return (tf.add(i, 1), flow) 40 | 41 | 42 | if __name__ == '__main__': 43 | fizzbuzz = FizzBuzz() 44 | ix, array = fizzbuzz.run() 45 | print(array) 46 | 47 | -------------------------------------------------------------------------------- /misc/xor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | 8 | class Xor(nn.Module): 9 | '''Xor::forward([!1, <2>] -> [!1, <1>])''' 10 | def __init__(self): 11 | super(Xor, self).__init__() 12 | self.fc1 = nn.Linear(in_features=2, out_features=3) 13 | self.fc2 = nn.Linear(in_features=3, out_features=1) 14 | def forward(self, x): 15 | x = self.fc1(x) 16 | x = F.sigmoid(x) 17 | return self.fc2(x) 18 | 19 | net = Xor() 20 | 21 | inputs = list(map(lambda s: Variable(torch.Tensor([s])), [ 22 | [0, 0], 23 | [0, 1], 24 | [1, 0], 25 | [1, 1] 26 | ])) 27 | targets = list(map(lambda s: Variable(torch.Tensor([s])), [ 28 | [0], 29 | [1], 30 | [1], 31 | [0] 32 | ])) 33 | 34 | criterion = nn.MSELoss() 35 | optimizer = optim.SGD(net.parameters(), lr=0.1) 36 | 37 | EPOCHS_TO_TRAIN = 2000 38 | print("Training loop:") 39 | for idx in range(0, EPOCHS_TO_TRAIN): 40 | for input, target in zip(inputs, targets): 41 | optimizer.zero_grad() # zero the gradient buffers 42 | output = net(input) 43 | loss = criterion(output, target) 44 | loss.backward() 45 | optimizer.step() # update 46 | if idx % 500 == 0: 47 | print("Epoch:\t", idx, "\tloss:\t", loss.data.numpy()) 48 | 49 | 50 | print("") 51 | print("Final results:") 52 | for input, target in zip(inputs, targets): 53 | output = net(input) 54 | print("Input:[{},{}] Target:[{}] Predicted:[{}] Error:[{}]".format( 55 | int(input.data.numpy()[0][0]), 56 | int(input.data.numpy()[0][1]), 57 | int(target.data.numpy()[0]), 58 | round(float(output.data.numpy()[0]), 4), 59 | round(float(abs(target.data.numpy()[0] - output.data.numpy()[0])), 4) 60 | )) 61 | -------------------------------------------------------------------------------- /trsc/Cargo.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "ansi_term" 3 | version = "0.11.0" 4 | source = "registry+https://github.com/rust-lang/crates.io-index" 5 | dependencies = [ 6 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 7 | ] 8 | 9 | [[package]] 10 | name = "atty" 11 | version = "0.2.8" 12 | source = "registry+https://github.com/rust-lang/crates.io-index" 13 | dependencies = [ 14 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 15 | "termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 16 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 17 | ] 18 | 19 | [[package]] 20 | name = "backtrace" 21 | version = "0.3.6" 22 | source = "registry+https://github.com/rust-lang/crates.io-index" 23 | dependencies = [ 24 | "backtrace-sys 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)", 25 | "cfg-if 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", 26 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 27 | "rustc-demangle 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", 28 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 29 | ] 30 | 31 | [[package]] 32 | name = "backtrace-sys" 33 | version = "0.1.16" 34 | source = "registry+https://github.com/rust-lang/crates.io-index" 35 | dependencies = [ 36 | "cc 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", 37 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 38 | ] 39 | 40 | [[package]] 41 | name = "bitflags" 42 | version = "1.0.1" 43 | source = "registry+https://github.com/rust-lang/crates.io-index" 44 | 45 | [[package]] 46 | name = "cc" 47 | version = "1.0.10" 48 | source = "registry+https://github.com/rust-lang/crates.io-index" 49 | 50 | [[package]] 51 | name = "cfg-if" 52 | version = "0.1.2" 53 | source = "registry+https://github.com/rust-lang/crates.io-index" 54 | 55 | [[package]] 56 | name = "clap" 57 | version = "2.31.2" 58 | source = "registry+https://github.com/rust-lang/crates.io-index" 59 | dependencies = [ 60 | "ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", 61 | "atty 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", 62 | "bitflags 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 63 | "strsim 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", 64 | "textwrap 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", 65 | "unicode-width 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", 66 | "vec_map 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", 67 | ] 68 | 69 | [[package]] 70 | name = "codespan" 71 | version = "0.1.1" 72 | source = "registry+https://github.com/rust-lang/crates.io-index" 73 | dependencies = [ 74 | "failure 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 75 | ] 76 | 77 | [[package]] 78 | name = "codespan-reporting" 79 | version = "0.1.3" 80 | source = "registry+https://github.com/rust-lang/crates.io-index" 81 | dependencies = [ 82 | "codespan 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 83 | "termcolor 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 84 | ] 85 | 86 | [[package]] 87 | name = "failure" 88 | version = "0.1.1" 89 | source = "registry+https://github.com/rust-lang/crates.io-index" 90 | dependencies = [ 91 | "backtrace 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 92 | "failure_derive 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 93 | ] 94 | 95 | [[package]] 96 | name = "failure_derive" 97 | version = "0.1.1" 98 | source = "registry+https://github.com/rust-lang/crates.io-index" 99 | dependencies = [ 100 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 101 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 102 | "synstructure 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", 103 | ] 104 | 105 | [[package]] 106 | name = "lazy_static" 107 | version = "1.0.0" 108 | source = "registry+https://github.com/rust-lang/crates.io-index" 109 | 110 | [[package]] 111 | name = "libc" 112 | version = "0.2.40" 113 | source = "registry+https://github.com/rust-lang/crates.io-index" 114 | 115 | [[package]] 116 | name = "maplit" 117 | version = "1.0.1" 118 | source = "registry+https://github.com/rust-lang/crates.io-index" 119 | 120 | [[package]] 121 | name = "pest" 122 | version = "1.0.6" 123 | source = "registry+https://github.com/rust-lang/crates.io-index" 124 | 125 | [[package]] 126 | name = "pest_derive" 127 | version = "1.0.7" 128 | source = "registry+https://github.com/rust-lang/crates.io-index" 129 | dependencies = [ 130 | "pest 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", 131 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 132 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 133 | ] 134 | 135 | [[package]] 136 | name = "quote" 137 | version = "0.3.15" 138 | source = "registry+https://github.com/rust-lang/crates.io-index" 139 | 140 | [[package]] 141 | name = "redox_syscall" 142 | version = "0.1.37" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | 145 | [[package]] 146 | name = "redox_termios" 147 | version = "0.1.1" 148 | source = "registry+https://github.com/rust-lang/crates.io-index" 149 | dependencies = [ 150 | "redox_syscall 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)", 151 | ] 152 | 153 | [[package]] 154 | name = "rustc-demangle" 155 | version = "0.1.7" 156 | source = "registry+https://github.com/rust-lang/crates.io-index" 157 | 158 | [[package]] 159 | name = "strsim" 160 | version = "0.7.0" 161 | source = "registry+https://github.com/rust-lang/crates.io-index" 162 | 163 | [[package]] 164 | name = "syn" 165 | version = "0.11.11" 166 | source = "registry+https://github.com/rust-lang/crates.io-index" 167 | dependencies = [ 168 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 169 | "synom 0.11.3 (registry+https://github.com/rust-lang/crates.io-index)", 170 | "unicode-xid 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)", 171 | ] 172 | 173 | [[package]] 174 | name = "synom" 175 | version = "0.11.3" 176 | source = "registry+https://github.com/rust-lang/crates.io-index" 177 | dependencies = [ 178 | "unicode-xid 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)", 179 | ] 180 | 181 | [[package]] 182 | name = "synstructure" 183 | version = "0.6.1" 184 | source = "registry+https://github.com/rust-lang/crates.io-index" 185 | dependencies = [ 186 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 187 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 188 | ] 189 | 190 | [[package]] 191 | name = "termcolor" 192 | version = "0.3.6" 193 | source = "registry+https://github.com/rust-lang/crates.io-index" 194 | dependencies = [ 195 | "wincolor 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", 196 | ] 197 | 198 | [[package]] 199 | name = "termion" 200 | version = "1.5.1" 201 | source = "registry+https://github.com/rust-lang/crates.io-index" 202 | dependencies = [ 203 | "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", 204 | "redox_syscall 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)", 205 | "redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 206 | ] 207 | 208 | [[package]] 209 | name = "textwrap" 210 | version = "0.9.0" 211 | source = "registry+https://github.com/rust-lang/crates.io-index" 212 | dependencies = [ 213 | "unicode-width 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", 214 | ] 215 | 216 | [[package]] 217 | name = "tsrc" 218 | version = "0.1.0" 219 | dependencies = [ 220 | "clap 2.31.2 (registry+https://github.com/rust-lang/crates.io-index)", 221 | "codespan 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 222 | "codespan-reporting 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", 223 | "lazy_static 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", 224 | "maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 225 | "pest 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", 226 | "pest_derive 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)", 227 | "quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)", 228 | "syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)", 229 | ] 230 | 231 | [[package]] 232 | name = "unicode-width" 233 | version = "0.1.4" 234 | source = "registry+https://github.com/rust-lang/crates.io-index" 235 | 236 | [[package]] 237 | name = "unicode-xid" 238 | version = "0.0.4" 239 | source = "registry+https://github.com/rust-lang/crates.io-index" 240 | 241 | [[package]] 242 | name = "vec_map" 243 | version = "0.8.0" 244 | source = "registry+https://github.com/rust-lang/crates.io-index" 245 | 246 | [[package]] 247 | name = "winapi" 248 | version = "0.3.4" 249 | source = "registry+https://github.com/rust-lang/crates.io-index" 250 | dependencies = [ 251 | "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 252 | "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 253 | ] 254 | 255 | [[package]] 256 | name = "winapi-i686-pc-windows-gnu" 257 | version = "0.4.0" 258 | source = "registry+https://github.com/rust-lang/crates.io-index" 259 | 260 | [[package]] 261 | name = "winapi-x86_64-pc-windows-gnu" 262 | version = "0.4.0" 263 | source = "registry+https://github.com/rust-lang/crates.io-index" 264 | 265 | [[package]] 266 | name = "wincolor" 267 | version = "0.1.6" 268 | source = "registry+https://github.com/rust-lang/crates.io-index" 269 | dependencies = [ 270 | "winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", 271 | ] 272 | 273 | [metadata] 274 | "checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" 275 | "checksum atty 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "af80143d6f7608d746df1520709e5d141c96f240b0e62b0aa41bdfb53374d9d4" 276 | "checksum backtrace 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "ebbe525f66f42d207968308ee86bc2dd60aa5fab535b22e616323a173d097d8e" 277 | "checksum backtrace-sys 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "44585761d6161b0f57afc49482ab6bd067e4edef48c12a152c237eb0203f7661" 278 | "checksum bitflags 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b3c30d3802dfb7281680d6285f2ccdaa8c2d8fee41f93805dba5c4cf50dc23cf" 279 | "checksum cc 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)" = "8b9d2900f78631a5876dc5d6c9033ede027253efcd33dd36b1309fc6cab97ee0" 280 | "checksum cfg-if 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c819a1287eb618df47cc647173c5c4c66ba19d888a6e50d605672aed3140de" 281 | "checksum clap 2.31.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f0f16b89cbb9ee36d87483dc939fe9f1e13c05898d56d7b230a0d4dff033a536" 282 | "checksum codespan 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39dd28877b17f6a0b94dca0c2930f1875f5c8b017cc069a8c067ad1bb39d882c" 283 | "checksum codespan-reporting 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a5c1ae63259648f1cd349d0913570828d7e098de98cc685837dacad1367f428" 284 | "checksum failure 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "934799b6c1de475a012a02dab0ace1ace43789ee4b99bcfbf1a2e3e8ced5de82" 285 | "checksum failure_derive 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c7cdda555bb90c9bb67a3b670a0f42de8e73f5981524123ad8578aafec8ddb8b" 286 | "checksum lazy_static 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c8f31047daa365f19be14b47c29df4f7c3b581832407daabe6ae77397619237d" 287 | "checksum libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)" = "6fd41f331ac7c5b8ac259b8bf82c75c0fb2e469bbf37d2becbba9a6a2221965b" 288 | "checksum maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "08cbb6b4fef96b6d77bfc40ec491b1690c779e77b05cd9f07f787ed376fd4c43" 289 | "checksum pest 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0fce5d8b5cc33983fc74f78ad552b5522ab41442c4ca91606e4236eb4b5ceefc" 290 | "checksum pest_derive 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "ab94faafeb93f4c5e3ce81ca0e5a779529a602ad5d09ae6d21996bfb8b6a52bf" 291 | "checksum quote 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" 292 | "checksum redox_syscall 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)" = "0d92eecebad22b767915e4d529f89f28ee96dbbf5a4810d2b844373f136417fd" 293 | "checksum redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7e891cfe48e9100a70a3b6eb652fef28920c117d366339687bd5576160db0f76" 294 | "checksum rustc-demangle 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "11fb43a206a04116ffd7cfcf9bcb941f8eb6cc7ff667272246b0a1c74259a3cb" 295 | "checksum strsim 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bb4f380125926a99e52bc279241539c018323fab05ad6368b56f93d9369ff550" 296 | "checksum syn 0.11.11 (registry+https://github.com/rust-lang/crates.io-index)" = "d3b891b9015c88c576343b9b3e41c2c11a51c219ef067b264bd9c8aa9b441dad" 297 | "checksum synom 0.11.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a393066ed9010ebaed60b9eafa373d4b1baac186dd7e008555b0f702b51945b6" 298 | "checksum synstructure 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "3a761d12e6d8dcb4dcf952a7a89b475e3a9d69e4a69307e01a470977642914bd" 299 | "checksum termcolor 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "adc4587ead41bf016f11af03e55a624c06568b5a19db4e90fde573d805074f83" 300 | "checksum termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "689a3bdfaab439fd92bc87df5c4c78417d3cbe537487274e9b0b2dce76e92096" 301 | "checksum textwrap 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c0b59b6b4b44d867f1370ef1bd91bfb262bf07bf0ae65c202ea2fbc16153b693" 302 | "checksum unicode-width 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "bf3a113775714a22dcb774d8ea3655c53a32debae63a063acc00a91cc586245f" 303 | "checksum unicode-xid 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f860d7d29cf02cb2f3f359fd35991af3d30bac52c57d265a3c461074cb4dc" 304 | "checksum vec_map 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "887b5b631c2ad01628bbbaa7dd4c869f80d3186688f8d0b6f58774fbe324988c" 305 | "checksum winapi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "04e3bd221fcbe8a271359c04f21a76db7d0c6028862d1bb5512d85e1e2eb5bb3" 306 | "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 307 | "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 308 | "checksum wincolor 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "eeb06499a3a4d44302791052df005d5232b927ed1a9658146d842165c4de7767" 309 | -------------------------------------------------------------------------------- /trsc/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "trsc" 3 | version = "0.1.0" 4 | authors = ["ricky han "] 5 | 6 | [dependencies] 7 | pest = "^1.0.0-beta" 8 | pest_derive = "^1.0.0-beta" 9 | maplit = "1.0.1" 10 | codespan = "0.1.1" 11 | codespan-reporting = "0.1.3" 12 | clap = "2.31.2" 13 | lazy_static = "1.0" 14 | trsc_core_derive = { path = "../trsc_core_derive" } 15 | 16 | [dev-dependencies] 17 | assert_cli = "0.6" 18 | -------------------------------------------------------------------------------- /trsc/src/codegen/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod pytorch; 2 | -------------------------------------------------------------------------------- /trsc/src/codegen/pytorch.rs: -------------------------------------------------------------------------------- 1 | #[allow(unused_imports)] 2 | use codespan::ByteSpan; 3 | #[allow(unused_imports)] 4 | use span::CSpan; 5 | use typing::type_env::{Alias, ModName, TypeEnv}; 6 | #[allow(unused_imports)] 7 | use typing::typed_term::ArgsVecInto; 8 | #[allow(unused_imports)] 9 | use typing::typed_term::{TyDecl, TyFieldAccess, TyFnApp, TyFnAppArg, TyFnDecl, TyFnDeclParam, 10 | TyGraphDecl, TyNodeDecl, TyTerm, TyUseStmt, TyWeightsAssign, 11 | TyWeightsDecl, TyAliasAssign}; 12 | use typing::Type; 13 | use std::rc::Rc; 14 | use std::cell::RefCell; 15 | use std::fmt::Write; 16 | use std::collections::{ BTreeSet, BTreeMap, VecDeque }; 17 | use errors::{Diag, Emitter}; 18 | use core::Core; 19 | 20 | type VarName = String; 21 | type FnName = String; 22 | 23 | enum Item { 24 | FnApp(Option, FnName, Vec, bool, Option), 25 | SelfFnApp(Option, FnName, Vec), 26 | Ident(bool, String), 27 | ViewFn(Option, Type), 28 | } 29 | 30 | pub struct Module { 31 | core: Rc>, 32 | pub tenv: Rc>, 33 | pub name: String, 34 | pub ty: Type, 35 | pub fns: Rc>>, 36 | pub inits: Rc>>, 37 | pub buf: String, 38 | pub indent: usize, 39 | codegen_stack: VecDeque, 40 | } 41 | 42 | impl Module { 43 | pub fn new(tenv: Rc>, ty: Type, name: &str, core: Rc>) -> Self { 44 | Self { 45 | core, 46 | tenv, 47 | ty, 48 | name: name.to_owned(), 49 | buf: String::new(), 50 | fns: Rc::new(RefCell::new(BTreeMap::new())), 51 | inits: Rc::new(RefCell::new(vec![])), 52 | indent: 0, 53 | codegen_stack: VecDeque::new(), 54 | } 55 | } 56 | 57 | #[inline(always)] 58 | fn indent(&mut self) -> Result<(), Diag> { 59 | write!(self.buf, "{}", " ".repeat(self.indent*4))?; 60 | Ok(()) 61 | } 62 | #[inline(always)] 63 | fn tab(&mut self) { 64 | self.indent += 1; 65 | } 66 | #[inline(always)] 67 | fn shift_tab(&mut self) { 68 | self.indent -= 1; 69 | } 70 | 71 | pub fn set_fns(&mut self, fns: &Vec) -> Result<(), Diag> { 72 | for f in fns.iter() { 73 | self.fns.borrow_mut().insert( 74 | f.name.as_str().to_owned(), 75 | f.clone() 76 | ); 77 | } 78 | Ok(()) 79 | } 80 | 81 | pub fn generate(&mut self) -> Result<(), Diag> { 82 | self.generate_class_head()?; 83 | let fns_clone = self.fns.clone(); 84 | for (fn_name, f) in fns_clone.borrow().iter().rev() { 85 | if fn_name == "new" { 86 | self.tab(); 87 | self.generate_init_fn(f)?; 88 | self.shift_tab(); 89 | } 90 | else { 91 | self.tab(); 92 | self.generate_fn_decl(f)?; 93 | self.shift_tab(); 94 | } 95 | } 96 | Ok(()) 97 | } 98 | 99 | pub fn set_inits(&mut self, inits: &Vec) -> Result<(), Diag> { 100 | self.inits = Rc::new(RefCell::new(inits.clone())); 101 | Ok(()) 102 | } 103 | 104 | fn collect_term(&mut self, term: &TyTerm, var: Option, is_stmt: bool) -> Result<(), Diag> { 105 | use self::TyTerm::*; 106 | match term { 107 | TyBlock{stmts, ret, ..} => { 108 | self.collect_term(&stmts, var.clone(), is_stmt)?; 109 | self.collect_term(&ret, var, is_stmt)?; 110 | } 111 | TyList(terms) => terms 112 | .iter() 113 | .map(|t| self.collect_term(t, var.clone(), is_stmt)) 114 | .collect::>()?, 115 | TyExpr(t,..) => { 116 | self.collect_term(t, var, is_stmt)?; 117 | } 118 | TyFnApp(box fn_app) => { 119 | self.collect_fn_app(fn_app, var, is_stmt)?; 120 | }, 121 | TyIdent(_t,i,..) => self.codegen_stack 122 | .push_back(Item::Ident(var.is_none(), i.as_str().to_owned())), 123 | TyInteger(..) => (), 124 | TyFloat(..) => (), 125 | TyStmt(t, _) => self.collect_term(t, var, true)?, 126 | TyNone => (), 127 | _ => panic!("{:#?}", term), 128 | } 129 | Ok(()) 130 | } 131 | 132 | fn generate_fn(&mut self) -> Result<(), Diag> { 133 | while let Some(item) = self.codegen_stack.pop_back() { 134 | match item { 135 | Item::FnApp(var_name, fn_name, args, is_stmt, mod_name) => { 136 | self.indent()?; 137 | let mut is_global = false; 138 | let module_name = match mod_name { 139 | None => 140 | self.tenv.borrow() 141 | .resolve_type( 142 | &ModName::Named(self.name.to_owned()), 143 | &Alias::Variable(fn_name.to_owned()), 144 | ) 145 | .unwrap_or_else(|| { 146 | is_global = true; 147 | self.tenv.borrow() 148 | .resolve_type( 149 | &ModName::Global, 150 | &Alias::Variable(fn_name.to_owned()), 151 | ).unwrap() 152 | }) 153 | .as_string(), 154 | Some(ref mn) => mn.as_str().to_owned(), 155 | }; 156 | 157 | match (var_name, is_stmt) { 158 | (Some(v), false) => write!(self.buf, "{} = ", v)?, // assignment 159 | (None, false) => write!(self.buf, "return ")?, // return 160 | (None, true) => write!(self.buf, "")?, // 161 | (Some(_), true) => write!(self.buf, "")?, // nn.init.normal_(..) 162 | }; 163 | 164 | let core_cloned = self.core.clone(); 165 | let core = core_cloned.borrow(); 166 | let op = core.find_mod(&module_name).unwrap(); 167 | let out = { 168 | if mod_name.is_some() { 169 | op.gen_fn_app(&fn_name, args.as_slice())? 170 | } else { 171 | op.gen_fn_app("forward", args.as_slice())? 172 | } 173 | }; 174 | 175 | match (is_global, is_stmt) { 176 | (true, false) => { 177 | write!(self.buf, "{}(", op.pytorch_name())?; 178 | writeln!(self.buf, "{})", out)? 179 | } 180 | (false, false) => { 181 | write!(self.buf, "self.{}(", fn_name)?; 182 | writeln!(self.buf, "{})", out)? 183 | } 184 | (true, true) => write!(self.buf, "")?, 185 | (false, true) => { 186 | write!(self.buf, "")?; 187 | writeln!(self.buf, "{}", out)?; 188 | } 189 | } 190 | } 191 | Item::SelfFnApp(var_name, fn_name, args) => { 192 | self.indent()?; 193 | match var_name { 194 | Some(v) => write!(self.buf, "{} = ", v)?, 195 | None => write!(self.buf, "return ")?, 196 | }; 197 | write!(self.buf, "self.{}(", fn_name)?; 198 | let s = args.to_btreemap().unwrap().keys().cloned().collect::>().join(", "); 199 | write!(self.buf, "{}", s)?; 200 | writeln!(self.buf, ")")?; 201 | } 202 | Item::Ident(ret, name) => { 203 | if ret && name != "self" { 204 | self.indent()?; 205 | writeln!(self.buf, "return {}", name)?; 206 | } 207 | } 208 | Item::ViewFn(var_name, ty) => { 209 | self.indent()?; 210 | match var_name { 211 | Some(name) => { 212 | writeln!(self.buf, "{} = {}.view({})", name, name, ty.as_string())?; 213 | } 214 | None => { 215 | writeln!(self.buf, "return x.view({})", ty.as_string())?; 216 | } 217 | } 218 | } 219 | } 220 | } 221 | Ok(()) 222 | } 223 | 224 | fn collect_fn_app(&mut self, fn_app: &TyFnApp, var_name: Option, is_stmt: bool) -> Result<(), Diag> { 225 | 226 | if fn_app.mod_name == Some("view".to_owned()) { 227 | self.codegen_stack.push_back(Item::ViewFn( 228 | var_name, 229 | fn_app.ret_ty.clone(), 230 | )) 231 | } else if fn_app.name == Alias::Function("forward".to_owned()) { 232 | self.codegen_stack.push_back(Item::FnApp( 233 | var_name, 234 | fn_app.orig_name.clone().unwrap().to_owned(), 235 | fn_app.args.clone(), 236 | is_stmt, 237 | None, 238 | )); 239 | } else { 240 | if fn_app.orig_name == Some("self".to_owned()) { 241 | self.codegen_stack.push_back(Item::SelfFnApp( 242 | var_name, 243 | fn_app.name.as_str().to_owned(), 244 | fn_app.args.clone() 245 | )); 246 | } else { // init_normal 247 | let mod_ty = self.tenv.borrow().resolve_type( 248 | &ModName::Named(self.name.to_owned()), 249 | &Alias::Variable(fn_app.orig_name.clone().unwrap().as_str().to_owned()), 250 | ).unwrap(); 251 | self.codegen_stack.push_back(Item::FnApp( 252 | fn_app.mod_name.clone(), 253 | fn_app.name.as_str().to_owned(), 254 | fn_app.args.clone(), 255 | is_stmt, 256 | Some(mod_ty.as_mod_name()), 257 | )); 258 | } 259 | } 260 | 261 | for arg in fn_app.args.iter() { 262 | self.collect_term(&arg.arg, arg.name.clone(), is_stmt)?; 263 | } 264 | Ok(()) 265 | } 266 | 267 | pub fn generate_class_head(&mut self) -> Result<(), Diag> { 268 | writeln!(self.buf, "class {}(nn.Module):", self.name)?; 269 | self.tab(); 270 | self.indent()?; 271 | writeln!(self.buf, "'''{:?}'''", self.ty)?; 272 | self.shift_tab(); 273 | Ok(()) 274 | } 275 | 276 | fn generate_fn_decl(&mut self, func: &TyFnDecl) -> Result<(), Diag> { 277 | self.generate_fn_decl_head(func.name.as_str(), func)?; 278 | self.tab(); 279 | // self.indent()?; 280 | // writeln!(self.buf, "'''{:?}'''", func.fn_ty)?; 281 | self.collect_term(&func.func_block, None, false)?; 282 | self.generate_fn()?; 283 | self.shift_tab(); 284 | Ok(()) 285 | } 286 | 287 | fn generate_fn_decl_head(&mut self, name: &str, func: &TyFnDecl) -> Result<(), Diag> { 288 | let params = func.fn_params 289 | .iter() 290 | .map(|p| format!("{}", p.name)) 291 | .collect::>(); 292 | self.indent()?; 293 | if !params.is_empty() { 294 | writeln!(self.buf, "def {}(self, {}):", name, params.join(", "))?; 295 | } else { 296 | writeln!(self.buf, "def {}(self):", name)?; 297 | } 298 | Ok(()) 299 | } 300 | 301 | fn generate_init_fn(&mut self, init_fn: &TyFnDecl) -> Result<(), Diag> { 302 | self.generate_fn_decl_head("__init__", init_fn)?; 303 | self.tab(); 304 | self.indent()?; 305 | writeln!(self.buf, "super({}, self).__init__()", self.name)?; 306 | let inits = self.inits.clone(); 307 | for init in inits.borrow().iter() { 308 | self.indent()?; 309 | write!(self.buf, "self.{} = ", init.name)?; 310 | 311 | let module_name = self.tenv.borrow() 312 | .resolve_type( 313 | &ModName::Global, 314 | &Alias::Variable(init.mod_name.as_str().to_owned()) 315 | ) 316 | .unwrap() 317 | .as_string(); 318 | let core_cloned = self.core.clone(); 319 | let core = core_cloned.borrow(); 320 | let op = core.find_mod(&module_name).unwrap(); 321 | write!(self.buf, "{}", op.gen_fn_app(&init.fn_name, init.fn_args.as_slice())?)?; 322 | writeln!(self.buf, "")?; 323 | } 324 | 325 | let fn_new = self.fns.borrow().get("new").unwrap().clone(); 326 | self.collect_term(&fn_new.func_block, None, true)?; 327 | self.generate_fn()?; 328 | 329 | self.shift_tab(); 330 | Ok(()) 331 | } 332 | } 333 | 334 | pub struct Generator { 335 | pub emitter: Rc>, 336 | pub tenv: Rc>, 337 | pub buf: String, 338 | pub imports: BTreeSet<(String, String)>, 339 | pub modules: BTreeMap, 340 | core: Rc>, 341 | pub indent: usize, 342 | } 343 | 344 | impl Generator { 345 | pub fn new(emitter: Rc>, tenv: Rc>, core: Rc>) -> Self { 346 | Self { 347 | emitter, 348 | tenv, 349 | buf: String::new(), 350 | imports: BTreeSet::new(), 351 | modules: BTreeMap::new(), 352 | core, 353 | indent: 0, 354 | } 355 | } 356 | 357 | pub fn generate(&mut self, term: &TyTerm) -> Result<(), Diag> { 358 | self.collect(term)?; 359 | self.generate_imports()?; 360 | self.generate_modules()?; 361 | Ok(()) 362 | } 363 | 364 | fn generate_modules(&mut self) -> Result<(), Diag> { 365 | for (_name, module) in self.modules.iter_mut() { 366 | writeln!(self.buf, "")?; 367 | module.generate()?; 368 | writeln!(self.buf, "{}", module.buf)?; 369 | } 370 | Ok(()) 371 | } 372 | 373 | fn generate_imports(&mut self) -> Result<(), Diag> { 374 | writeln!(self.buf, "import torch")?; 375 | writeln!(self.buf, "from torch.autograd import Variable")?; 376 | writeln!(self.buf, "import torch.nn as nn")?; 377 | writeln!(self.buf, "import torch.nn.functional as F")?; 378 | writeln!(self.buf, "import torch.optim as optim")?; 379 | writeln!(self.buf, "")?; 380 | 381 | // writeln!(self.buf, "# import ops")?; 382 | // for (path_name, mod_name) in self.imports.iter() { 383 | // let import_stmt = self.core.borrow().pytorch_name(path_name, mod_name); 384 | // writeln!(self.buf, "import {} as {}", import_stmt.unwrap(), mod_name)?; 385 | // } 386 | 387 | Ok(()) 388 | } 389 | 390 | fn collect(&mut self, term: &TyTerm) -> Result<(), Diag> { 391 | use self::TyTerm::*; 392 | match term { 393 | TyProgram(decls) => decls 394 | .iter() 395 | .map(|d| self.collect_decl(&d)) 396 | .collect::>()?, 397 | _ => unimplemented!(), 398 | } 399 | Ok(()) 400 | } 401 | 402 | fn collect_decl(&mut self, decl: &TyDecl) -> Result<(), Diag> { 403 | use self::TyDecl::*; 404 | match decl { 405 | TyUseStmt(stmt) => { 406 | for name in stmt.imported_names.iter() { 407 | self.imports.insert( 408 | (stmt.mod_name.to_owned(), 409 | name.to_owned() 410 | ) 411 | ); 412 | } 413 | }, 414 | TyNodeDecl(decl) => { 415 | let m = Module::new(self.tenv.clone(), decl.ty_sig.clone(), &decl.name, self.core.clone()); 416 | self.modules.insert(decl.name.to_owned(), m); 417 | } 418 | TyWeightsDecl(decl) => { 419 | let mut m = self.modules.get_mut(&decl.name).unwrap(); 420 | m.set_inits(&decl.inits)?; 421 | } 422 | TyGraphDecl(decl) => { 423 | let mut m = self.modules.get_mut(&decl.name).unwrap(); 424 | m.set_fns(&decl.fns)?; 425 | } 426 | _ => (), 427 | } 428 | Ok(()) 429 | } 430 | 431 | } 432 | 433 | impl From<::std::fmt::Error> for Diag { 434 | fn from(_error: ::std::fmt::Error) -> Diag { 435 | Diag::UnknownError 436 | } 437 | } -------------------------------------------------------------------------------- /trsc/src/core/conv.rs: -------------------------------------------------------------------------------- 1 | use core::{MethodName, Op, PyTorch, Resolve}; 2 | use errors::Diag; 3 | use span::CSpan; 4 | use typing::typed_term::{ArgsVecInto, TyFnAppArg, TyTerm}; 5 | use typing::{Type, TypeEnv}; 6 | 7 | use std::fmt::Write; 8 | 9 | use self::TyTerm::*; 10 | 11 | 12 | macro_rules! read_2_tuple { 13 | ($var:expr) => { 14 | if let box TyExpr(box TyTuple(_,vs,_),_,_) = $var { // TyExpr 15 | (vs[0].clone().as_num()?, vs[1].clone().as_num()?) 16 | } else { 17 | panic!("{:#?}", $var); 18 | } 19 | }; 20 | } 21 | 22 | macro_rules! read_from_init { 23 | ($var:expr, $default:expr) => { 24 | $var 25 | .map(|t| (t, t.ty()) ) 26 | .and_then(|(t, ty)| 27 | if let Type::Tuple(..) = ty { 28 | Some(read_2_tuple!(t)) 29 | } else { 30 | let p0 = t.as_num()?; 31 | Some((p0, p0)) 32 | } 33 | ).unwrap_or($default) 34 | }; 35 | } 36 | 37 | #[derive(Debug, Op)] 38 | #[path = "conv"] 39 | #[forward = "?() -> unit"] 40 | #[new = "?() -> unit"] 41 | #[stateful] 42 | pub struct Conv2d; 43 | 44 | impl Resolve for Conv2d { 45 | fn resolve( &self, 46 | tenv: &mut TypeEnv, 47 | fn_name: &str, 48 | arg_ty: Type, 49 | _ret_ty: Type, 50 | _args: Vec, 51 | inits: Option> 52 | ) -> Option> { 53 | match fn_name { 54 | "forward" => { 55 | let forward_args = arg_ty.as_args_map()?; 56 | let x_ty = &forward_args["x"]; 57 | if !x_ty.is_resolved() { 58 | None 59 | } else { 60 | let init_map = inits?.to_btreemap()?; 61 | let (k0, k1) = read_from_init!(init_map.get("kernel_size"), (0, 0)); 62 | let (p0, p1) = read_from_init!(init_map.get("padding"), (0, 0)); 63 | let (d0, d1) = read_from_init!(init_map.get("dilation"), (1, 1)); 64 | let (s0, s1) = read_from_init!(init_map.get("stride"), (1, 1)); 65 | 66 | 67 | let in_ch = init_map.get("in_ch").map(|t|t.as_num().unwrap()).expect("does not have in_ch"); 68 | let out_ch = init_map.get("out_ch").map(|t|t.as_num().unwrap()).expect("does not have in_ch"); 69 | 70 | let dims = x_ty.as_vec()?; 71 | let (n, c_in, h_in, w_in) = ( 72 | dims[0].to_owned(), 73 | dims[1].to_owned().as_num().unwrap(), 74 | dims[2].to_owned().as_num().unwrap(), 75 | dims[3].to_owned().as_num().unwrap() 76 | ); 77 | 78 | assert_eq!(c_in, in_ch); 79 | // println!("BLAH: {:?}", x_ty); 80 | let h_out = (h_in + 2 * p0 - d0 * (k0 -1) - 1) / s0 + 1; 81 | let w_out = (w_in + 2 * p1 - d1 * (k1 -1) - 1) / s1 + 1; 82 | 83 | let span = x_ty.span(); 84 | 85 | Some(Ok( // returns a function 86 | fun!( 87 | "Conv2d", 88 | "forward", 89 | arg_ty, 90 | Type::TSR(vec![ 91 | n, 92 | Type::ResolvedDim(out_ch, span), 93 | Type::ResolvedDim(h_out, span), 94 | Type::ResolvedDim(w_out, span), 95 | ], span) 96 | ) 97 | )) 98 | } 99 | }, 100 | "new" => { 101 | Some(Ok(fun!( 102 | "Conv2d", 103 | "new", 104 | args!( 105 | arg!("in_ch", int!()), 106 | arg!("out_ch", int!()), 107 | arg!("kernel_size", tenv.fresh_var(CSpan::fresh_span())) 108 | ), 109 | module!("Conv2d") 110 | ))) 111 | } 112 | _ => unimplemented!(), 113 | } 114 | } 115 | 116 | } 117 | 118 | impl PyTorch for Conv2d { 119 | 120 | fn pytorch_name(&self) -> &'static str { 121 | "nn.Conv2d" 122 | } 123 | 124 | fn gen_fn_app(&self, name: &str, args: &[TyFnAppArg]) -> Result { 125 | let mut buf = String::new(); 126 | match name { 127 | "new" => { 128 | write!(buf, "{}(", self.pytorch_name()).unwrap(); 129 | let map = args.to_btreemap().unwrap(); 130 | write!(buf, "in_channels={}, ", map["in_ch"].as_str().unwrap()).unwrap(); 131 | write!(buf, "out_channels={}, ", map["out_ch"].as_str().unwrap()).unwrap(); 132 | write!(buf, "kernel_size={})", map["kernel_size"].as_str().unwrap()).unwrap(); 133 | Ok(buf) 134 | } 135 | "forward" => { 136 | let args: Vec<_> = args.iter().map(|i| i.name.clone().unwrap()).collect(); 137 | write!(buf, "{}", args.join(", ")).unwrap(); 138 | Ok(buf) 139 | } 140 | _ => panic!("{} is not implemented", name), 141 | } 142 | } 143 | 144 | } 145 | 146 | #[allow(non_camel_case_types)] 147 | #[derive(Debug, Op)] 148 | #[path = "conv"] 149 | #[forward = "?() -> unit"] 150 | pub struct maxpool2d; 151 | 152 | 153 | impl Resolve for maxpool2d { 154 | fn resolve( 155 | &self, 156 | _tenv: &mut TypeEnv, 157 | fn_name: &str, 158 | arg_ty: Type, 159 | _ret_ty: Type, 160 | args: Vec, 161 | _inits: Option>, 162 | ) -> Option> { 163 | match fn_name { 164 | "forward" => { 165 | let args_ty_map = arg_ty.as_args_map()?; 166 | let x_ty = args_ty_map.get("x").expect("No x argument"); 167 | let args_map = args.to_btreemap()?; 168 | 169 | if !x_ty.is_resolved() { 170 | None 171 | } else { 172 | let (k0, k1) = read_from_init!(args_map.get("kernel_size"), (0, 0)); 173 | let (p0, p1) = read_from_init!(args_map.get("padding"), (0, 0)); 174 | let (d0, d1) = read_from_init!(args_map.get("dilation"), (1, 1)); 175 | let (s0, s1) = read_from_init!(args_map.get("stride"), (k0, k1)); 176 | 177 | let dims = x_ty.as_vec()?; 178 | let (n, c_in, h_in, w_in) = ( 179 | dims[0].to_owned(), 180 | dims[1].to_owned(), 181 | dims[2].to_owned().as_num().unwrap(), 182 | dims[3].to_owned().as_num().unwrap() 183 | ); 184 | // println!("BLAH: {:?}", x_ty); 185 | let h_out = (h_in + 2 * p0 - d0 * (k0 -1) - 1) / s0 + 1; 186 | let w_out = (w_in + 2 * p1 - d1 * (k1 -1) - 1) / s1 + 1; 187 | 188 | let span = x_ty.span(); 189 | 190 | Some(Ok( // returns a function 191 | fun!( 192 | "maxpool2d", 193 | "forward", 194 | arg_ty, 195 | Type::TSR(vec![ 196 | n, 197 | c_in.clone(), 198 | Type::ResolvedDim(h_out, span), 199 | Type::ResolvedDim(w_out, span), 200 | ], span) 201 | ) 202 | )) 203 | } 204 | }, 205 | _ => None, 206 | } 207 | } 208 | 209 | } 210 | 211 | impl PyTorch for maxpool2d { 212 | 213 | fn pytorch_name(&self) -> &'static str { 214 | "F.max_pool2d" 215 | } 216 | 217 | fn gen_fn_app(&self, name: &str, args: &[TyFnAppArg]) -> Result { 218 | let mut buf = String::new(); 219 | match name { 220 | "new" => { 221 | write!(buf, "{}(", self.pytorch_name()).unwrap(); 222 | let map = args.to_btreemap().unwrap(); 223 | write!(buf, "in_channels={}, ", map["in_ch"].as_str().unwrap()).unwrap(); 224 | write!(buf, "out_channels={}, ", map["out_ch"].as_str().unwrap()).unwrap(); 225 | write!(buf, "kernel_size={})", map["kernel_size"].as_str().unwrap()).unwrap(); 226 | Ok(buf) 227 | } 228 | "forward" => { 229 | let args: Vec<_> = args 230 | .iter() 231 | .map(|i| i.name.clone().unwrap()) 232 | .collect(); 233 | write!(buf, "{}", args.join(", ")).unwrap(); 234 | Ok(buf) 235 | } 236 | _ => panic!("{} is not implemented", name), 237 | } 238 | } 239 | } -------------------------------------------------------------------------------- /trsc/src/core/lin.rs: -------------------------------------------------------------------------------- 1 | use core::{MethodName, Op, PyTorch, Resolve}; 2 | use errors::Diag; 3 | use span::CSpan; 4 | use typing::typed_term::{ArgsVecInto, TyFnAppArg, TyTerm}; 5 | use typing::{Type, TypeEnv}; 6 | use std::fmt::Write; 7 | 8 | #[derive(Debug, Op)] 9 | #[path = "lin"] 10 | #[new = "(in: int, out: int) -> self"] 11 | #[forward = "?(x: tsr0) -> tsr0"] 12 | #[init_normal = "(std: float) -> unit"] 13 | #[stateful] 14 | pub struct Linear; 15 | 16 | impl Resolve for Linear { 17 | /// output same shape as input 18 | fn resolve( 19 | &self, 20 | _tenv: &mut TypeEnv, 21 | fn_name: &str, 22 | arg_ty: Type, 23 | ret_ty: Type, 24 | _args: Vec, 25 | inits: Option>, // ... refactor into span error 26 | ) -> Option> { 27 | match fn_name { 28 | "forward" => { 29 | if inits.is_some() { 30 | let hm = inits.unwrap().to_btreemap().unwrap(); 31 | if !hm.contains_key("in") { 32 | panic!("Initatialize Linear with parameter in="); 33 | } else if !hm.contains_key("out") { 34 | panic!("Initatialize Linear with parameter out="); 35 | } 36 | 37 | let in_dim = hm.get("in").and_then(|t| unwrap_dim(t))?; 38 | let out_dim = hm.get("out").and_then(|t| unwrap_dim(t))?; 39 | 40 | let span = arg_ty.span(); 41 | 42 | let (a, b) = match (arg_ty.first_arg_ty()?.as_vec(), ret_ty.as_vec()) { 43 | (None, None) => return None, 44 | (Some(ref mut a), None) | 45 | (None, Some(ref mut a)) => { 46 | // modify the last dimension 47 | let mut b = a.clone(); 48 | { 49 | let mut last_arg_dim = a.last_mut().unwrap(); 50 | let mut last_ret_dim = b.last_mut().unwrap(); 51 | *last_arg_dim = Type::ResolvedDim(in_dim, CSpan::fresh_span()); 52 | *last_ret_dim = Type::ResolvedDim(out_dim, CSpan::fresh_span()); 53 | }; 54 | 55 | (a.clone(), b) 56 | } 57 | (Some(ref mut a), Some(ref mut b)) => { 58 | if a.len() != b.len() { 59 | // return dimension mismatch 60 | return Some(Err(Diag::TypeError(arg_ty, ret_ty))); 61 | } 62 | // modify the last dimension 63 | { 64 | let mut last_arg_dim = a.last_mut().unwrap(); 65 | let mut last_ret_dim = b.last_mut().unwrap(); 66 | *last_arg_dim = Type::ResolvedDim(in_dim, CSpan::fresh_span()); 67 | *last_ret_dim = Type::ResolvedDim(out_dim, CSpan::fresh_span()); 68 | }; 69 | 70 | (a.clone(), b.clone()) 71 | } 72 | }; 73 | 74 | Some(Ok(fun!( 75 | self.get_name(), 76 | "forward", 77 | args!(arg!("x",Type::TSR(a, span))), 78 | Type::TSR(b, span) 79 | ))) 80 | } else { 81 | None 82 | } 83 | } 84 | _ => unimplemented!(), 85 | } 86 | } 87 | 88 | } 89 | impl PyTorch for Linear { 90 | 91 | fn pytorch_name(&self) -> &'static str { 92 | "nn.Linear" 93 | } 94 | 95 | fn gen_fn_app(&self, name: &str, args: &[TyFnAppArg]) -> Result { 96 | let mut buf = String::new(); 97 | match name { 98 | "new" => { 99 | let map = args.to_btreemap().unwrap(); 100 | write!(buf, "{}(", self.pytorch_name()).unwrap(); 101 | write!(buf, "in_features={:?}, ", map["in"].as_num().unwrap()).unwrap(); 102 | write!(buf, "out_features={:?})", map["out"].as_num().unwrap()).unwrap(); 103 | Ok(buf) 104 | } 105 | "forward" => { 106 | let args: Vec<_> = args.iter().map(|i| i.name.clone().unwrap()).collect(); 107 | write!(buf, "{}", args.join(", ")).unwrap(); 108 | Ok(buf) 109 | } 110 | "init_normal" => { 111 | let map = args.to_btreemap().unwrap(); 112 | write!(buf, "nn.init.normal_(").unwrap(); 113 | write!(buf, "std={}", map["std"].as_float().unwrap()).unwrap(); 114 | write!(buf, ")").unwrap(); 115 | Ok(buf) 116 | } 117 | _ => panic!("{} is not implemented", name), 118 | } 119 | } 120 | } 121 | 122 | fn unwrap_dim(in_dim: &TyTerm) -> Option { 123 | match in_dim.ty() { 124 | Type::INT(_) => in_dim.as_num(), 125 | Type::ResolvedDim(num, _) => Some(num), 126 | _ => panic!("{:?} is not a numeric value!", in_dim), 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /trsc/src/core/mod.rs: -------------------------------------------------------------------------------- 1 | use typing::typed_term::TyFnAppArg; 2 | use errors::Diag; 3 | use typing::{Type, TypeEnv}; 4 | use std::collections::HashMap; 5 | use std::fmt::Debug; 6 | 7 | mod prelude; 8 | mod conv; 9 | mod lin; 10 | mod reg; 11 | mod nonlin; 12 | 13 | pub trait Op: PyTorch + Resolve { 14 | fn get_name(&self) -> &'static str; 15 | 16 | fn ty_sigs(&self, tenv: &mut TypeEnv) -> Vec<(MethodName, Type)>; 17 | 18 | fn is_stateful(&self) -> bool; 19 | } 20 | 21 | pub trait Resolve { 22 | fn resolve( 23 | &self, 24 | _tenv: &mut TypeEnv, 25 | fn_name: &str, 26 | _arg_ty: Type, 27 | _ret_ty: Type, 28 | _args: Vec, 29 | _inits: Option>, 30 | ) -> Option> { 31 | panic!("{} is not yet implemented", fn_name); 32 | } 33 | } 34 | 35 | pub trait PyTorch: Debug { 36 | fn pytorch_name(&self) -> &'static str; 37 | fn gen_fn_app(&self, name: &str, _args: &[TyFnAppArg]) -> Result { 38 | panic!("{:?}::{} function call is not yet implemented", self, name); 39 | // unimplemented!() 40 | } 41 | } 42 | 43 | #[derive(Debug)] 44 | pub struct Core { 45 | maps: HashMap<&'static str, HashMap<&'static str, Box>>, 46 | } 47 | 48 | pub type MethodName = &'static str; 49 | 50 | impl Core { 51 | pub fn new() -> Self { 52 | let maps = hashmap! { 53 | "conv" => hashmap! { 54 | "Conv2d" => box self::conv::Conv2d as Box, 55 | "maxpool2d" => box self::conv::maxpool2d as Box, 56 | }, 57 | "nonlin" => hashmap! { 58 | "relu" => box self::nonlin::relu as Box, 59 | "tanh" => box self::nonlin::tanh as Box, 60 | "leaky_relu" => box self::nonlin::leaky_relu as Box, 61 | "log_softmax" => box self::nonlin::log_softmax as Box, 62 | "sigmoid" => box self::nonlin::sigmoid as Box, 63 | }, 64 | "lin" => hashmap! { 65 | "Linear" => box self::lin::Linear as Box, 66 | }, 67 | "prelude" => hashmap! { 68 | "view" => box self::prelude::view as Box, 69 | }, 70 | "reg" => hashmap! { 71 | "Dropout2d" => box self::reg::Dropout2d as Box, 72 | "BatchNorm1d" => box self::reg::BatchNorm1d as Box, 73 | } 74 | }; 75 | Self { 76 | maps, 77 | } 78 | } 79 | pub fn import(&self, path_name: &str, mod_name: &str, tenv: &mut TypeEnv) -> Option> { 80 | let op = self.find(path_name, mod_name)?; 81 | Some(op.ty_sigs(tenv)) 82 | } 83 | 84 | pub fn find(&self, path_name: &str, mod_name: &str) -> Option<&Box> { 85 | let ret = self.maps.get(path_name)?.get(mod_name)?; 86 | Some(ret) 87 | } 88 | 89 | pub fn find_mod(&self, mod_name:&str) -> Option<&Box> { 90 | self.maps.values() 91 | .map(|m| m.get(mod_name)) 92 | .filter(|i|i.is_some()) 93 | .collect::>>() 94 | .first()? 95 | .to_owned() 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /trsc/src/core/nonlin.rs: -------------------------------------------------------------------------------- 1 | use self::Type::*; 2 | use core::{MethodName, Op, PyTorch, Resolve}; 3 | use std::fmt::Write; 4 | use span::CSpan; 5 | use typing::typed_term::TyFnAppArg; 6 | use typing::{Type, TypeEnv}; 7 | use errors::Diag; 8 | 9 | #[allow(non_camel_case_types)] 10 | #[derive(Debug, Op)] 11 | #[path = "nonlin"] 12 | #[forward = "?()"] 13 | pub struct sigmoid; 14 | 15 | impl Resolve for sigmoid { 16 | fn resolve( 17 | &self, 18 | tenv: &mut TypeEnv, 19 | fn_name: &str, 20 | _arg_ty: Type, 21 | _ret_ty: Type, 22 | _args: Vec, 23 | _inits: Option>, 24 | ) -> Option> { 25 | match fn_name { 26 | "forward" => { 27 | let ty = tenv.fresh_var(CSpan::fresh_span()); 28 | Some(Ok(fun!(self.get_name(), "forward", args!(arg!("x", ty.clone())), ty))) 29 | } 30 | _ => unimplemented!(), 31 | } 32 | } 33 | } 34 | 35 | impl PyTorch for sigmoid { 36 | fn pytorch_name(&self) -> &'static str { 37 | "F.sigmoid" 38 | } 39 | fn gen_fn_app(&self, name: &str, args: &[TyFnAppArg]) -> Result { 40 | let mut buf = String::new(); 41 | match name { 42 | "forward" => { 43 | write!(buf, "x").unwrap(); 44 | Ok(buf) 45 | } 46 | _ => panic!("{} is not implemented", name), 47 | } 48 | } 49 | } 50 | 51 | #[allow(non_camel_case_types)] 52 | #[derive(Debug, Op)] 53 | #[path = "nonlin"] 54 | #[forward = "?()"] 55 | pub struct tanh; 56 | 57 | impl Resolve for tanh { 58 | fn resolve( 59 | &self, 60 | tenv: &mut TypeEnv, 61 | fn_name: &str, 62 | _arg_ty: Type, 63 | _ret_ty: Type, 64 | _args: Vec, 65 | _inits: Option>, 66 | ) -> Option> { 67 | match fn_name { 68 | "forward" => { 69 | let ty = tenv.fresh_var(CSpan::fresh_span()); 70 | Some(Ok(fun!(self.get_name(), "forward", args!(arg!("x", ty.clone())), ty))) 71 | } 72 | _ => unimplemented!(), 73 | } 74 | } 75 | } 76 | 77 | impl PyTorch for tanh { 78 | fn pytorch_name(&self) -> &'static str { 79 | "F.tanh" 80 | } 81 | fn gen_fn_app(&self, name: &str, _args: &[TyFnAppArg]) -> Result { 82 | let mut buf = String::new(); 83 | match name { 84 | "forward" => { 85 | write!(buf, "x").unwrap(); 86 | Ok(buf) 87 | } 88 | _ => panic!("{} is not implemented", name), 89 | } 90 | } 91 | } 92 | 93 | #[allow(non_camel_case_types)] 94 | #[derive(Debug, Op)] 95 | #[path = "nonlin"] 96 | #[forward = "?()"] 97 | pub struct relu; 98 | 99 | impl Resolve for relu { 100 | fn resolve( 101 | &self, 102 | tenv: &mut TypeEnv, 103 | fn_name: &str, 104 | _arg_ty: Type, 105 | _ret_ty: Type, 106 | _args: Vec, 107 | _inits: Option>, 108 | ) -> Option> { 109 | match fn_name { 110 | "forward" => { 111 | let ty = tenv.fresh_var(CSpan::fresh_span()); 112 | Some(Ok(fun!(self.get_name(), "forward", args!(arg!("x", ty.clone())), ty))) 113 | } 114 | _ => unimplemented!(), 115 | } 116 | } 117 | } 118 | 119 | impl PyTorch for relu { 120 | fn pytorch_name(&self) -> &'static str { 121 | "F.relu" 122 | } 123 | fn gen_fn_app(&self, name: &str, _args: &[TyFnAppArg]) -> Result { 124 | let mut buf = String::new(); 125 | match name { 126 | "forward" => { 127 | write!(buf, "x").unwrap(); 128 | Ok(buf) 129 | } 130 | _ => panic!("{} is not implemented", name), 131 | } 132 | } 133 | } 134 | 135 | #[allow(non_camel_case_types)] 136 | #[derive(Debug, Op)] 137 | #[path = "nonlin"] 138 | #[forward = "?()"] 139 | pub struct leaky_relu; 140 | 141 | impl Resolve for leaky_relu { 142 | fn resolve( 143 | &self, 144 | tenv: &mut TypeEnv, 145 | fn_name: &str, 146 | _arg_ty: Type, 147 | _ret_ty: Type, 148 | _args: Vec, 149 | _inits: Option>, 150 | ) -> Option> { 151 | match fn_name { 152 | "forward" => { 153 | let ty = tenv.fresh_var(CSpan::fresh_span()); 154 | Some(Ok(fun!(self.get_name(), "forward", args!(arg!("x", ty.clone())), ty))) 155 | } 156 | _ => unimplemented!(), 157 | } 158 | } 159 | } 160 | 161 | impl PyTorch for leaky_relu { 162 | fn pytorch_name(&self) -> &'static str { 163 | "F.leaky_relu" 164 | } 165 | fn gen_fn_app(&self, name: &str, _args: &[TyFnAppArg]) -> Result { 166 | let mut buf = String::new(); 167 | match name { 168 | "forward" => { 169 | write!(buf, "x").unwrap(); 170 | Ok(buf) 171 | } 172 | _ => panic!("{} is not implemented", name), 173 | } 174 | } 175 | } 176 | 177 | #[allow(non_camel_case_types)] 178 | #[derive(Debug, Op)] 179 | #[path = "nonlin"] 180 | #[forward = "?()"] 181 | pub struct log_softmax; 182 | 183 | impl Resolve for log_softmax { 184 | fn resolve( 185 | &self, 186 | tenv: &mut TypeEnv, 187 | fn_name: &str, 188 | _arg_ty: Type, 189 | _ret_ty: Type, 190 | _args: Vec, 191 | _inits: Option>, 192 | ) -> Option> { 193 | match fn_name { 194 | "forward" => { 195 | let ty = tenv.fresh_var(CSpan::fresh_span()); 196 | Some(Ok(fun!(self.get_name(), "forward", args!(arg!("x", ty.clone()), arg!("dim", int!())), ty))) 197 | } 198 | _ => unimplemented!(), 199 | } 200 | } 201 | } 202 | 203 | impl PyTorch for log_softmax { 204 | fn pytorch_name(&self) -> &'static str { 205 | "F.log_softmax" 206 | } 207 | fn gen_fn_app(&self, name: &str, args: &[TyFnAppArg]) -> Result { 208 | let mut buf = String::new(); 209 | match name { 210 | "forward" => { 211 | let args: Vec<_> = args.iter().map(|i| i.name.clone().unwrap()).collect(); 212 | write!(buf, "{}", args.join(", ")).unwrap(); 213 | Ok(buf) 214 | } 215 | _ => panic!("{} is not implemented", name), 216 | } 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /trsc/src/core/prelude.rs: -------------------------------------------------------------------------------- 1 | use core::{MethodName, Op, PyTorch, Resolve}; 2 | use errors::Diag; 3 | use span::CSpan; 4 | use typing::typed_term::TyFnAppArg; 5 | use typing::{Type, TypeEnv}; 6 | 7 | #[allow(non_camel_case_types)] 8 | #[derive(Debug, Op)] 9 | #[path = "prelude"] 10 | #[forward = "?() -> unit"] 11 | pub struct view; 12 | 13 | impl Resolve for view { 14 | /// output same shape as input 15 | fn resolve( 16 | &self, 17 | _tenv: &mut TypeEnv, 18 | fn_name: &str, 19 | arg_ty: Type, 20 | ret_ty: Type, 21 | _args: Vec, 22 | _inits: Option>, // ... refactor into span error 23 | ) -> Option> { 24 | match fn_name { 25 | "forward" => { 26 | // println!("ret_ty: {:#?}\n, arg_ty: {:#?}", ret_ty, arg_ty); 27 | if !arg_ty.is_resolved() { return None; } 28 | let args_map = arg_ty.as_args_map()?; 29 | let arg_tsr = args_map.get("x")?.as_vec()?; 30 | let ret_tsr = ret_ty.as_vec()?; 31 | 32 | let resolved_arg_tsr: Vec = arg_tsr.iter().filter_map(|i| i.as_num()).collect(); 33 | let resolved_ret_tsr: Vec = ret_tsr.iter().filter_map(|i| i.as_num()).collect(); 34 | 35 | // suppose arg_ty = [!1, 10] 36 | // ret_ty = ['100, 2, 5] 37 | // replace '100 with !1 38 | if ret_tsr.len() - resolved_ret_tsr.len() > 1 { 39 | return Some(Err( 40 | Diag::EllisionError("Cannot elide more than 1 tensor dimension in view function".to_owned(), ret_tsr[0].span()) 41 | )); 42 | } 43 | 44 | let ret_prod: i64 = resolved_ret_tsr.iter().product(); 45 | let arg_prod: i64 = resolved_arg_tsr.iter().product(); 46 | 47 | let is_only_one_arg_dim_unresolved = (arg_tsr.len() - resolved_arg_tsr.len()) == 1; 48 | if ret_prod == arg_prod && is_only_one_arg_dim_unresolved { 49 | let unresolved_arg_dim = arg_tsr.iter().find(|i| i.as_num().is_none()).unwrap(); 50 | let unresolved_ret_dim = ret_tsr.iter().find(|i| i.as_num().is_none()).unwrap(); 51 | let modified_ret_ty = ret_tsr 52 | .iter() 53 | .map(|i| 54 | if i == unresolved_ret_dim { 55 | unresolved_arg_dim 56 | } else { 57 | i 58 | } ) 59 | .cloned() 60 | .collect(); 61 | Some(Ok( 62 | fun!("view", "forward", arg_ty, tsr!(modified_ret_ty)) 63 | )) 64 | } else { 65 | panic!("{} {}", ret_prod, arg_prod);// ... 66 | } 67 | } 68 | _ => unimplemented!(), 69 | } 70 | } 71 | } 72 | 73 | impl PyTorch for view { 74 | fn pytorch_name(&self) -> &'static str { 75 | unimplemented!(); 76 | } 77 | } -------------------------------------------------------------------------------- /trsc/src/core/reg.rs: -------------------------------------------------------------------------------- 1 | use core::{MethodName, Op, PyTorch, Resolve}; 2 | use errors::Diag; 3 | use span::CSpan; 4 | use typing::typed_term::TyFnAppArg; 5 | use typing::{Type, TypeEnv}; 6 | use typing::typed_term::ArgsVecInto; 7 | use std::fmt::Write; 8 | 9 | #[derive(Debug, Op)] 10 | #[path = "reg"] 11 | #[new = "(p: float) -> self"] 12 | #[forward = "?(x: tsr0) -> tsr0"] 13 | pub struct Dropout2d; 14 | 15 | impl Resolve for Dropout2d { 16 | fn resolve( 17 | &self, 18 | tenv: &mut TypeEnv, 19 | fn_name: &str, 20 | _arg_ty: Type, 21 | _ret_ty: Type, 22 | _args: Vec, 23 | _inits: Option>, 24 | ) -> Option> { 25 | match fn_name { 26 | "forward" => { 27 | let ty = tenv.fresh_var(CSpan::fresh_span()); 28 | Some(Ok(fun!(self.get_name(), "forward", args!(arg!("x", ty.clone())), ty))) 29 | } 30 | _ => unimplemented!(), 31 | } 32 | } 33 | } 34 | 35 | impl PyTorch for Dropout2d { 36 | 37 | fn pytorch_name(&self) -> &'static str { 38 | "nn.Dropout2d" 39 | } 40 | 41 | fn gen_fn_app(&self, name: &str, args: &[TyFnAppArg]) -> Result { 42 | let mut buf = String::new(); 43 | match name { 44 | "new" => { 45 | write!(buf, "{}(", self.get_name()).unwrap(); 46 | let map = args.to_btreemap().unwrap(); 47 | write!(buf, "p={})", map["p"].as_str().unwrap()).unwrap(); 48 | Ok(buf) 49 | } 50 | "forward" => { 51 | let args: Vec<_> = args.iter().map(|i| i.name.clone().unwrap()).collect(); 52 | write!(buf, "{}", args.join(", ")).unwrap(); 53 | Ok(buf) 54 | } 55 | _ => panic!("{} is not implemented", name), 56 | } 57 | } 58 | } 59 | 60 | #[derive(Debug, Op)] 61 | #[path = "reg"] 62 | #[new = "(num_features: int) -> self"] 63 | #[forward = "?(x: tsr0) -> tsr0"] 64 | #[stateful] 65 | pub struct BatchNorm1d; 66 | 67 | impl Resolve for BatchNorm1d { 68 | fn resolve( 69 | &self, 70 | tenv: &mut TypeEnv, 71 | fn_name: &str, 72 | _arg_ty: Type, 73 | _ret_ty: Type, 74 | _args: Vec, 75 | _inits: Option>, 76 | ) -> Option> { 77 | match fn_name { 78 | "forward" => { 79 | let ty = tenv.fresh_var(CSpan::fresh_span()); 80 | Some(Ok(fun!(self.get_name(), "forward", args!(arg!("x", ty.clone())), ty))) 81 | } 82 | _ => unimplemented!(), 83 | } 84 | } 85 | } 86 | 87 | impl PyTorch for BatchNorm1d { 88 | fn pytorch_name(&self) -> &'static str { 89 | "nn.BatchNorm1d" 90 | } 91 | fn gen_fn_app(&self, name: &str, args: &[TyFnAppArg]) -> Result { 92 | let mut buf = String::new(); 93 | match name { 94 | "new" => { 95 | let map = args.to_btreemap().unwrap(); 96 | write!(buf, "{}(", self.pytorch_name()).unwrap(); 97 | write!(buf, "num_features={})", 98 | map["num_features"].as_num().unwrap()).unwrap(); 99 | } 100 | "forward" => { 101 | // let map = args.to_btreemap().unwrap(); 102 | write!(buf, "x").unwrap(); 103 | } 104 | _ => unimplemented!(), 105 | } 106 | 107 | Ok(buf) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /trsc/src/errors/diagnostic.rs: -------------------------------------------------------------------------------- 1 | use codespan_reporting::{Diagnostic, Label, Severity}; 2 | use typing::Type; 3 | use codespan::CodeMap; 4 | use codespan::{ByteSpan, LineIndex}; 5 | 6 | #[derive(Debug, Clone)] 7 | pub enum Diag { 8 | UnknownError, 9 | RankMismatch(Type, Type), 10 | DimensionMismatch(Type, Type), 11 | ParseError(String, ByteSpan), 12 | SymbolNotFound(String, ByteSpan), 13 | ImportError(String, ByteSpan), 14 | DuplicateVarInScope(String, Type, Type), 15 | TypeError(Type, Type), 16 | EllisionError(String, ByteSpan), 17 | } 18 | 19 | impl Diag { 20 | pub fn as_diagnostic(&self, code_map: &CodeMap) -> Diagnostic { 21 | use self::Diag::*; 22 | match self { 23 | DimensionMismatch(Type::ResolvedDim(v1, s1), Type::ResolvedDim(v2,s2)) => { 24 | Diagnostic::new( 25 | Severity::Error, 26 | format!("Dimension mismatch: {} != {}", v1, v2), 27 | ) 28 | .with_label(Label::new_primary(*s1)) 29 | .with_label(Label::new_primary(*s2)) 30 | } 31 | 32 | RankMismatch(Type::TSR(dims1, s1), Type::TSR(dims2, s2)) => { 33 | Diagnostic::new( 34 | Severity::Error, 35 | format!("Tensor rank mismatch: rank({:?}) != rank({:?})", dims1, dims2), 36 | ) 37 | .with_label(Label::new_primary(*s1)) 38 | .with_label(Label::new_primary(*s2)) 39 | }, 40 | 41 | ParseError(msg, sp) => { 42 | // Since error points to the next line, 43 | // also print the line before 44 | let idx = sp.start(); 45 | let file = code_map.find_file(idx).unwrap(); 46 | let line = file.find_line(idx).unwrap(); 47 | let prev_line = line.to_usize() - 1; 48 | let prev_line_span = file.line_span(LineIndex(prev_line as u32)).unwrap(); 49 | Diagnostic::new( 50 | Severity::Error, 51 | format!("{} on line {}:", msg, prev_line + 1), 52 | ) 53 | .with_label(Label::new_primary(prev_line_span)) 54 | .with_label(Label::new_primary(*sp)) 55 | }, 56 | 57 | SymbolNotFound(msg, sp) => { 58 | Diagnostic::new( 59 | Severity::Error, 60 | format!("Symbol `{}` not in scope", msg), 61 | ) 62 | .with_label(Label::new_primary(*sp)) 63 | } 64 | 65 | ImportError(msg, sp) => { 66 | Diagnostic::new( 67 | Severity::Error, 68 | format!("Cannot import symbol `{}`", msg), 69 | ) 70 | .with_label(Label::new_primary(*sp)) 71 | } 72 | 73 | DuplicateVarInScope(name, ty1, ty2) => { 74 | Diagnostic::new( 75 | Severity::Error, 76 | format!("Duplicate symbol in scope: {}: {:?}, {:?}", name, ty1, ty2), 77 | ) 78 | .with_label(Label::new_primary(ty1.span())) 79 | .with_label(Label::new_primary(ty2.span())) 80 | } 81 | 82 | TypeError(ty1, ty2) => { 83 | Diagnostic::new( 84 | Severity::Error, 85 | format!("Type mismatch: {:?}, {:?}", ty1, ty2), 86 | ) 87 | .with_label(Label::new_primary(ty1.span())) 88 | .with_label(Label::new_primary(ty2.span())) 89 | } 90 | 91 | EllisionError(msg, span) => { 92 | Diagnostic::new( 93 | Severity::Error, 94 | msg.to_owned(), 95 | ) 96 | .with_label(Label::new_primary(*span)) 97 | } 98 | 99 | _ => unimplemented!(), 100 | } 101 | } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /trsc/src/errors/emitter.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | use codespan::CodeMap; 3 | use codespan_reporting::termcolor::StandardStream; 4 | use codespan_reporting::{emit, ColorArg, Diagnostic, Severity }; 5 | use super::diagnostic::Diag; 6 | use std::process::exit; 7 | 8 | #[derive(Debug, Clone)] 9 | pub struct Emitter { 10 | errs: Vec, 11 | code_map: CodeMap, 12 | print_ast: bool, 13 | } 14 | 15 | impl Emitter { 16 | pub fn new(code_map: CodeMap, print_ast: bool) -> Self { 17 | Self { 18 | errs: vec![], 19 | code_map, 20 | print_ast, 21 | } 22 | } 23 | 24 | pub fn add(&mut self, e: Diag) { 25 | self.errs.push(e); 26 | } 27 | 28 | pub fn print_errs(&self) { 29 | let mut diagnostics: Vec = self.errs 30 | .iter() 31 | .map(|e|e.as_diagnostic(&self.code_map)) 32 | .collect(); 33 | let writer = StandardStream::stderr(ColorArg::from_str("auto").unwrap().into()); 34 | let mut is_err = false; 35 | while let Some(diagnostic) = &diagnostics.pop() { // consumes so it only prints once 36 | if diagnostic.severity == Severity::Error { is_err = true } 37 | emit(&mut writer.lock(), &self.code_map, &diagnostic).unwrap(); 38 | } 39 | if is_err && !self.print_ast { exit(-1) } 40 | } 41 | } -------------------------------------------------------------------------------- /trsc/src/errors/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod diagnostic; 2 | pub mod emitter; 3 | 4 | pub use self::emitter::Emitter; 5 | pub use self::diagnostic::Diag; -------------------------------------------------------------------------------- /trsc/src/main.rs: -------------------------------------------------------------------------------- 1 | #![feature(iterator_flatten)] 2 | #![feature(transpose_result)] 3 | #![feature(box_syntax)] 4 | #![feature(box_patterns)] 5 | #![feature(custom_attribute)] 6 | #![feature(attr_literals)] 7 | 8 | /// How it works: 9 | /// 1. PEG parser parses into token tree. The downside of PEG parser is that 10 | /// it is mostly magic, which means either it works or not, very difficult 11 | /// to debug or rigorously test other than trial and error. The Pest crate handles 12 | /// lexing and parsing in this compiler. 13 | /// 14 | /// 2. Parses token tree into untyped AST. This constructs a simple traversable tree 15 | /// structure for quality of life. The typing step might as well be merged to this part. 16 | /// 17 | /// 3. Annotate untyped AST into typed AST for type inference and reconstruction. The 18 | /// idea is to annotate each tree node with a dummy type variable. 19 | /// 20 | /// 4. Hindley-Milner type inference for type reconstruction. This is consisted 21 | /// of a few substeps. 22 | /// 23 | /// a. Collect constraints. (handled in constraint.rs) 24 | /// In this step, traverse typed ast and collect types of adjacent nodes that should 25 | /// be equivalent. This generates a Constraint struct which is just a thin wrapper 26 | /// around a btreeset of (Type, Type) tuple. 27 | /// 28 | /// b. Unify constraints by generating substitutions. 29 | /// This is a variant to Algorithm W in H-M type inference. Bascially, unify_one 30 | /// function tries to replace 1 type var with a concrete type. The parent function, unify, 31 | /// then uses that substitution on the rest of the constraints, thus eliminating the type 32 | /// variable from the constraint set. The process is iterated until one of these conditions are met: 33 | /// a) all type variable are exhausted. b) equivalence that can never happen. c) circular 34 | /// type dependence (handled by occurs check). 35 | /// 36 | /// c. Generate Substitutions 37 | /// Now after the unification is complete, the function returns a list of substitutions that 38 | /// should remove all type variables from the typed AST. 39 | /// 40 | /// 5. code gen // ... todo 41 | /// 42 | /// 43 | /// A note about `Span`s: Span contains the location in the source code for 44 | /// error reporting. Think of it as a lightweight tag that can be associated with 45 | /// data structures such as AST nodes, types, etc... 46 | /// 47 | 48 | #[macro_use] 49 | extern crate trsc_core_derive; 50 | extern crate pest; 51 | #[macro_use] 52 | mod typing; 53 | #[macro_use] 54 | extern crate pest_derive; 55 | #[macro_use] 56 | extern crate maplit; 57 | 58 | extern crate codespan; 59 | extern crate clap; 60 | extern crate codespan_reporting; 61 | 62 | mod core; 63 | mod parsing; 64 | mod span; 65 | mod errors; 66 | mod codegen; 67 | 68 | 69 | use typing::constraint::Constraints; 70 | use typing::unifier::Unifier; 71 | use typing::annotate::Annotator; 72 | use codegen::pytorch::Generator; 73 | use typing::type_env::TypeEnv; 74 | use typing::Type; 75 | use typing::inferred_ast::subs; 76 | use errors::{Emitter}; 77 | use parsing::ast_builder::ASTBuilder; 78 | use span::CSpan; 79 | 80 | use std::rc::Rc; 81 | use std::cell::RefCell; 82 | use std::fs::File; 83 | use std::io::Read; 84 | use std::process::exit; 85 | 86 | use codespan::CodeMap; 87 | use clap::{Arg, App, ArgMatches}; 88 | 89 | fn get_matches<'a>() -> ArgMatches<'a> { 90 | App::new("tsrc") 91 | .version("0.0") 92 | .author("Ricky Han ") 93 | .about("Compiler for Tensorscript") 94 | .arg(Arg::with_name("input") 95 | .short("f") 96 | .long("in") 97 | .value_name("FILE") 98 | .help("Sets a custom input file") 99 | .takes_value(true) 100 | .required(true)) 101 | .arg(Arg::with_name("print_ast") 102 | .long("print-ast") 103 | .help("Prints AST")) 104 | .get_matches() 105 | } 106 | 107 | fn main() { 108 | // --------------- get command line options ----------------- 109 | let matches = get_matches(); 110 | let print_ast = matches.is_present("print_ast"); 111 | let fname = matches.value_of("input").unwrap(); 112 | let mut file = File::open(fname).expect("Unable to open the file"); 113 | let mut src = String::new(); 114 | file.read_to_string(&mut src).expect("Unable to read the file"); 115 | // -------------------- create emitter -------------------- 116 | let mut code_map = CodeMap::new(); 117 | let file_map = code_map.add_filemap(fname.to_owned().into(), src.clone()); 118 | let emitter = Rc::new(RefCell::new(Emitter::new(code_map, print_ast))); 119 | // --------------- parse into untyped ast ----------------- 120 | let cspan = CSpan::new(file_map.span()); 121 | let builder = ASTBuilder::new(Rc::clone(&emitter), cspan); 122 | let parsed_terms = builder.parse_str(&src); 123 | let program = parsed_terms 124 | .unwrap_or_else(||{ emitter.borrow().print_errs(); exit(-1); }); 125 | // ------------- annotate ast with type vars -------------- 126 | let core = Rc::new(RefCell::new(core::Core::new())); 127 | let tenv = Rc::new(RefCell::new(TypeEnv::new(core.clone()))); 128 | let annotator = Annotator::new(Rc::clone(&emitter), Rc::clone(&tenv)); 129 | let ast = annotator.annotate(&program); 130 | emitter.borrow().print_errs(); 131 | // println!("{:#?}", ast); 132 | // println!("initial tenv: {:#?}", tenv); 133 | // ------------ first unitfication pass --------------- 134 | let mut cs = Constraints::new(Rc::clone(&emitter), Rc::clone(&tenv)); 135 | cs.collect(&ast); 136 | let mut unifier = Unifier::new(Rc::clone(&emitter), Rc::clone(&tenv)); 137 | let mut last_sub = unifier.unify(cs.clone()); 138 | emitter.borrow().print_errs(); 139 | // println!("{:#?}", last_sub); 140 | 141 | // ------------ resolve module constraints until it stabilizes ---------- 142 | let mut last_ast = subs(&ast, &mut last_sub);; 143 | let em_clone = emitter.clone(); 144 | let tenv_clone = tenv.clone(); 145 | let resolve_ast = move || { 146 | let mut i = 0; 147 | loop { 148 | // collect constraints 149 | let mut new_cs = Constraints::new(Rc::clone(&em_clone), Rc::clone(&tenv_clone)); 150 | new_cs.collect(&last_ast); 151 | em_clone.borrow().print_errs(); 152 | // unify constraints 153 | let mut new_unifier = Unifier::new(Rc::clone(&em_clone), Rc::clone(&tenv_clone)); 154 | let mut new_sub = new_unifier.unify(new_cs.clone()); 155 | em_clone.borrow().print_errs(); 156 | let temp_ast = subs(&last_ast, &mut new_sub); 157 | if temp_ast != last_ast { 158 | last_ast = temp_ast; 159 | i += 1; 160 | if i > 1_000_000 { 161 | println!("Error: does not halt"); 162 | exit(1); 163 | } 164 | 165 | continue; 166 | } 167 | return last_ast; 168 | } 169 | }; 170 | let final_ast = resolve_ast(); 171 | if print_ast { 172 | println!("{:#?}", final_ast); 173 | exit(0); 174 | } 175 | // ---------------------------- code gen ----------------------------------- 176 | let mut generator = Generator::new(emitter.clone(), tenv.clone(), core.clone()); 177 | generator.generate(&final_ast).unwrap(); 178 | println!("{}", generator.buf); 179 | } 180 | -------------------------------------------------------------------------------- /trsc/src/parsing/grammar.rs: -------------------------------------------------------------------------------- 1 | #[derive(Parser)] 2 | #[grammar = "tensorscript.pest"] 3 | pub struct TensorScriptParser; 4 | -------------------------------------------------------------------------------- /trsc/src/parsing/macros.rs: -------------------------------------------------------------------------------- 1 | macro_rules! err { 2 | ($msg:expr, $span:expr) => { 3 | Diag::ParseError( 4 | $msg.to_owned(), 5 | $span, 6 | ) 7 | }; 8 | } 9 | 10 | macro_rules! eat { 11 | ($tokens:expr, $err:expr, $span:expr) => { 12 | { 13 | let t = $tokens.next(); 14 | t.ok_or(err!($err, $span)) 15 | } 16 | }; 17 | 18 | 19 | ($tokens:expr, $rule:ident, $err:expr, $span:expr) => { 20 | { 21 | let t = $tokens.next(); 22 | t 23 | .ok_or(err!($err, $span)) 24 | .and_then(|val| { 25 | if Rule::$rule != val.as_rule() { 26 | Err(err!( 27 | &format!("Type is not {:?}", $rule), 28 | $span 29 | )) 30 | } else { 31 | Ok(val) 32 | } 33 | }) 34 | } 35 | }; 36 | 37 | ($tokens:expr, [$( $rule:ident ),+], $err:expr) => { 38 | $tokens.next() 39 | .ok_or(err!($err)) 40 | .and_then(|val| { 41 | $( 42 | if Rule::$rule == val.as_rule() { 43 | return Ok(val); 44 | } 45 | )* 46 | return Err(err!("Type is wrong")) 47 | }) 48 | }; 49 | } 50 | 51 | macro_rules! to_idents { 52 | ($ident_list:expr) => { 53 | $ident_list 54 | .into_inner() 55 | .map(|id| id.as_str()) 56 | .map(String::from) 57 | .collect() 58 | }; 59 | } -------------------------------------------------------------------------------- /trsc/src/parsing/mod.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | mod macros; 3 | pub mod ast_builder; 4 | pub mod grammar; 5 | pub mod term; 6 | -------------------------------------------------------------------------------- /trsc/src/parsing/term.rs: -------------------------------------------------------------------------------- 1 | use codespan::ByteSpan; 2 | /// Data structures for untyped AST. 3 | /// 4 | use std::fmt::{Display, Error, Formatter}; 5 | 6 | type Expression = Box; 7 | type Statements = Box; 8 | 9 | #[derive(Debug, PartialEq, Clone)] 10 | pub enum Term { 11 | None, 12 | /// a vector of decls 13 | Program(Vec), 14 | Integer(i64, ByteSpan), 15 | Float(f64, ByteSpan), 16 | List(Vec), 17 | Ident(String, ByteSpan), 18 | ViewFn(ViewFn), 19 | FieldAccess(FieldAccess), 20 | FnApp(FnApp), 21 | Block { 22 | stmts: Statements, 23 | ret: Expression, 24 | span: ByteSpan, 25 | }, 26 | Expr(Box, ByteSpan), 27 | Stmt(Box, ByteSpan), 28 | Pipes(Vec), 29 | Tuple(Vec, ByteSpan), 30 | } 31 | 32 | // impl Term { 33 | // pub fn span(&self) -> ByteSpan { 34 | // use self::Term::*; 35 | // match self { 36 | // None, 37 | // /// a vector of decls 38 | // Program(Vec), 39 | // Integer(i64), 40 | // Float(f64), 41 | // List(Vec), 42 | // Ident(String, ByteSpan), 43 | // ViewFn(ViewFn), 44 | // FieldAccess(FieldAccess), 45 | // FnApp(FnApp), 46 | // Block { 47 | // stmts: Statements, 48 | // ret: Expression, 49 | // span: ByteSpan, 50 | // }, 51 | // Expr { 52 | // items: Box, 53 | // span: ByteSpan, 54 | // }, 55 | // Stmt { 56 | // items: Box, 57 | // span: ByteSpan, 58 | // }, 59 | // Pipes(Vec, ), 60 | // _ => unimplemented!(), 61 | // } 62 | // } 63 | // } 64 | 65 | #[derive(Debug, PartialEq, Clone)] 66 | pub enum Decl { 67 | NodeDecl(NodeDecl), 68 | WeightsDecl(WeightsDecl), 69 | GraphDecl(GraphDecl), 70 | UseStmt(UseStmt), 71 | AliasAssign(AliasAssign), 72 | } 73 | 74 | #[derive(Debug, PartialEq, Clone)] 75 | pub struct UseStmt { 76 | pub mod_name: String, 77 | pub imported_names: Vec, 78 | pub span: ByteSpan, 79 | } 80 | 81 | #[derive(Debug, PartialEq, Clone)] 82 | pub struct NodeDecl { 83 | pub name: String, 84 | pub ty_sig: FnTySig, 85 | pub defs: Vec, 86 | pub span: ByteSpan, 87 | } 88 | 89 | #[derive(Debug, PartialEq, Clone)] 90 | pub struct GraphDecl { 91 | pub name: String, 92 | pub ty_sig: FnTySig, 93 | pub fns: Vec, 94 | pub span: ByteSpan, 95 | } 96 | 97 | #[derive(Debug, PartialEq, Clone)] 98 | pub struct WeightsDecl { 99 | pub name: String, 100 | pub ty_sig: FnTySig, 101 | pub inits: Vec, 102 | pub span: ByteSpan, 103 | } 104 | 105 | #[derive(Debug, PartialEq, Clone)] 106 | pub struct FnDeclParam { 107 | pub name: String, 108 | pub ty_sig: TensorTy, 109 | pub span: ByteSpan, 110 | } 111 | 112 | #[derive(Debug, PartialEq, Clone)] 113 | pub struct FieldAccess { 114 | pub mod_name: String, 115 | pub field_name: String, 116 | pub func_call: Option>, 117 | pub span: ByteSpan, 118 | } 119 | 120 | #[derive(Debug, PartialEq, Clone)] 121 | pub struct FnApp { 122 | pub name: String, 123 | pub args: Vec, 124 | pub span: ByteSpan, 125 | } 126 | 127 | #[derive(Debug, PartialEq, Clone)] 128 | pub struct FnAppArg { 129 | pub name: String, 130 | pub arg: Box, 131 | pub span: ByteSpan, 132 | } 133 | 134 | #[derive(Debug, PartialEq, Clone)] 135 | pub struct WeightsAssign { 136 | pub name: String, 137 | pub mod_name: String, 138 | pub fn_name: String, 139 | pub mod_sig: Option, 140 | pub fn_args: Vec, 141 | pub span: ByteSpan, 142 | } 143 | 144 | #[derive(Debug, PartialEq, Clone)] 145 | pub struct FnTySig { 146 | pub from: TensorTy, 147 | pub to: TensorTy, 148 | } 149 | 150 | #[derive(Debug, PartialEq, Clone)] 151 | pub struct FnDecl { 152 | pub name: String, 153 | pub fn_params: Option>, 154 | pub return_ty: Option, 155 | pub func_block: Box, 156 | pub span: ByteSpan, 157 | } 158 | 159 | #[derive(Debug, PartialEq, Clone)] 160 | pub enum AliasAssign { 161 | Dimension { 162 | ident: String, 163 | rhs: Term, 164 | span: ByteSpan, 165 | }, 166 | Tensor { 167 | ident: String, 168 | rhs: TensorTy, 169 | span: ByteSpan, 170 | }, 171 | } 172 | 173 | #[derive(Debug, PartialEq, Clone)] 174 | pub enum TensorTy { 175 | Tensor(String, ByteSpan), 176 | Generic(Vec, ByteSpan), 177 | } 178 | 179 | #[derive(Debug, PartialEq, Clone)] 180 | pub struct ViewFn { 181 | pub dims: Vec, 182 | pub span: ByteSpan, 183 | } 184 | 185 | impl Term { 186 | // pub fn is(&self, var: &Self) -> bool { 187 | // ::std::mem::discriminant(self) == ::std::mem::discriminant(var) 188 | // } 189 | 190 | // pub fn is_UseStmt(&self) -> bool { 191 | // self.is(&Term::UseStmt { 192 | // mod_name: format!(""), 193 | // imported_names: vec![], 194 | // }) 195 | // } 196 | 197 | // pub fn to_list(&self) -> Option> { 198 | // if let &Term::List(ref vs) = self { 199 | // Some(vs.to_vec()) 200 | // } else { 201 | // None 202 | // } 203 | // } 204 | 205 | // /// args is List(Arg) 206 | // pub fn extend_arg_list(func: FnApp, init: Term) -> Vec { 207 | // let mut new_arg_vec = vec![ 208 | // FnAppArg { 209 | // name: format!("x"), 210 | // arg: Box::new(init), 211 | // }, 212 | // ]; 213 | // new_arg_vec.extend(func.args); 214 | // new_arg_vec 215 | // } 216 | } 217 | 218 | impl Display for Term { 219 | fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { 220 | write!(f, "{:#?}", self) 221 | } 222 | } 223 | 224 | // #[derive(Debug, PartialEq, Clone)] 225 | // pub enum Op { 226 | // Expo, 227 | // Mult, 228 | // Div, 229 | // Mod, 230 | // Add, 231 | // Sub, 232 | // ShL, 233 | // ShR, 234 | // BAnd, 235 | // BOr, 236 | // BXor, 237 | // Lt, 238 | // LtE, 239 | // Gt, 240 | // GtE, 241 | // Eq, 242 | // NotEq, 243 | // And, 244 | // Or, 245 | // Assign, 246 | // } 247 | -------------------------------------------------------------------------------- /trsc/src/span.rs: -------------------------------------------------------------------------------- 1 | use codespan::{ByteIndex, Span, ByteOffset}; 2 | use pest::Span as PestSpan; 3 | use codespan::ByteSpan; 4 | 5 | /// a ByteSpan is an index into source code 6 | pub struct CSpan { 7 | sp: ByteSpan, 8 | } 9 | 10 | impl CSpan { 11 | pub fn new(c: ByteSpan) -> CSpan { 12 | CSpan { 13 | sp: c, 14 | } 15 | } 16 | 17 | pub fn fresh_span() -> ByteSpan { 18 | // span can be any because it's taken into account for hashing 19 | Span::new(ByteIndex(0), ByteIndex(0)) 20 | } 21 | 22 | pub fn convert_span(&self, sp: &PestSpan) -> ByteSpan { 23 | // Span::new(ByteIndex(sp.start() as u32 + 1), ByteIndex(sp.end() as u32 + 1)) 24 | self.sp.subspan(ByteOffset(sp.start() as i64), ByteOffset(sp.end() as i64)) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /trsc/src/tensorscript.pest: -------------------------------------------------------------------------------- 1 | whitespace = _{ space | newline } 2 | comment = _{ line_comment } 3 | line_comment = _{ "//" ~ (!newline ~ any)* } 4 | 5 | semicolon = { ";" } 6 | 7 | newline = _{ "\n" | "\r\n" } 8 | space = _{ " " | "\t" } 9 | // keyword 10 | unspecified_dim_lit = @{ "_" } 11 | batch_lit = @{ "?" } 12 | dim_lit = _{ "dim" } 13 | tsr_lit = _{ "tsr" } 14 | node_lit = _{ "node" } 15 | view_lit = _{ "view" } 16 | weights_lit = _{ "weights" } 17 | graph_lit = _{ "graph" } 18 | fn_lit = _{ "def" } 19 | break_lit = { "break" } 20 | const_lit = { "const" } 21 | continue_lit = { "continue" } 22 | crate_lit = { "crate" } 23 | else_lit = { "else" } 24 | enum_lit = { "enum" } 25 | false_lit = { "false" } 26 | for_lit = { "for" } 27 | if_lit = { "if" } 28 | let_lit = { "let" } 29 | match_lit = { "match" } 30 | mod_lit = { "mod" } 31 | move_lit = { "move" } 32 | return_lit = { "return" } 33 | self_lit = { "self" } 34 | true_lit = { "true" } 35 | ty_lit = { "type" } 36 | use_lit = { "use" } 37 | where_lit = { "where" } 38 | while_lit = { "while" } 39 | print_lit = { "print" } 40 | keyword = { unspecified_dim_lit| batch_lit | dim_lit | tsr_lit | node_lit | weights_lit | graph_lit | view_lit | break_lit | const_lit | 41 | continue_lit | crate_lit | else_lit | enum_lit | true_lit | false_lit | 42 | fn_lit | for_lit | if_lit | let_lit | match_lit | mod_lit | move_lit | 43 | return_lit | self_lit | ty_lit | use_lit | 44 | where_lit | while_lit | print_lit } 45 | 46 | 47 | binary_op = _{ 48 | op_expo | 49 | op_mult | 50 | op_div | 51 | op_mod | 52 | op_add | 53 | op_sub | 54 | op_bsl | 55 | op_bsr | 56 | op_and | 57 | op_band | 58 | op_or | 59 | op_bor | 60 | op_bxor | 61 | op_lte | 62 | op_lt | 63 | op_gte | 64 | op_gt | 65 | op_eq | 66 | op_ne | 67 | op_assign 68 | } 69 | op_expo = { "**" } 70 | op_mult = { "*" } 71 | op_not = { "!" } 72 | op_div = { "/" } 73 | op_mod = { "%" } 74 | op_add = { "+" } 75 | op_sub = { "-" } 76 | op_bsl = { "<<" } 77 | op_bsr = { ">>" } 78 | op_band = { "&" } 79 | op_bor = { "|" } 80 | op_bxor = { "^" } 81 | op_lt = { "<" } 82 | op_lte = { "<=" } 83 | op_gt = { ">" } 84 | op_gte = { ">=" } 85 | op_eq = { "==" } 86 | op_ne = { "!=" } 87 | op_and = { "&&" } 88 | op_or = { "||" } 89 | op_assign = { "=" } 90 | 91 | 92 | literal = _{ 93 | num_lit | 94 | bool_lit 95 | } 96 | 97 | // bool 98 | bool_lit = { true_lit | false_lit } 99 | // int 100 | digit = _{ '0'..'9' } 101 | int_lit = @{ digit ~ (digit | "_")* } 102 | plus = _{ "+" } 103 | // float 104 | minus = _{ "-" } 105 | exp = _{ ^"e" ~ (plus | minus)? ~ int_lit } 106 | float_lit = @{ 107 | int_lit ~ "." ~ int_lit? ~ exp? | 108 | int_lit ~ exp 109 | } 110 | num_lit = _{ float_lit | int_lit } 111 | 112 | // ident 113 | lower = _{ 'a'..'z' } 114 | upper = _{ 'A'..'Z' } 115 | alpha = _{ lower | upper } 116 | ident = @{ (!digit ~ (alpha | digit | "_")+ ) | "?" } 117 | ident_list = { ident ~ ("," ~ ident)* ~ ","? } 118 | cap_ident = @{ upper ~ (alpha|digit| "_")* } 119 | upper_ident = @{ (upper|digit|"_")* } 120 | 121 | use_stmt = { use_lit ~ ident ~ "::" ~ ( "{" ~ ident_list ~ "}" | ident ) ~ semicolon} 122 | 123 | // type signature 124 | 125 | ty_ident = @{ (alpha | digit | "?" | "_")+ } 126 | ty_ident_list = { ty_ident ~ ("," ~ ty_ident)* } 127 | fn_ty_sig = { "<" ~ tensor_ty ~ "->" ~ tensor_ty ~ ">" } 128 | ty_sig = { "<"? ~ tensor_ty ~ ">"? } 129 | tensor_ty_sig = _{ "[" ~ ty_ident_list ~ "]" } 130 | tensor_ty = _{ tensor_alias_ty | tensor_ty_sig } 131 | tensor_alias_ty = _{ ident } 132 | 133 | dim_assign = { dim_lit ~ ( ident | batch_lit ) ~ op_assign ~ int_lit ~ semicolon } 134 | tsr_assign = { tsr_lit ~ ident ~ op_assign ~ tensor_ty ~ semicolon } 135 | node_assign = { dim_assign | tsr_assign } 136 | node_decl_body = { "{" ~ node_assign* ~ "}" } 137 | node_decl_head = { node_lit ~ cap_ident ~ fn_ty_sig } 138 | node_decl = { node_decl_head ~ node_decl_body } 139 | 140 | 141 | weights_assign = { ident ~ op_assign ~ 142 | cap_ident ~ ("::" ~ fn_ty_sig)? ~ "::" ~ fn_app ~ semicolon 143 | } 144 | weights_decl_body = { "{" ~ weights_assign* ~ "}" } 145 | weights_decl_head = { weights_lit ~ cap_ident ~ fn_ty_sig } 146 | weights_decl = { weights_decl_head ~ weights_decl_body } 147 | 148 | graph_decl_body = { "{" ~ fn_decls ~ "}" } 149 | graph_decl_head = { graph_lit ~ cap_ident ~ fn_ty_sig } 150 | graph_decl = { graph_decl_head ~ graph_decl_body } 151 | 152 | 153 | 154 | while_loop = { while_lit ~ expr ~ block } 155 | 156 | conditional = { "if" ~ expr ~ block ~ (op_else_if ~ expr ~ block)* ~ (op_else ~ block)? } 157 | op_else_if = { "else if" } 158 | op_else = { "else" } 159 | 160 | 161 | fn_decls = { fn_decl* } 162 | fn_decl_param = { ("(" ~ ")") | ("(" ~ fn_decl_params ~ ")") } 163 | fn_decl_sig = { fn_decl_param ~ ("->" ~ ty_sig)? } 164 | fn_decl_params = { fn_decl_arg ~ ("," ~ fn_decl_arg)* } 165 | fn_decl_arg = { ident ~ (":" ~ ty_sig)? } 166 | fn_decl_head = { fn_lit ~ ident ~ fn_decl_sig? } 167 | fn_decl = { fn_decl_head ~ block } 168 | 169 | fn_app_param = { ("(" ~ ")") | ("(" ~ fn_app_args ~ ")") } 170 | fn_app_arg_pair = { ident ~ "=" ~ expr } 171 | fn_app_arg = { ident ~ "=" ~ expr } 172 | fn_app_args = { (fn_app_arg ~ ",")* ~ fn_app_arg? ~ ","? } 173 | fn_app = { ident ~ "(" ~ fn_app_args? ~ ")" } 174 | 175 | 176 | pipes = { expr_item ~ ("|>" ~ expr)+ } 177 | 178 | field_access = { ident ~ "." ~ ident ~ fn_app_param? } 179 | 180 | view_fn = { view_lit ~ "(" ~ view_fn_args ~ ")" } 181 | view_fn_args = _{ ( unspecified_dim_lit | num_lit | ident)? ~ ("," ~ ( unspecified_dim_lit | num_lit |ident))* ~ ","? } 182 | 183 | tuple = { "(" ~ (expr ~ ",")* ~ expr? ~ ","? ~ ")" } 184 | expr_item = _{ view_fn | field_access | literal | bool_not | fn_app | ident | conditional | tuple } 185 | expr = { expr_item ~ !"|>" | pipes } 186 | 187 | bool_not = _{ op_not ~ expr } 188 | 189 | // This allows {} and {statement; statement; statement;} and {statement; expr} and {expr} 190 | block = { "{" ~ stmts ~ expr? ~ "}" } 191 | stmts = { stmt* } 192 | 193 | stmt = { assignment | while_loop | conditional | (expr ~ semicolon) | comment } 194 | 195 | assignment = { ident ~ op_assign ~ expr ~ semicolon } 196 | 197 | 198 | 199 | input = _{ soi ~ items ~ eoi } 200 | items = _{ item* } 201 | item = _{ use_stmt | graph_decl | weights_decl | node_decl | dim_assign | tsr_assign } 202 | 203 | -------------------------------------------------------------------------------- /trsc/src/typing/constraint.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeSet; 2 | 3 | use typing::type_env::{Alias, ModName, TypeEnv}; 4 | use typing::typed_term::*; 5 | use typing::Type; 6 | use std::rc::Rc; 7 | use std::process::exit; 8 | use std::cell::RefCell; 9 | use errors::{ Emitter, Diag }; 10 | 11 | use span::CSpan; 12 | 13 | #[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)] 14 | pub struct Equals(pub Type, pub Type); 15 | 16 | #[derive(Debug, Clone)] 17 | pub struct Constraints { 18 | pub set: BTreeSet, 19 | pub emitter: Rc>, 20 | pub tenv: Rc>, 21 | } 22 | 23 | impl Constraints { 24 | 25 | pub fn new(emitter: Rc>, tenv: Rc>) -> Self { 26 | Constraints { 27 | set: BTreeSet::new(), 28 | emitter, 29 | tenv, 30 | } 31 | } 32 | 33 | pub fn is_empty(&self) -> bool { 34 | self.set.is_empty() 35 | } 36 | 37 | fn add(&mut self, a: Type, b: Type) { 38 | // println!("{:?} {:?}", a, b); 39 | self.set.insert(Equals(a, b)); 40 | } 41 | 42 | pub fn collect(&mut self, typed_term: &TyTerm) { 43 | use self::TyTerm::*; 44 | let module = { self.tenv.borrow().module().clone() }; 45 | // println!("{}", typed_term); 46 | match typed_term { 47 | TyProgram(ref decls) => decls 48 | .iter() 49 | .map(|decl| self.collect_decl(&decl)) 50 | .collect(), 51 | TyInteger(_, _, _) => (), 52 | TyFloat(_, _, _) => (), 53 | TyList(ref terms) => terms.iter().map(|t| self.collect(&t)).collect(), 54 | TyTuple(_, ref terms, _) => terms.iter().map(|t| self.collect(&t)).collect(), 55 | TyIdent(ref t, ref name, ref sp) => { 56 | let ty = self.tenv.borrow() 57 | .resolve_type(&module, &name) 58 | .or_else(|| self.tenv.borrow().resolve_type(&ModName::Global, &name)) 59 | .unwrap() // ... todo: 60 | .clone() 61 | .with_span(&sp); 62 | self.add(t.clone(), ty); 63 | } 64 | // &TyFieldAccess(TyFieldAccess), 65 | TyFnApp(ref fn_app) => self.collect_fn_app(&fn_app), 66 | TyBlock { ref stmts, ref ret, .. } => { 67 | self.tenv.borrow_mut().push_scope_collection(&module); 68 | self.collect(&stmts); 69 | self.collect(&ret); 70 | self.tenv.borrow_mut().pop_scope(&module); 71 | } 72 | TyExpr(ref items, ref ty, _) => { 73 | self.collect(&items); 74 | self.add(ty.clone(), items.ty()); 75 | } 76 | TyStmt(ref items, _) => self.collect(&items), 77 | TyNone => (), 78 | _ => { 79 | panic!("{:#?}", typed_term); 80 | } 81 | } 82 | } 83 | fn collect_decl(&mut self, decl: &TyDecl) { 84 | use self::TyDecl::*; 85 | match decl { 86 | TyGraphDecl(d) => self.collect_graph_decl(d), 87 | TyNodeDecl(d) => self.collect_node_decl(d), 88 | TyUseStmt(d) => self.collect_use_stmt(d), 89 | TyWeightsDecl(d) => self.collect_weights_decl(d), 90 | TyAliasAssign(_) => (), 91 | } 92 | self.tenv.borrow_mut().set_module(ModName::Global); 93 | } 94 | 95 | fn collect_graph_decl(&mut self, decl: &TyGraphDecl) { 96 | // type decl should be the same 97 | self.tenv.borrow_mut().set_module(ModName::Named(decl.name.clone())); 98 | let graph_ty_sig = self.tenv.borrow() 99 | .resolve_type(&ModName::Global, &Alias::Variable(decl.name.clone())) 100 | .unwrap() 101 | .clone(); 102 | 103 | self.add( 104 | Type::Module( 105 | decl.name.to_owned(), 106 | Some(box decl.ty_sig.clone()), 107 | decl.span, 108 | ), 109 | graph_ty_sig, 110 | ); 111 | 112 | // collect fn_decls 113 | for f in &decl.fns { 114 | self.collect_fn_decl(&f); 115 | } 116 | } 117 | 118 | fn collect_fn_decl(&mut self, decl: &TyFnDecl) { 119 | let module = self.tenv.borrow().module(); 120 | self.tenv.borrow_mut().push_scope_collection(&module); 121 | 122 | self.collect(&decl.func_block); 123 | self.add(decl.func_block.ty(), decl.ret_ty.clone()); 124 | 125 | 126 | // if decl.name == Alias::Function("forward".to_owned()) { 127 | // panic!("{:?}, {:?}", decl.fn_ty, func); 128 | // } 129 | 130 | self.tenv.borrow_mut().pop_scope(&module); 131 | } 132 | 133 | fn collect_node_decl(&mut self, decl: &TyNodeDecl) { 134 | self.tenv.borrow_mut().set_module(ModName::Named(decl.name.clone())); 135 | // type decl should be the same 136 | let graph_ty_sig = self.tenv.borrow().resolve_type(&ModName::Global, &Alias::Variable(decl.name.clone())) 137 | .unwrap() 138 | .clone(); 139 | self.add( 140 | Type::Module( 141 | decl.name.to_owned(), 142 | Some(box decl.ty_sig.clone()), 143 | decl.span, 144 | ), 145 | graph_ty_sig, 146 | ); 147 | } 148 | 149 | fn collect_weights_decl(&mut self, decl: &TyWeightsDecl) { 150 | self.tenv.borrow_mut().set_module(ModName::Named(decl.name.clone())); 151 | // type decl should be the same 152 | let graph_ty_sig = self.tenv.borrow().resolve_type(&ModName::Global, &Alias::Variable(decl.name.clone())) 153 | .unwrap() 154 | .clone(); 155 | self.add( 156 | Type::Module( 157 | decl.name.to_owned(), 158 | Some(box decl.ty_sig.clone()), 159 | decl.span, 160 | ), 161 | graph_ty_sig, 162 | ); 163 | 164 | // collect weight assigns 165 | for w in &decl.inits { 166 | self.collect_weights_assign(&w); 167 | } 168 | } 169 | 170 | fn collect_use_stmt(&mut self, _decl: &TyUseStmt) { 171 | () 172 | } 173 | 174 | fn collect_weights_assign(&mut self, w_a: &TyWeightsAssign) { 175 | let mod_name = &w_a.mod_name; 176 | // convert into a fn_app and collect on `self.new` method 177 | let ret_ty = self.tenv.borrow_mut().fresh_var(w_a.span); 178 | self.collect_fn_app( 179 | &TyFnApp { 180 | mod_name: Some(mod_name.to_string()), 181 | orig_name: None, 182 | name: Alias::Function("new".to_owned()), 183 | arg_ty: w_a.arg_ty.clone(), 184 | ret_ty, 185 | args: w_a.fn_args.clone(), 186 | span: w_a.span, 187 | }, 188 | ); 189 | } 190 | 191 | fn collect_fn_app(&mut self, fn_app: &TyFnApp) { 192 | let current_mod = self.tenv.borrow().module(); 193 | // println!("{:#?}", fn_app); 194 | // println!("{}", fn_app.name); 195 | // println!("{:#?}", cs); 196 | 197 | let symbol_name = fn_app.mod_name.clone().unwrap(); 198 | let symbol_mod_ty = match &fn_app.orig_name { 199 | Some(ref orig_name) => self.tenv.borrow() 200 | .resolve_type( 201 | ¤t_mod, 202 | &Alias::Variable(orig_name.clone()) 203 | ) 204 | .or_else(|| self.tenv.borrow() 205 | .resolve_type(&ModName::Global, &Alias::Variable(orig_name.clone())) 206 | ) 207 | .unwrap() 208 | .clone(), 209 | None => { 210 | let resolved_ty = self.tenv.borrow() 211 | .resolve_type( 212 | ¤t_mod, 213 | &Alias::Variable(symbol_name.clone()) 214 | ) 215 | .or_else(|| self.tenv.borrow() 216 | .resolve_type(&ModName::Global, &Alias::Variable(symbol_name.clone())) 217 | ); 218 | match resolved_ty { 219 | Some(ty) => ty, 220 | None => { 221 | let e = Diag::SymbolNotFound(symbol_name.to_owned(), fn_app.span()); 222 | self.emitter.borrow_mut().add(e); 223 | self.emitter.borrow().print_errs(); 224 | exit(-1); 225 | } 226 | } 227 | } 228 | }; 229 | 230 | let symbol_modname = ModName::Named(symbol_mod_ty.as_string()); // Linear 231 | let fn_name = &fn_app.name; // F(forward) 232 | let resolved_ty = self.tenv.borrow().resolve_type(&symbol_modname, &fn_name) // function / Unresolved 233 | .or_else(|| self.tenv.borrow().resolve_type(&ModName::Global, &fn_name)); 234 | let ty = match resolved_ty { 235 | Some(ty) => ty, 236 | None => { 237 | let e = Diag::SymbolNotFound(fn_name.as_str().to_owned(), fn_app.span); 238 | self.emitter.borrow_mut().add(e); // ... 239 | return; 240 | } 241 | }; 242 | 243 | // println!( 244 | // "{:?} | {:?} | {} | {:?} | {:?} | {:?} ", 245 | // ty, fn_app.orig_name, symbol_name, symbol_mod_ty, symbol_modname, fn_name 246 | // ); 247 | 248 | if let Type::UnresolvedModuleFun(..) = ty { 249 | let resolution = if fn_app.orig_name.is_none() { // this is a weight assign fn 250 | // println!("{:?}, {:?}", &fn_app.mod_name.clone().unwrap().as_str(), fn_app.name); 251 | self.tenv.borrow_mut().resolve_unresolved( 252 | &ty, 253 | &fn_app.name.as_str(), 254 | fn_app.arg_ty.clone(), 255 | fn_app.ret_ty.clone(), 256 | fn_app.args.clone(), 257 | None 258 | ) 259 | } else { 260 | let inits = self.tenv.borrow().resolve_init(¤t_mod, &fn_app.orig_name.clone().unwrap()); 261 | self.tenv.borrow_mut().resolve_unresolved( 262 | &ty, 263 | fn_name.as_str(), 264 | fn_app.arg_ty.clone(), 265 | fn_app.ret_ty.clone(), 266 | fn_app.args.clone(), 267 | inits 268 | ) 269 | }; 270 | 271 | match resolution { 272 | Ok(Some((resolved_fn_ty, is_stateful))) => { 273 | self.add( 274 | resolved_fn_ty.clone(), 275 | fun!( 276 | symbol_name, 277 | fn_app.name.as_str(), 278 | fn_app.arg_ty.clone(), 279 | fn_app.ret_ty.clone() 280 | ) 281 | ); 282 | // set alias for symbol if stateful 283 | if is_stateful { 284 | unsafe { 285 | // println!("{:#?}", self.tenv); 286 | let ty = match resolved_fn_ty { 287 | Type::FUN(m,n,a,r,s) => Type::FUN(m,n, box a.first_arg_ty().unwrap(),r,s), 288 | _ => unimplemented!(), 289 | }; 290 | let sp = ty.span(); 291 | 292 | if fn_app.orig_name.is_some() { 293 | // only replace forward calls 294 | self.tenv.borrow_mut().replace_type( 295 | ¤t_mod, 296 | &Alias::Variable(fn_app.orig_name.clone().unwrap().to_owned()), 297 | Type::Module(symbol_name.to_owned(), Some(box ty), sp), 298 | ); 299 | } 300 | // println!("{:#?}", self.tenv); 301 | // panic!(); 302 | } 303 | } 304 | } 305 | Ok(None) => 306 | (), 307 | Err(e) => { 308 | self.emitter.borrow_mut().add(e); 309 | } 310 | } 311 | } 312 | 313 | self.add( 314 | ty.clone(), 315 | fun!(symbol_name, fn_app.name.as_str(), fn_app.arg_ty.clone(), fn_app.ret_ty.clone()), 316 | ); 317 | 318 | self.add(fn_app.arg_ty.clone(), fn_app.args.to_ty(&fn_app.span)); 319 | 320 | if let "forward" = fn_name.as_str() { 321 | if let Type::Module(_, Some(box supplied_ty), _) = symbol_mod_ty { 322 | if let Type::FUN(_,_,box p,box r, _) = supplied_ty { 323 | self.add(fn_app.arg_ty.clone().clone(), 324 | args!(arg!("x",p.clone()))); 325 | self.add(fn_app.ret_ty.clone(), r.clone()); 326 | } 327 | } 328 | } 329 | 330 | for a in &fn_app.args { 331 | self.collect(&a.arg); 332 | } 333 | 334 | } 335 | } 336 | 337 | -------------------------------------------------------------------------------- /trsc/src/typing/inferred_ast.rs: -------------------------------------------------------------------------------- 1 | /// Substitute inferred types back into AST 2 | use self::TyTerm::*; 3 | use typing::typed_term; 4 | use typing::typed_term::*; 5 | use typing::unifier::Substitution; 6 | 7 | pub fn subs(typed_term: &TyTerm, s: &mut Substitution) -> TyTerm { 8 | // println!("{}", typed_term); 9 | match typed_term { 10 | TyProgram(ref decls) => TyProgram(decls.iter().map(|decl| subs_decl(&decl, s)).collect()), 11 | TyInteger(ref ty, ref a, ref sp) => TyInteger(s.apply_ty(&ty), *a, *sp), 12 | TyFloat(ref ty, ref a, ref sp) => TyFloat(s.apply_ty(&ty), *a, *sp), 13 | TyList(ref terms) => TyList(terms.iter().map(|t| subs(&t, s)).collect()), 14 | TyIdent(ref t, ref name, ref span) => TyIdent(s.apply_ty(t), name.clone(), *span), 15 | // // &TyFieldAccess(TyFieldAccess), 16 | TyFnApp(ref fn_app) => TyFnApp(box subs_fn_app(&fn_app, s)), 17 | TyBlock { 18 | ref stmts, 19 | ref ret, 20 | ref span, 21 | } => TyBlock { 22 | stmts: box subs(&stmts, s), 23 | ret: box subs(&ret, s), 24 | span: *span, 25 | }, 26 | TyExpr(ref items, ref ty, ref span) => TyExpr( 27 | box subs(&items, s), 28 | s.apply_ty(ty), 29 | *span, 30 | ), 31 | TyStmt(ref items, ref span) => TyStmt( 32 | box subs(&items, s), 33 | *span, 34 | ), 35 | TyNone => TyNone, 36 | TyTuple(ref ty, ref vs, ref span) => TyTuple( 37 | s.apply_ty(ty), 38 | vs.iter().map(|i|subs(i,s)).collect(), 39 | *span 40 | ), 41 | _ => { 42 | panic!("{:#?}", typed_term); 43 | } 44 | } 45 | } 46 | 47 | fn subs_decl(decl: &TyDecl, s: &mut Substitution) -> TyDecl { 48 | use self::TyDecl::*; 49 | match decl { 50 | TyGraphDecl(d) => TyGraphDecl(subs_graph_decl(d, s)), 51 | TyNodeDecl(d) => TyNodeDecl(subs_node_decl(d, s)), 52 | TyUseStmt(d) => TyUseStmt(subs_use_stmt(d, s)), 53 | TyWeightsDecl(d) => TyWeightsDecl(subs_weights_decl(d, s)), 54 | TyAliasAssign(d) => TyAliasAssign(d.clone()), 55 | } 56 | } 57 | 58 | fn subs_graph_decl(decl: &TyGraphDecl, s: &mut Substitution) -> TyGraphDecl { 59 | TyGraphDecl { 60 | name: decl.name.clone(), 61 | ty_sig: s.apply_ty(&decl.ty_sig), 62 | fns: decl.fns.iter().map(|f| subs_fn_decl(f, s)).collect(), 63 | span: decl.span, 64 | } 65 | } 66 | 67 | fn subs_fn_decl(decl: &TyFnDecl, s: &mut Substitution) -> TyFnDecl { 68 | let mut c = decl.clone(); 69 | c.arg_ty = s.apply_ty(&c.arg_ty); 70 | c.ret_ty = s.apply_ty(&c.ret_ty); 71 | c.func_block = box subs(&c.func_block, s); 72 | c 73 | } 74 | 75 | fn subs_node_decl(decl: &TyNodeDecl, s: &mut Substitution) -> TyNodeDecl { 76 | TyNodeDecl { 77 | name: decl.name.clone(), 78 | ty_sig: s.apply_ty(&decl.ty_sig), 79 | span: decl.span, 80 | } 81 | } 82 | 83 | fn subs_weights_decl(decl: &TyWeightsDecl, s: &mut Substitution) -> TyWeightsDecl { 84 | let mut c = decl.clone(); 85 | c.ty_sig = s.apply_ty(&c.ty_sig); 86 | c.inits = c.inits 87 | .iter() 88 | .map(|w_a| subs_weights_assign(w_a, s)) 89 | .collect::>(); 90 | c 91 | } 92 | 93 | fn subs_use_stmt(decl: &TyUseStmt, _tenv: &mut Substitution) -> TyUseStmt { 94 | decl.clone() 95 | } 96 | 97 | fn subs_weights_assign(w_a: &TyWeightsAssign, s: &mut Substitution) -> TyWeightsAssign { 98 | let mut c = w_a.clone(); 99 | c.arg_ty = s.apply_ty(&c.arg_ty); 100 | c.fn_args = c.fn_args.iter().map(|a| subs_fn_app_arg(a, s)).collect(); 101 | c 102 | } 103 | 104 | fn subs_fn_app(fn_app: &typed_term::TyFnApp, s: &mut Substitution) -> typed_term::TyFnApp { 105 | let mut c = fn_app.clone(); 106 | c.arg_ty = s.apply_ty(&c.arg_ty); 107 | c.ret_ty = s.apply_ty(&c.ret_ty); 108 | c.args = c.args.iter().map(|a| subs_fn_app_arg(&a, s)).collect(); 109 | c 110 | } 111 | 112 | fn subs_fn_app_arg(a: &TyFnAppArg, s: &mut Substitution) -> TyFnAppArg { 113 | TyFnAppArg { 114 | name: a.name.clone(), 115 | arg: box subs(&a.arg, s), 116 | span: a.span, 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /trsc/src/typing/mod.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | pub mod types; 3 | pub mod annotate; 4 | pub mod type_env; 5 | pub mod typed_term; 6 | 7 | pub use self::type_env::TypeEnv; 8 | pub use self::types::Type; 9 | pub mod constraint; 10 | pub mod inferred_ast; 11 | pub mod unifier; 12 | -------------------------------------------------------------------------------- /trsc/src/typing/type_env.rs: -------------------------------------------------------------------------------- 1 | use codespan::ByteSpan; 2 | use core::Core; 3 | use span::CSpan; 4 | use std::rc::Rc; 5 | use std::cell::RefCell; 6 | /// Type Environment holds the state during type reconstruction 7 | /// which is really just a few tree traversals. 8 | /// 9 | /// It handles, in broad strokes, 3 things: 10 | /// 1. Type Aliasing during the first pass (annotate) 11 | /// 2. pushing and popping scopes (during `annotate` and `collect`) 12 | /// 3. module type and method type reconstruction 13 | use parsing::term::{AliasAssign, TensorTy, Term}; 14 | use std::collections::{BTreeMap, VecDeque}; 15 | use std::fmt::{Debug, Error, Formatter}; 16 | use typing::typed_term::TyFnAppArg; 17 | use typing::Type; 18 | use errors::Diag; 19 | use self::ModName::*; 20 | 21 | pub type TypeId = usize; 22 | 23 | #[derive(Clone, Hash, Eq, PartialEq, PartialOrd, Ord)] 24 | pub enum ModName { 25 | Global, 26 | Named(String), 27 | } 28 | 29 | impl ModName { 30 | pub fn as_str(&self) -> &str { 31 | match self { 32 | Global => unimplemented!(), 33 | Named(ref s) => s, 34 | } 35 | } 36 | } 37 | 38 | impl Debug for ModName { 39 | fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { 40 | match self { 41 | Named(ref s) => write!(f, "MOD({})", s), 42 | Global => write!(f, "MOD(Global)"), 43 | } 44 | } 45 | } 46 | 47 | /// Represents a single level of scope 48 | #[derive(Debug)] 49 | pub struct Scope { 50 | /// type information of aliases 51 | types: BTreeMap, 52 | } 53 | 54 | impl Scope { 55 | pub fn new() -> Scope { 56 | Scope { 57 | types: BTreeMap::new(), 58 | } 59 | } 60 | } 61 | 62 | type ScopeStack = VecDeque; 63 | type ScopeQueue = VecDeque; 64 | type InitMap = BTreeMap>; 65 | 66 | #[derive(Debug)] 67 | pub struct TypeEnv { 68 | core: Rc>, 69 | dim_counter: TypeId, 70 | var_counter: TypeId, 71 | current_mod: ModName, 72 | modules: BTreeMap, 73 | } 74 | 75 | #[derive(PartialEq, Eq, Hash, Clone, PartialOrd, Ord)] 76 | pub enum Alias { 77 | Variable(String), 78 | Function(String), 79 | } 80 | 81 | impl Debug for Alias { 82 | fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { 83 | match self { 84 | Alias::Function(a) => write!(f, "F({})", a), 85 | Alias::Variable(a) => write!(f, "V({})", a), 86 | } 87 | } 88 | } 89 | 90 | impl Alias { 91 | pub fn as_str(&self) -> &str { 92 | match self { 93 | Alias::Function(s) => s, 94 | Alias::Variable(s) => s, 95 | } 96 | } 97 | } 98 | 99 | impl TypeEnv { 100 | pub fn new(core: Rc>) -> Self { 101 | let mut ret = Self { 102 | core, 103 | dim_counter: 0, 104 | var_counter: 0, 105 | current_mod: Global, 106 | modules: BTreeMap::new(), 107 | }; 108 | 109 | // import basic functions such as view 110 | ret.import_prelude().unwrap(); 111 | 112 | ret 113 | } 114 | 115 | /// create new dimension type variable 116 | pub fn fresh_dim(&mut self, span: ByteSpan) -> Type { 117 | self.dim_counter += 1; 118 | Type::DIM(self.dim_counter, span) 119 | } 120 | 121 | /// create new type variable 122 | pub fn fresh_var(&mut self, span: ByteSpan) -> Type { 123 | self.var_counter += 1; 124 | Type::VAR(self.var_counter, span) 125 | } 126 | 127 | /// push scope onto stack during tree traversal 128 | pub fn push_scope(&mut self, mod_name: &ModName) { 129 | let stack = self.modules.get_mut(mod_name).unwrap(); 130 | stack.0.push_back(Scope::new()); 131 | } 132 | 133 | /// during constraint collection, push the popped scopes back 134 | pub fn push_scope_collection(&mut self, mod_name: &ModName) { 135 | let stack = self.modules.get_mut(mod_name).unwrap(); 136 | let scp = stack.1.pop_front().unwrap(); 137 | stack.0.push_back(scp); 138 | } 139 | 140 | /// exiting block during tree traversal 141 | pub fn pop_scope(&mut self, mod_name: &ModName) { 142 | let stack = self.modules.get_mut(mod_name).unwrap(); 143 | let popped = stack.0.pop_back().unwrap(); 144 | stack.1.push_back(popped); 145 | } 146 | 147 | pub fn resolve_init(&self, mod_name: &ModName, alias: &str) -> Option> { 148 | let stack = &self.modules[mod_name]; 149 | stack.2.get(alias).cloned() 150 | } 151 | 152 | /// resolve the type of an identifier 153 | /// first check current mod name, if it doesn not exist, 154 | /// then check in the global scope 155 | pub fn resolve_type(&self, mod_name: &ModName, alias: &Alias) -> Option { 156 | self.resolve_type_inner(mod_name, alias) 157 | } 158 | 159 | /// inside the module or global scope, iterate over block scope and find 160 | /// the last defn of the alias which may be shadowed 161 | fn resolve_type_inner(&self, mod_name: &ModName, alias: &Alias) -> Option { 162 | let types = self.get_scoped_types(mod_name, alias); 163 | types.iter().last().cloned() 164 | } 165 | 166 | /// iterate over scopes and find the alias in each 167 | fn get_scoped_types(&self, mod_name: &ModName, alias: &Alias) -> Vec { 168 | let stack = self.modules.get(mod_name) 169 | .expect(&format!("BUG: Unable to find {:?} in {:?}.", alias, mod_name)); 170 | stack 171 | .0 172 | .iter() 173 | .rev() 174 | .map(|sc| sc.types.get(alias)) 175 | .filter(|i| i.is_some()) 176 | .map(|i| i.unwrap()) 177 | .cloned() 178 | .collect() 179 | } 180 | 181 | /// if current module does not exist, create and insert it, nop otherwise 182 | pub fn upsert_module(&mut self, mod_name: &ModName) { 183 | if !self.modules.contains_key(mod_name) { 184 | self.modules.insert(mod_name.clone(), { 185 | // if the module does not yet exist, add with an empty scope 186 | let mut q = VecDeque::new(); 187 | q.push_back(Scope::new()); 188 | (q, VecDeque::new(), BTreeMap::new()) 189 | }); 190 | } 191 | } 192 | 193 | /// add type alias in current scope 194 | pub fn add_type(&mut self, mod_name: &ModName, alias: &Alias, ty: Type) -> Result<(), Diag> { 195 | let stack = self.modules.entry(mod_name.clone()).or_insert({ 196 | // if the module does not yet exist, add with an empty scope 197 | let mut q = VecDeque::new(); 198 | q.push_back(Scope::new()); 199 | (q, VecDeque::new(), BTreeMap::new()) 200 | }); 201 | 202 | let top = stack.0.len() - 1; 203 | let scope = &mut stack.0[top]; 204 | if scope.types.contains_key(alias) { 205 | let orig_ty = &scope.types[alias]; 206 | return Err( 207 | Diag:: 208 | DuplicateVarInScope( 209 | alias.as_str().to_string(), 210 | orig_ty.clone(), 211 | ty, 212 | ) 213 | ) 214 | } 215 | let _ = scope.types.insert(alias.clone(), ty); 216 | 217 | Ok(()) 218 | } 219 | 220 | /// add type alias in current scope(allows replacement) 221 | pub unsafe fn add_type_allow_replace(&mut self, mod_name: &ModName, alias: &Alias, ty: Type) { 222 | let stack = self.modules.entry(mod_name.clone()).or_insert({ 223 | // if the module does not yet exist, add with an empty scope 224 | let mut q = VecDeque::new(); 225 | q.push_back(Scope::new()); 226 | (q, VecDeque::new(), BTreeMap::new()) 227 | }); 228 | 229 | let top = stack.0.len() - 1; 230 | let scope = &mut stack.0[top]; 231 | // if scope.types.contains_key(alias) { 232 | // panic!("duplicate item"); 233 | // } 234 | let _ = scope.types.insert(alias.clone(), ty); 235 | } 236 | 237 | /// replace an alias in scope 238 | pub unsafe fn replace_type(&mut self, mod_name: &ModName, alias: &Alias, ty: Type) { 239 | let stack = self.modules.entry(mod_name.clone()).or_insert({ 240 | // if the module does not yet exist, add with an empty scope 241 | let mut q = VecDeque::new(); 242 | q.push_back(Scope::new()); 243 | (q, VecDeque::new(), BTreeMap::new()) 244 | }); 245 | 246 | for scope in &mut stack.0 { 247 | if scope.types.contains_key(alias) { 248 | let _ = scope.types.insert(alias.clone(), ty.clone()); 249 | } 250 | } 251 | } 252 | 253 | /// add stateful initialization in current scope 254 | pub fn add_init(&mut self, mod_name: &ModName, alias: &str, ty: Vec) { 255 | let stack = self.modules.get_mut(&mod_name).unwrap(); 256 | 257 | if stack.2.contains_key(alias) { 258 | panic!("duplicate item"); 259 | } 260 | let _ = stack.2.insert(alias.to_owned(), ty); 261 | } 262 | 263 | /// tie an alias with a type variable dimension 264 | pub fn add_dim_alias(&mut self, mod_name: &ModName, alias: &Alias, span: ByteSpan) -> Result<(), Diag> { 265 | let tyvar = self.fresh_dim(span); 266 | self.add_type(mod_name, alias, tyvar) 267 | } 268 | 269 | /// tie an alias with a resolved dimension 270 | pub fn add_resolved_dim_alias( 271 | &mut self, 272 | mod_name: &ModName, 273 | alias: &Alias, 274 | num: i64, 275 | span: &ByteSpan, 276 | ) -> Result<(), Diag> { 277 | let tyvar = Type::ResolvedDim(num, *span); 278 | self.add_type(mod_name, alias, tyvar) 279 | } 280 | 281 | /// tie an alias with a tensor 282 | pub fn add_tsr_alias( 283 | &mut self, 284 | mod_name: &ModName, 285 | alias: &Alias, 286 | tsr: &[String], 287 | span: &ByteSpan, 288 | ) -> Result<(), Diag> { 289 | // first insert all the dims 290 | for t in tsr.iter() { 291 | let alias = Alias::Variable(t.to_string()); 292 | if !self.exists(mod_name, &alias) { 293 | self.add_dim_alias(mod_name, &alias, *span)?; 294 | } 295 | } 296 | 297 | // then insert the tensor itself 298 | let tsr = self.create_tensor(mod_name, tsr, span); 299 | self.add_type(mod_name, alias, tsr) 300 | } 301 | 302 | // make a new tensor based on type signature 303 | pub fn create_tensor( 304 | &mut self, 305 | mod_name: &ModName, 306 | dims: &[String], 307 | span: &ByteSpan, 308 | ) -> Type { 309 | // each dimension alias in the tensor type signature must exist 310 | let dims_ty = dims.iter() 311 | .map(|t| { 312 | match t.parse::() { 313 | Ok(i) => vec![Type::ResolvedDim(i, *span)], 314 | Err(_e) => { 315 | let alias = Alias::Variable(t.to_string()); 316 | let ty = self.resolve_type(mod_name, &alias) 317 | .or_else(|| self.resolve_type(&Global, &alias)) 318 | .unwrap_or_else(|| self.fresh_dim(*span)) 319 | .clone(); 320 | if let Type::TSR(vs, _) = ty { 321 | vs 322 | } else { 323 | vec![ty] 324 | } 325 | } 326 | } 327 | }) 328 | .flatten() 329 | .collect(); 330 | // create the tensor type 331 | Type::TSR(dims_ty, *span) 332 | } 333 | 334 | /// generate a tensor from untyped ast tensor signature 335 | pub fn resolve_tensor(&mut self, mod_name: &ModName, t: &TensorTy) -> Type { 336 | match t { 337 | TensorTy::Generic(ref dims, ref sp) => { 338 | self.create_tensor(mod_name, &dims, sp) 339 | } 340 | TensorTy::Tensor(ref alias, ref sp) => { 341 | self.resolve_type(mod_name, &Alias::Variable(alias.to_string())) 342 | .or_else(|| self.resolve_type(&Global, &Alias::Variable(alias.to_string()))) 343 | .unwrap() 344 | .with_span(sp) 345 | } 346 | } 347 | } 348 | 349 | /// check if an alias exists 350 | pub fn exists(&self, mod_name: &ModName, alias: &Alias) -> bool { 351 | let types = self.get_scoped_types(mod_name, alias); 352 | !types.is_empty() 353 | } 354 | 355 | /// create aliases for an untyped AST node assign 356 | pub fn import_node_assign(&mut self, mod_name: &ModName, a: &AliasAssign) -> Result<(), Diag> { 357 | match a { 358 | AliasAssign::Tensor { 359 | ident: ref id, 360 | rhs: TensorTy::Generic(ref tys, ref sp), 361 | .. 362 | } => { 363 | self.add_tsr_alias(mod_name, &Alias::Variable(id.to_string()), tys, sp) 364 | } 365 | AliasAssign::Dimension { 366 | ident: ref id, 367 | rhs: Term::Integer(num, _), 368 | ref span, 369 | } => { 370 | self.add_resolved_dim_alias(mod_name, &Alias::Variable(id.to_string()), *num, span) 371 | } 372 | _ => unimplemented!(), 373 | } 374 | } 375 | 376 | pub fn import_top_level_ty_sig(&mut self, mod_name: &ModName, ty_sig: &TensorTy) -> Result<(), Diag> { 377 | if let TensorTy::Generic(dims, span) = ty_sig { 378 | // first insert all the dims 379 | for t in dims.iter().filter(|t| t.parse::().is_err()) { 380 | let alias = Alias::Variable(t.to_string()); 381 | if !self.exists(mod_name, &alias) { 382 | self.add_dim_alias(mod_name, &alias, *span)?; 383 | } 384 | } 385 | } 386 | 387 | Ok(()) 388 | } 389 | 390 | /// get current module name 391 | pub fn module(&self) -> ModName { 392 | self.current_mod.clone() 393 | } 394 | 395 | /// set current module name 396 | pub fn set_module(&mut self, scp: ModName) { 397 | self.current_mod = scp; 398 | } 399 | 400 | /// import module type and associated methods into type environment 401 | pub fn import_module(&mut self, path_name: &str, mod_name: &str) -> Option> { 402 | let core = self.core.clone(); 403 | let methods = core.borrow().import(path_name, mod_name, self)?; 404 | 405 | Some(methods.iter().map(|(name, ty)| { 406 | self.add_type( 407 | &Named(mod_name.to_owned()), 408 | &Alias::Function(name.to_string()), 409 | ty.clone(), 410 | ) 411 | }) 412 | .collect()) 413 | } 414 | 415 | pub fn import_prelude(&mut self) -> Result<(), Diag> { 416 | for fun in &vec!["view"] { 417 | self.add_type(&Global, 418 | &Alias::Variable(fun.to_string()), 419 | module!(fun.to_string()) 420 | )?; 421 | self.import_module("prelude", fun); 422 | } 423 | Ok(()) 424 | } 425 | 426 | pub fn resolve_unresolved( 427 | &mut self, 428 | ty: &Type, 429 | fn_name: &str, 430 | arg_ty: Type, 431 | ret_ty: Type, 432 | args: Vec, 433 | inits: Option>, 434 | ) -> Result, Diag> { 435 | // let (mod_name, mod_ty) = { 436 | // if let Type::Module(name, opty, _) = module { 437 | // (name, opty.clone().map(|i| *i)) 438 | // } else { 439 | // panic!(); 440 | // } 441 | // }; 442 | 443 | if let Type::UnresolvedModuleFun(ref p0, ref p1, ref p2, ref span) = ty { 444 | assert_eq!(fn_name.to_owned(), p2.to_owned()); 445 | let core_clone = self.core.clone(); 446 | let core = core_clone.borrow(); 447 | let find_result = core.find(p0, p1); 448 | match find_result { 449 | Some(op) => { 450 | let is_stateful = op.is_stateful(); 451 | op.resolve(self, fn_name, arg_ty, ret_ty, args, inits) 452 | .transpose() 453 | .map(|t| 454 | t.map(|i| (i, is_stateful)) 455 | ) 456 | } 457 | None => 458 | Err(Diag::SymbolNotFound(p1.to_string(), *span)), 459 | } 460 | } else { 461 | unimplemented!(); 462 | } 463 | } 464 | } 465 | -------------------------------------------------------------------------------- /trsc/src/typing/typed_term.rs: -------------------------------------------------------------------------------- 1 | /// Data structures for Typed AST 2 | /// 3 | use codespan::ByteSpan; 4 | use span::CSpan; 5 | use std::collections::BTreeMap; 6 | use typing::type_env::Alias; 7 | use typing::Type; 8 | use std::fmt::Write; 9 | 10 | /// typed AST nodes 11 | #[derive(Debug, PartialEq, Clone)] 12 | pub enum TyTerm { 13 | TyNone, 14 | TyProgram(Vec), 15 | TyInteger(Type, i64, ByteSpan), 16 | TyFloat(Type, f64, ByteSpan), 17 | TyList(Vec), 18 | TyIdent(Type, Alias, ByteSpan), 19 | TyFieldAccess(TyFieldAccess), 20 | TyFnApp(Box), 21 | TyTuple(Type, Vec, ByteSpan), 22 | TyBlock { 23 | stmts: Box, 24 | ret: Box, 25 | span: ByteSpan, 26 | }, 27 | TyExpr(Box, Type, ByteSpan), 28 | TyStmt(Box, ByteSpan), 29 | } 30 | 31 | impl TyTerm { 32 | pub fn as_float(&self) -> Option { 33 | match self { 34 | TyTerm::TyFloat(_, i, _) => Some(*i), 35 | TyTerm::TyExpr(ref items, ..) => items.as_float(), 36 | _ => None, 37 | } 38 | } 39 | 40 | pub fn as_num(&self) -> Option { 41 | use self::TyTerm::*; 42 | match self { 43 | TyInteger(_, i, _) => Some(*i), 44 | TyExpr(ref items, ..) => items.as_num(), 45 | TyIdent(ref t, ..) => t.as_num(), 46 | _ => None, 47 | } 48 | } 49 | 50 | pub fn ty(&self) -> Type { 51 | use self::TyTerm::*; 52 | use self::Type::*; 53 | match self { 54 | TyNone => Unit(CSpan::fresh_span()), 55 | TyProgram(_) => Unit(CSpan::fresh_span()), 56 | TyInteger(ref t, _, _) => t.clone(), 57 | TyFloat(ref t, _, _) => t.clone(), 58 | TyList(_) => Unit(CSpan::fresh_span()), 59 | TyIdent(ref t, _, _) => t.clone(), 60 | TyFieldAccess(ref f_a) => f_a.ty(), 61 | TyFnApp(ref f_a) => f_a.ty(), 62 | TyBlock {ref ret, ..} => ret.ty(), 63 | TyExpr(_,ref ty, _) => ty.clone(), 64 | TyStmt(..) => Unit(CSpan::fresh_span()), 65 | TyTuple(ref t, ..) => t.clone(), 66 | } 67 | } 68 | pub fn span(&self) -> ByteSpan { 69 | use self::TyTerm::*; 70 | match self { 71 | TyNone => CSpan::fresh_span(), 72 | TyProgram(_) => CSpan::fresh_span(), 73 | TyInteger(_, _, ref s) => *s, 74 | TyFloat(_, _, ref s) => *s, 75 | TyIdent(_, _, ref s) => *s, 76 | TyFieldAccess(ref f_a) => f_a.span(), 77 | TyFnApp(ref f_a) => f_a.span(), 78 | TyBlock {ref span, ..} => *span, 79 | TyExpr(_, _, ref span) => *span, 80 | TyStmt(_, ref span) => *span, 81 | _ => panic!("{:?}", self), 82 | } 83 | } 84 | 85 | pub fn as_str(&self) -> Option { 86 | use self::TyTerm::*; 87 | let mut s = String::new(); 88 | match self { 89 | TyInteger(..) => write!(s, "{}", self.as_num()?).unwrap(), 90 | TyExpr(ref items, ..) => write!(s, "{}", items.as_str()?).unwrap(), 91 | TyIdent(ref t, ..) => write!(s, "{}", t.as_string()).unwrap(), 92 | TyFloat(_, f, ..) => write!(s, "{}", f).unwrap(), 93 | TyTuple(_, ref ts, _) => { 94 | write!(s, "(").unwrap(); 95 | write!(s, "{}", ts 96 | .iter() 97 | .map(|t| t.as_str()) 98 | .collect::>>()?.join(", ") 99 | ).unwrap(); 100 | write!(s, ")").unwrap() 101 | } 102 | _ => panic!("{:?}", self), 103 | }; 104 | Some(s) 105 | } 106 | } 107 | 108 | impl TyFieldAccess { 109 | pub fn span(&self) -> ByteSpan { 110 | self.span 111 | } 112 | 113 | pub fn ty(&self) -> Type { 114 | self.ty.clone() 115 | } 116 | } 117 | 118 | impl TyFnApp { 119 | pub fn span(&self) -> ByteSpan { 120 | self.span 121 | } 122 | pub fn ty(&self) -> Type { 123 | self.ret_ty.clone() 124 | } 125 | } 126 | 127 | #[derive(Debug, PartialEq, Clone)] 128 | pub enum TyDecl { 129 | TyNodeDecl(TyNodeDecl), 130 | TyWeightsDecl(TyWeightsDecl), 131 | TyGraphDecl(TyGraphDecl), 132 | TyUseStmt(TyUseStmt), 133 | TyAliasAssign(TyAliasAssign), 134 | } 135 | 136 | #[derive(Debug, PartialEq, Clone)] 137 | pub enum TyAliasAssign { 138 | Placeholder, 139 | // Dimension { 140 | // ident: String, 141 | // span: ByteSpan, 142 | // }, 143 | // Tensor { 144 | // ident: String, 145 | // span: ByteSpan, 146 | // }, 147 | } 148 | 149 | #[derive(Debug, PartialEq, Clone)] 150 | pub struct TyUseStmt { 151 | pub mod_name: String, 152 | pub imported_names: Vec, 153 | pub span: ByteSpan, 154 | } 155 | 156 | #[derive(Debug, PartialEq, Clone)] 157 | pub struct TyNodeDecl { 158 | pub name: String, 159 | pub ty_sig: Type, 160 | pub span: ByteSpan, 161 | } 162 | 163 | #[derive(Debug, PartialEq, Clone)] 164 | pub struct TyGraphDecl { 165 | pub name: String, 166 | pub ty_sig: Type, 167 | pub fns: Vec, 168 | pub span: ByteSpan, 169 | } 170 | 171 | #[derive(Debug, PartialEq, Clone)] 172 | pub struct TyWeightsDecl { 173 | pub name: String, 174 | pub ty_sig: Type, 175 | pub inits: Vec, 176 | pub span: ByteSpan, 177 | } 178 | 179 | #[derive(Debug, PartialEq, Clone)] 180 | pub struct TyWeightsAssign { 181 | pub name: String, 182 | pub mod_name: String, 183 | pub fn_name: String, 184 | pub arg_ty: Type, 185 | pub fn_args: Vec, 186 | pub span: ByteSpan, 187 | } 188 | 189 | #[derive(Debug, PartialEq, Clone)] 190 | pub struct TyFnApp { 191 | pub mod_name: Option, 192 | pub orig_name: Option, 193 | pub name: Alias, 194 | pub arg_ty: Type, 195 | pub ret_ty: Type, 196 | pub args: Vec, 197 | pub span: ByteSpan, 198 | } 199 | 200 | impl TyFnApp { 201 | pub fn extend_arg(&mut self, arg: &TyFnAppArg) { 202 | self.args.insert(0, arg.clone()); 203 | let new_args_ty = self.args.to_ty(&self.span); 204 | // self.fn_ty = match &self.fn_ty { 205 | // Type::FUN(_, box r, span) => Type::FUN(box new_args_ty, box r.clone(), span), 206 | // _ => unimplemented!(), 207 | // }; 208 | self.arg_ty = new_args_ty; 209 | } 210 | } 211 | 212 | #[derive(Debug, PartialEq, Clone)] 213 | pub struct TyFnAppArg { 214 | pub name: Option, 215 | pub arg: Box, 216 | pub span: ByteSpan, 217 | } 218 | 219 | pub trait ArgsVecInto { 220 | fn to_ty(&self, span: &ByteSpan) -> Type; 221 | fn to_btreemap(&self) -> Option>>; 222 | } 223 | 224 | impl ArgsVecInto for [TyFnAppArg] { 225 | fn to_ty(&self, span: &ByteSpan) -> Type { 226 | Type::FnArgs( 227 | self.iter() 228 | .map(|t_arg| { 229 | Type::FnArg( 230 | t_arg.name.clone(), 231 | box t_arg.arg.ty().clone(), 232 | t_arg.span, 233 | ) 234 | }) 235 | .collect(), 236 | *span, 237 | ) 238 | } 239 | fn to_btreemap(&self) -> Option>> { 240 | Some( 241 | self.iter() 242 | .filter_map(|a| { 243 | if a.name.is_some() { 244 | Some((a.name.clone().unwrap(), a.arg.clone())) 245 | } else { 246 | None 247 | } 248 | }) 249 | .collect(), 250 | ) 251 | } 252 | } 253 | 254 | impl ArgsVecInto for [TyFnDeclParam] { 255 | fn to_ty(&self, span: &ByteSpan) -> Type { 256 | Type::FnArgs( 257 | self.iter() 258 | .map(|t_arg| { 259 | Type::FnArg( 260 | Some(t_arg.name.clone()), 261 | box t_arg.ty.clone(), 262 | t_arg.span, 263 | ) 264 | }) 265 | .collect(), 266 | *span, 267 | ) 268 | } 269 | fn to_btreemap(&self) -> Option>> { 270 | None 271 | // self.iter().filter_map(|a| 272 | // Some((a.name.clone(), box a.clone())) 273 | // ).collect() 274 | } 275 | } 276 | 277 | #[derive(Debug, PartialEq, Clone)] 278 | pub struct TyFnDecl { 279 | pub name: Alias, 280 | pub fn_params: Vec, 281 | pub arg_ty: Type, // args!() 282 | pub ret_ty: Type, // any type 283 | pub func_block: Box, 284 | pub span: ByteSpan, 285 | } 286 | 287 | #[derive(Debug, PartialEq, Clone)] 288 | pub struct TyFnDeclParam { 289 | pub name: String, 290 | pub ty: Type, 291 | pub span: ByteSpan, 292 | } 293 | 294 | #[derive(Debug, PartialEq, Clone)] 295 | pub struct TyFieldAccess { 296 | pub mod_name: String, 297 | pub field_name: String, 298 | pub ty: Type, 299 | pub span: ByteSpan, 300 | } 301 | -------------------------------------------------------------------------------- /trsc/src/typing/types.rs: -------------------------------------------------------------------------------- 1 | use codespan::ByteSpan; 2 | use std::fmt::{Debug, Error, Formatter}; 3 | /// Types for typed AST 4 | use std::hash::{Hash, Hasher}; 5 | use typing::type_env::TypeId; 6 | use std::collections::BTreeMap; 7 | use typing::type_env::ModName; 8 | 9 | #[derive(Clone, Eq, PartialOrd, Ord)] 10 | pub enum Type { 11 | // literals 12 | Unit(ByteSpan), 13 | INT(ByteSpan), 14 | FLOAT(ByteSpan), 15 | BOOL(ByteSpan), 16 | UnresolvedModuleFun(&'static str, &'static str, &'static str, ByteSpan), 17 | // type variables that need to be resolved 18 | VAR(TypeId, ByteSpan), 19 | DIM(TypeId, ByteSpan), 20 | Tuple(Vec, ByteSpan), 21 | 22 | // recursive types 23 | Module(String, Option>, ByteSpan), 24 | FnArgs(Vec, ByteSpan), 25 | FnArg(Option, Box, ByteSpan), 26 | ResolvedDim(i64, ByteSpan), 27 | FUN(String, String, Box, Box, ByteSpan), 28 | TSR(Vec, ByteSpan), 29 | } 30 | 31 | impl PartialEq for Type { 32 | fn eq(&self, other: &Type) -> bool { 33 | use self::Type::*; 34 | match (self, other) { 35 | (Unit(_), Unit(_)) => true, 36 | (INT(_), INT(_)) => true, 37 | (FLOAT(_), FLOAT(_)) => true, 38 | (BOOL(_), BOOL(_)) => true, 39 | // // UnresolvedModuleFun(_,_,_) => false, 40 | (VAR(a, _), VAR(b, _)) => a == b, 41 | (DIM(b, _), DIM(a, _)) => a == b, 42 | (Module(a1, b1, _), Module(a2, b2, _)) => (a1 == a2) && (b1 == b2), 43 | (FnArgs(ta, _), FnArgs(tb, _)) => ta == tb, 44 | (Tuple(ta, _), Tuple(tb, _)) => ta == tb, 45 | (FnArg(n1, t1, _), FnArg(n2, t2, _)) => (n1 == n2) && (t1 == t2), 46 | (ResolvedDim(a, _), ResolvedDim(b, _)) => a == b, 47 | (FUN(m1, n1, p1, r1, _), FUN(m2, n2, p2, r2, _)) => 48 | (p1 == p2) && (r1 == r2) && (m1 == m2) && (n1 == n2), 49 | (TSR(ts1, _), TSR(ts2, _)) => ts1 == ts2, 50 | (UnresolvedModuleFun(a1, b1, c1, _), UnresolvedModuleFun(a2, b2, c2, _)) => 51 | (a1 == a2) && (b1 == b2) && (c1 == c2), 52 | (VAR(..), _) => false, 53 | (_, VAR(..)) => false, 54 | (ResolvedDim(..), DIM(..)) => false, 55 | (DIM(..), ResolvedDim(..)) => false, 56 | _ => { 57 | println!("Undefined comparison:"); 58 | println!("(1) {:?}", self); 59 | println!("(2) {:?}", other); 60 | false 61 | } 62 | } 63 | } 64 | } 65 | 66 | impl Hash for Type { 67 | fn hash(&self, state: &mut H) { 68 | use self::Type::*; 69 | match self { 70 | Unit(_) => ().hash(state), 71 | INT(_) => 0.hash(state), 72 | FLOAT(_) => 1.hash(state), 73 | BOOL(_) => 2.hash(state), 74 | // UnresolvedModuleFun(_,_,_) => false, 75 | VAR(a, _) => { 76 | 3.hash(state); 77 | a.hash(state) 78 | } 79 | DIM(b, _) => { 80 | 4.hash(state); 81 | b.hash(state) 82 | } 83 | 84 | Module(a, b, _) => { 85 | 5.hash(state); 86 | a.hash(state); 87 | b.hash(state); 88 | } 89 | FnArgs(ts, _) => { 90 | 6.hash(state); 91 | ts.hash(state) 92 | } 93 | FnArg(n, t, _) => { 94 | 7.hash(state); 95 | n.hash(state); 96 | t.hash(state); 97 | } 98 | ResolvedDim(a, _) => { 99 | 8.hash(state); 100 | a.hash(state) 101 | } 102 | FUN(m,n,p, r, _) => { 103 | 9.hash(state); 104 | m.hash(state); 105 | n.hash(state); 106 | p.hash(state); 107 | r.hash(state); 108 | } 109 | TSR(ts, _) => { 110 | 10.hash(state); 111 | ts.hash(state); 112 | } 113 | UnresolvedModuleFun(a, b, c, _) => { 114 | 11.hash(state); 115 | a.hash(state); 116 | b.hash(state); 117 | c.hash(state); 118 | } 119 | // MismatchedDim(_,_) => true, 120 | _ => { 121 | panic!("{:?}", self); 122 | } 123 | } 124 | } 125 | } 126 | 127 | impl Type { 128 | 129 | pub fn span(&self) -> ByteSpan { 130 | use self::Type::*; 131 | match self { 132 | // literals 133 | Unit(s) => *s, 134 | INT(s) => *s, 135 | FLOAT(s) => *s, 136 | BOOL(s) => *s, 137 | UnresolvedModuleFun(_, _, _, s) => *s, 138 | // type variables that need to be resolved 139 | VAR(_, s) => *s, 140 | DIM(_, s) => *s, 141 | Tuple(_, s) => *s, 142 | 143 | // recursive types 144 | Module(_, _, s) => *s, 145 | FnArgs(_, s) => *s, 146 | FnArg(_, _, s) => *s, 147 | ResolvedDim(_, s) => *s, 148 | FUN(_, _, _, _, s) => *s, 149 | TSR(_, s) => *s, 150 | } 151 | } 152 | 153 | pub fn as_vec(&self) -> Option> { 154 | use self::Type::TSR; 155 | match self { 156 | TSR(ts, _) => Some(ts.to_owned()), 157 | _ => None, 158 | } 159 | } 160 | 161 | pub fn as_args_map(&self) -> Option> { 162 | use self::Type::{FnArg, FnArgs}; 163 | match self { 164 | FnArgs(vs, _) => { 165 | Some( 166 | vs.iter() 167 | .filter_map(|ty| 168 | if let FnArg(ref name,box ref ty, _) = ty { 169 | if name.is_some() { 170 | Some((name.clone().unwrap(), ty.clone())) 171 | } else { 172 | None 173 | } 174 | } else { None } 175 | ) 176 | .collect() 177 | ) 178 | } 179 | _ => None 180 | } 181 | } 182 | 183 | // pub fn last_dim(&self) -> Option { 184 | // match self { 185 | // Type::TSR(vs, _) => { 186 | // Some(vs[vs.len()-1].clone()) 187 | // } 188 | // _ => None 189 | // } 190 | // } 191 | 192 | /// returns the first argument type of a function argument 193 | pub fn first_arg_ty(&self) -> Option { 194 | match self { 195 | Type::FnArgs(vs, _) => { 196 | if let Type::FnArg(_,box ref ty, _) = vs[0] { 197 | Some(ty.clone()) 198 | } else { None } 199 | } 200 | Type::FUN(_,_,arg,_,_) => arg.first_arg_ty(), 201 | _ => None 202 | } 203 | } 204 | 205 | /// modifies the span parameter in type to the most relevant 206 | pub fn with_span(&self, sp: &ByteSpan) -> Type { 207 | use self::Type::*; 208 | match self { 209 | Unit(_) => Unit(*sp), 210 | VAR(ref a, _) => VAR(*a, *sp), 211 | DIM(ref a, _) => DIM(*a, *sp), 212 | INT(_) => INT(*sp), 213 | FLOAT(_) => FLOAT(*sp), 214 | BOOL(_) => BOOL(*sp), 215 | UnresolvedModuleFun(ref a, ref b, ref c, _) => UnresolvedModuleFun(a, b, c, *sp), 216 | FnArgs(ref args, _) => FnArgs(args.clone(), *sp), 217 | FnArg(ref name, ref ty, _) => FnArg(name.clone(), ty.clone(), *sp), 218 | ResolvedDim(ref d, _) => ResolvedDim(*d, *sp), 219 | Module(ref s, ref ty, _) => Module(s.clone(), ty.clone(), *sp), 220 | FUN(ref m,ref n,ref p, ref r, _) => FUN(m.clone(),n.clone(),p.clone(), r.clone(), *sp), 221 | TSR(ref dims, _) => TSR(dims.clone(), *sp), 222 | Tuple(ref vs, _) => Tuple(vs.clone(), *sp), 223 | } 224 | } 225 | 226 | pub fn as_mod_name(&self) -> ModName { 227 | match self { 228 | Type::Module(s,..) => ModName::Named(s.to_owned()), 229 | _ => unimplemented!(), 230 | } 231 | } 232 | 233 | pub fn as_string(&self) -> String { 234 | use self::Type::*; 235 | match self { 236 | Module(ref n, _, _) => n.to_owned(), 237 | TSR(tys, _) => tys.iter().map(|t| t.as_string()).collect::>().join(", "), 238 | DIM(_, _) => "-1".to_owned(), 239 | ResolvedDim(i, _) => format!("{}", i), 240 | _ => panic!("{:?}", self), 241 | } 242 | } 243 | 244 | pub fn as_num(&self) -> Option { 245 | use self::Type::*; 246 | match self { 247 | ResolvedDim(ref i, _) => Some(*i), 248 | _ => None, 249 | } 250 | } 251 | 252 | pub fn as_rank(&self) -> usize { 253 | use self::Type::*; 254 | match self { 255 | TSR(ref i, _) => i.len(), 256 | _ => unimplemented!(), 257 | } 258 | } 259 | 260 | pub fn is_resolved(&self) -> bool { 261 | use self::Type::*; 262 | match self { 263 | Unit(..) => true, 264 | INT(..) => true, 265 | FLOAT(..) => true, 266 | BOOL(..) => true, 267 | UnresolvedModuleFun(..) => false, 268 | 269 | VAR(..) => false, 270 | DIM(..) => false, 271 | 272 | Module(_, Some(i), _) => i.is_resolved(), 273 | Module(_, None, _) => false, 274 | FnArgs(ts, _) => ts.iter().map(|t| t.is_resolved()).all(|t| t), 275 | FnArg(_, t, _) => t.is_resolved(), 276 | ResolvedDim(_, _) => true, 277 | FUN(_,_, p, r, _) => Type::is_resolved(p) && r.is_resolved(), 278 | TSR(_ts, _) => true, //ts.iter().map(|t| t.is_resolved()).all(|t|t), 279 | _ => unimplemented!(), 280 | } 281 | } 282 | } 283 | 284 | impl Debug for Type { 285 | fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { 286 | use self::Type::*; 287 | match self { 288 | Unit(_) => write!(f, "()"), 289 | INT(_) => write!(f, "int"), 290 | FLOAT(_) => write!(f, "float"), 291 | BOOL(_) => write!(f, "bool"), 292 | UnresolvedModuleFun(ref a, ref b, ref c, _) => { 293 | write!(f, "UNRESOLVED({}::{}::{})", a, b, c) 294 | } 295 | Tuple(ref tys, _) => write!(f, "({:?})", tys), 296 | VAR(ref t_id, _) => write!(f, "'{:?}", t_id), 297 | DIM(ref t_id, _) => write!(f, "!{:?}", t_id), 298 | FnArgs(ref args, _) => write!(f, "FnArgs({:?})", args), 299 | FnArg(ref name, ref ty, _) => write!(f, "ARG({:?}={:?})", name, ty), 300 | ResolvedDim(ref d, _) => write!(f, "<{}>", d), 301 | Module(ref s, ref ty, _) => write!(f, "MODULE({}, {:?})", s, ty), 302 | FUN(ref module, ref name,ref p, ref r, _) => write!(f, "{}::{}({:?} -> {:?})", module,name,p, r), 303 | TSR(ref dims, _) => { 304 | if !dims.is_empty() { 305 | write!(f, "[")?; 306 | for i in dims[0..dims.len() - 1].iter() { 307 | write!(f, "{:?}, ", i)?; 308 | } 309 | write!(f, "{:?}]", dims[dims.len() - 1]) 310 | } else { 311 | write!(f, "[]") 312 | } 313 | } 314 | } 315 | } 316 | } 317 | 318 | macro_rules! args { 319 | ( $( $x:expr ),* ) => { 320 | { 321 | Type::FnArgs(vec![$($x),*], CSpan::fresh_span()) 322 | } 323 | }; 324 | } 325 | 326 | macro_rules! arg { 327 | ($name:expr, $ty:expr) => { 328 | Type::FnArg(Some($name.to_owned()), box $ty, CSpan::fresh_span()) 329 | }; 330 | } 331 | 332 | macro_rules! fun { 333 | ($m:expr, $n: expr, $e1:expr, $e2:expr) => { 334 | Type::FUN($m.to_owned(),$n.to_owned(), box $e1, box $e2, CSpan::fresh_span()) 335 | }; 336 | } 337 | 338 | macro_rules! float { 339 | () => { 340 | Type::FLOAT(CSpan::fresh_span()) 341 | }; 342 | } 343 | 344 | macro_rules! tsr { 345 | ($tsr:expr) => { 346 | Type::TSR($tsr, CSpan::fresh_span()) 347 | }; 348 | } 349 | 350 | macro_rules! unit { 351 | () => { 352 | Type::Unit(CSpan::fresh_span()) 353 | }; 354 | } 355 | 356 | macro_rules! int { 357 | () => { 358 | Type::INT(CSpan::fresh_span()) 359 | }; 360 | } 361 | 362 | macro_rules! tuple { 363 | (int 2) => { 364 | Type::Tuple(vec![int!(), int!()], CSpan::fresh_span()) 365 | }; 366 | } 367 | 368 | macro_rules! module { 369 | ($e1:expr) => { 370 | Type::Module($e1.to_owned(), None, CSpan::fresh_span()) 371 | }; 372 | } 373 | 374 | 375 | #[cfg(test)] 376 | mod tests { 377 | use super::*; 378 | use codespan::{Span, ByteIndex}; 379 | #[test] 380 | fn should_not_take_span_into_hash() { 381 | let h = hashset!( 382 | Type::VAR(1, Span::new(ByteIndex(1), ByteIndex(1))), 383 | Type::VAR(1, Span::new(ByteIndex(2), ByteIndex(2))), 384 | 385 | Type::VAR(2, Span::new(ByteIndex(1), ByteIndex(1))), 386 | Type::VAR(2, Span::new(ByteIndex(2), ByteIndex(2))), 387 | ); 388 | assert_eq!(h.len(), 2); 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /trsc/src/typing/unifier.rs: -------------------------------------------------------------------------------- 1 | use typing::{Type, TypeEnv}; 2 | use typing::type_env::TypeId; 3 | use span::CSpan; 4 | use errors::{Emitter, Diag }; 5 | use std::rc::Rc; 6 | use std::cell::RefCell; 7 | use std::process::exit; 8 | use std::collections::BTreeMap; 9 | 10 | use typing::constraint::{Constraints, Equals}; 11 | 12 | pub struct Unifier { 13 | pub emitter: Rc>, 14 | pub tenv: Rc>, 15 | } 16 | 17 | impl Unifier { 18 | 19 | pub fn new(emitter: Rc>, tenv: Rc>) -> Unifier { 20 | Unifier { 21 | emitter, 22 | tenv, 23 | } 24 | } 25 | 26 | pub fn unify(&mut self, cs: Constraints) -> Substitution { 27 | if cs.is_empty() { 28 | Substitution::empty() 29 | } else { 30 | let emitter = cs.emitter.clone(); 31 | let tenv = cs.tenv.clone(); 32 | let mut it = cs.set.into_iter(); 33 | let mut subst = self.unify_one(it.next().unwrap()); 34 | let subst_tail = subst.apply(&Constraints {set: it.collect(), emitter, tenv}); 35 | let subst_tail: Substitution = self.unify(subst_tail); 36 | subst.compose(subst_tail) 37 | } 38 | } 39 | 40 | fn unify_one(&mut self, eq: Equals) -> Substitution { 41 | use self::Type::*; 42 | // println!("{:?}", eq); 43 | let emitter = Rc::clone(&self.emitter); 44 | let tenv = Rc::clone(&self.tenv); 45 | match eq { 46 | Equals(Unit(_), Unit(_)) => Substitution::empty(), 47 | Equals(INT(_), INT(_)) => Substitution::empty(), 48 | Equals(FLOAT(_), FLOAT(_)) => Substitution::empty(), 49 | Equals(BOOL(_), BOOL(_)) => Substitution::empty(), 50 | 51 | Equals(INT(_), ResolvedDim(_, _)) => Substitution::empty(), 52 | Equals(ResolvedDim(_, _), INT(_)) => Substitution::empty(), 53 | 54 | Equals(a @ ResolvedDim(_, _), b @ ResolvedDim(_, _)) => { 55 | if a.as_num() == b.as_num() { 56 | Substitution::empty() 57 | } else { 58 | self.emitter.borrow_mut().add(Diag::DimensionMismatch(a.clone(), b.clone())); 59 | Substitution::empty() 60 | } 61 | } 62 | 63 | Equals(VAR(tvar, _), ty) => self.unify_var(tvar, ty), 64 | Equals(ty, VAR(tvar, _)) => self.unify_var(tvar, ty), 65 | 66 | Equals(DIM(tvar, _), ty) => self.unify_var(tvar, ty), 67 | Equals(ty, DIM(tvar, _)) => self.unify_var(tvar, ty), 68 | 69 | Equals(FnArgs(v1, _), FnArgs(v2, _)) => self.unify( 70 | Constraints { 71 | set: v1.into_iter().zip(v2).map(|(i, j)| Equals(i, j)).collect(), 72 | emitter, 73 | tenv, 74 | }, 75 | ), 76 | 77 | Equals(FnArg(Some(a), ty1, _), FnArg(Some(b), ty2, _)) => { 78 | if a == b { 79 | self.unify( 80 | Constraints { 81 | set: btreeset!{ Equals(*ty1, *ty2)}, 82 | emitter, 83 | tenv, 84 | }, 85 | ) 86 | } else { 87 | panic!("supplied parameter is incorrect! {} != {}", a, b); 88 | } 89 | } 90 | 91 | Equals(FUN(m1,n1,p1, r1, _), FUN(m2,n2,p2, r2, _)) => { 92 | if n1 == n2 { 93 | self.unify( 94 | Constraints{ 95 | set: btreeset!{ 96 | Equals(*p1, *p2), 97 | Equals(*r1, *r2), 98 | }, 99 | emitter, 100 | tenv, 101 | }, 102 | ) 103 | } else { 104 | println!("{} {} {} {}", m1, m2, n1, n2); 105 | panic!() 106 | } 107 | }, 108 | 109 | Equals(Tuple(vs1, _), Tuple(vs2, _)) => self.unify( 110 | Constraints { 111 | set: vs1.into_iter().zip(vs2).map(|(i,j)| Equals(i,j)).collect(), 112 | emitter, 113 | tenv, 114 | }, 115 | ), 116 | 117 | Equals(ts1 @ TSR(_, _), ts2 @ TSR(_, _)) => { 118 | if ts1.as_rank() == ts2.as_rank() { 119 | if let (TSR(dims1, s1), TSR(dims2, s2)) = (ts1.clone(), ts2.clone()) { 120 | let cons = Constraints { 121 | set: dims1 122 | .into_iter() 123 | .zip(dims2) 124 | .filter_map(|(i, j)| { 125 | if let (Type::ResolvedDim(a,_), Type::ResolvedDim(b,_)) = (i.clone(),j.clone()) { 126 | if a != b { self.emitter.borrow_mut().add(Diag::TypeError(ts1.clone(),ts2.clone())) } 127 | None 128 | } else { 129 | Some(Equals(i.with_span(&s1), j.with_span(&s2))) 130 | } 131 | }) 132 | .collect(), 133 | emitter, 134 | tenv, 135 | }; 136 | self.unify(cons) 137 | } else { 138 | unimplemented!(); 139 | } 140 | } else { 141 | self.emitter.borrow_mut().add(Diag::RankMismatch(ts1, ts2)); 142 | Substitution::empty() 143 | } 144 | } 145 | 146 | Equals(Module(n1, Some(box ty1), _), Module(n2, Some(box ty2), _)) => self.unify( 147 | Constraints { 148 | set: btreeset!{ 149 | if n1 == n2 { 150 | Equals(ty1, ty2) 151 | } else { 152 | panic!(); 153 | } 154 | }, 155 | emitter, 156 | tenv, 157 | }, 158 | ), 159 | 160 | Equals(u @ UnresolvedModuleFun(_, _, _, _), ty) => { 161 | Substitution::empty() 162 | } 163 | 164 | _ => { 165 | let Equals(a, b) = eq; 166 | let mut em = self.emitter.borrow_mut(); 167 | em.add(Diag::TypeError(a, b)); 168 | em.print_errs(); 169 | exit(-1); 170 | } 171 | } 172 | } 173 | 174 | fn unify_var(&mut self, tvar: TypeId, ty: Type) -> Substitution { 175 | use self::Type::*; 176 | 177 | let span = CSpan::fresh_span(); 178 | match ty.clone() { 179 | VAR(tvar2, _) => { 180 | if tvar == tvar2 { 181 | Substitution::empty() 182 | } else { 183 | Substitution(btreemap!{ VAR(tvar, span) => ty }) 184 | } 185 | } 186 | DIM(tvar2, _) => { 187 | if tvar == tvar2 { 188 | Substitution::empty() 189 | } else { 190 | Substitution(btreemap!{ VAR(tvar, span) => ty }) 191 | } 192 | } 193 | _ => if occurs(tvar, &ty) { 194 | panic!("circular type") 195 | } else { 196 | Substitution(btreemap!{ VAR(tvar, span) => ty }) 197 | }, 198 | } 199 | } 200 | } 201 | 202 | fn occurs(tvar: TypeId, ty: &Type) -> bool { 203 | use self::Type::*; 204 | match ty { 205 | FUN(_,_, ref p, ref r, _) => occurs(tvar, &p) | occurs(tvar, &r), 206 | VAR(ref tvar2, _) => tvar == *tvar2, 207 | _ => false, 208 | } 209 | } 210 | 211 | #[derive(Debug, PartialEq)] 212 | pub struct Substitution(pub BTreeMap); 213 | 214 | impl Substitution { 215 | 216 | /// apply substitution to a set of constraints 217 | pub fn apply(&mut self, cs: &Constraints) -> Constraints { 218 | Constraints { 219 | set: cs.set 220 | .iter() 221 | .map(|Equals(a, b)| Equals(self.apply_ty(a), self.apply_ty(b))) 222 | .collect(), 223 | tenv: cs.tenv.clone(), 224 | emitter: cs.emitter.clone(), 225 | } 226 | } 227 | 228 | pub fn apply_ty(&mut self, ty: &Type) -> Type { 229 | self.0.iter().fold(ty.clone(), |result, solution| { 230 | let (ty, solution_type) = solution; 231 | if let Type::VAR(ref tvar, ref span) = ty { 232 | substitute_tvar(result, tvar, &solution_type.with_span(span)) 233 | } else { 234 | panic!("Impossible!"); 235 | } 236 | }) 237 | } 238 | 239 | pub fn compose(&mut self, mut other: Substitution) -> Substitution { 240 | let mut self_substituded: BTreeMap = self.0 241 | .clone() 242 | .into_iter() 243 | .map(|(k, s)| (k, other.apply_ty(&s))) 244 | .collect(); 245 | self_substituded.extend(other.0); 246 | Substitution(self_substituded) 247 | } 248 | 249 | pub fn empty() -> Substitution { 250 | Substitution(BTreeMap::new()) 251 | } 252 | } 253 | 254 | /// replace tvar with replacement in ty 255 | fn substitute_tvar(ty: Type, tvar: &TypeId, replacement: &Type) -> Type { 256 | use self::Type::*; 257 | // println!("\nTVAR:::\n{:?}, \n'{:?}, \n{:?}\n", ty, tvar, replacement); 258 | match ty { 259 | UnresolvedModuleFun(_, _, _, _) => { 260 | println!("{:?}, replacement: {:?}", ty, replacement); 261 | ty 262 | }, 263 | Unit(_) => ty, 264 | INT(_) => ty, 265 | BOOL(_) => ty, 266 | FLOAT(_) => ty, 267 | ResolvedDim(_, _) => ty, 268 | VAR(tvar2, span) => { 269 | if *tvar == tvar2 { 270 | replacement.with_span(&span) 271 | } else { 272 | ty 273 | } 274 | } 275 | DIM(tvar2, span) => { 276 | if *tvar == tvar2 { 277 | replacement.with_span(&span) 278 | } else { 279 | ty 280 | } 281 | } 282 | FnArgs(args, span) => FnArgs( 283 | args.into_iter() 284 | .map(|ty| match ty { 285 | FnArg(name, a, s) => FnArg(name, box substitute_tvar(*a, tvar, replacement), s), 286 | _ => panic!(ty), 287 | }) 288 | .collect(), 289 | span, 290 | ), 291 | Tuple(tys, s) => Tuple(tys.into_iter().map(|t| substitute_tvar(t, tvar, replacement)).collect(), s), 292 | FUN(module,name,p, r, s) => FUN( 293 | module, 294 | name, 295 | box substitute_tvar(*p, tvar, &replacement), 296 | box substitute_tvar(*r, tvar, &replacement), 297 | s, 298 | ), 299 | TSR(_, _) => ty, 300 | 301 | Module(n, Some(box ty), s) => { 302 | Module(n, Some(box substitute_tvar(ty, tvar, replacement)), s) 303 | } 304 | 305 | Module(_, None, _) => ty, 306 | FnArg(name, box ty, s) => FnArg(name, box substitute_tvar(ty, tvar, replacement), s), 307 | } 308 | } -------------------------------------------------------------------------------- /trsc/tests/input/gan.trs: -------------------------------------------------------------------------------- 1 | use lin::Linear; 2 | use reg::{BatchNorm1d}; 3 | use nonlin::{leaky_relu, tanh, sigmoid}; 4 | 5 | dim noise_dim = 100; 6 | dim image_dim = 28; 7 | dim flattened_image_dim = 784; 8 | tsr noise = [?, noise_dim]; 9 | tsr flattened_image = [?, flattened_image_dim]; 10 | tsr image = [?, 1, image_dim, image_dim]; 11 | 12 | node Generator image> {} 13 | weights Generator image> { 14 | lin1 = Linear::new(in=noise_dim, out=128); 15 | lin2 = Linear::new(in=128, out=256); 16 | bn1 = BatchNorm1d::new(num_features=256); 17 | lin3 = Linear::new(in=256, out=512); 18 | bn2 = BatchNorm1d::new(num_features=512); 19 | lin4 = Linear::new(in=512, out=1024); 20 | bn3 = BatchNorm1d::new(num_features=1024); 21 | lin5 = Linear::new(in=1024, out=flattened_image_dim); 22 | } 23 | graph Generator image> { 24 | def new() -> Self { 25 | self 26 | } 27 | def forward { 28 | x 29 | |> lin1 |> leaky_relu(p=0.2) 30 | |> lin2 |> bn1 |> leaky_relu(p=0.2) 31 | |> lin3 |> bn2 |> leaky_relu(p=0.2) 32 | |> lin4 |> bn3 |> leaky_relu(p=0.2) 33 | |> lin5 |> tanh 34 | |> view(_, 1, image_dim, image_dim) 35 | } 36 | } 37 | 38 | node Discriminator [?, 1]> {} 39 | weights Discriminator [?,1]> { 40 | lin1 = Linear::new(in=flattened_image_dim, out=512); 41 | lin2 = Linear::new(in=512, out=256); 42 | lin3 = Linear::new(in=256, out=1); 43 | } 44 | graph Discriminator [?,1]> { 45 | def new() -> Self { 46 | self 47 | } 48 | def forward { 49 | x |> view(?, flattened_image_dim) 50 | |> lin1 |> leaky_relu(p=0.2) 51 | |> lin2 |> leaky_relu(p=0.2) 52 | |> lin3 |> sigmoid 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /trsc/tests/input/mnist.trs: -------------------------------------------------------------------------------- 1 | use conv::{Conv2d, maxpool2d}; 2 | use reg::Dropout2d; 3 | use nonlin::{relu, log_softmax}; 4 | use lin::Linear; 5 | 6 | node Mnist<[?, IMAGE] -> LABELS> { 7 | // this is where you declare type level constants 8 | dim FC1 = 320; 9 | dim FC2 = 50; 10 | 11 | // Prediction 12 | dim OUT = 10; 13 | // Channel 14 | dim C = 1; 15 | 16 | dim W = 28; // Image Width 17 | dim H = 28; // Image Height 18 | tsr IMAGE = [C,H,W]; // Tensor alias 19 | 20 | tsr LABELS = [?,OUT]; 21 | } 22 | 23 | weights Mnist<[?, IMAGE] -> LABELS> { 24 | conv1 = Conv2d::new(in_ch=1, out_ch=10, kernel_size=(5,5)); 25 | conv2 = Conv2d::new(in_ch=10, out_ch=20, kernel_size=5); 26 | dropout = Dropout2d::new(p=0.5); 27 | fc1 = Linear::<[?,FC1] -> [?,FC2]>::new(in=FC1, out=FC2); 28 | fc2 = Linear::<[?,FC2] -> [?,OUT]>::new(in=FC2, out=OUT); 29 | } 30 | 31 | graph Mnist<[?, IMAGE] -> LABELS> { 32 | 33 | def new() -> Self { 34 | fc1.init_normal(std=1.); 35 | fc2.init_normal(std=1.); 36 | self 37 | } 38 | 39 | def forward { 40 | x 41 | |> conv1 |> maxpool2d(kernel_size=2) |> relu 42 | |> conv2 |> dropout |> maxpool2d(kernel_size=2) |> relu 43 | |> view(_, FC1) 44 | |> fc1 |> relu 45 | |> self.example() 46 | |> log_softmax(dim=1) 47 | } 48 | 49 | def example(x: [?,FC2]) -> LABELS { 50 | x |> fc2 |> relu 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /trsc/tests/input/xor.trs: -------------------------------------------------------------------------------- 1 | use lin::Linear; 2 | use nonlin::{sigmoid, relu}; 3 | 4 | node Xor<[?,2] -> [?,1]> { 5 | } 6 | 7 | weights Xor<[?,2] -> [?,1]> { 8 | fc1 = Linear::new(in=2, out=3); 9 | fc2 = Linear::<[?,3]->[?,1]>::new(in=3, out=1); 10 | } 11 | 12 | graph Xor<[?,2] -> [?,1]> { 13 | def new() -> Self { 14 | self 15 | } 16 | 17 | def forward { 18 | x |> fc1 |> sigmoid 19 | |> fc2 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /trsc/tests/integration_test.rs: -------------------------------------------------------------------------------- 1 | extern crate assert_cli; 2 | 3 | #[test] 4 | fn test_no_input() { 5 | assert_cli::Assert::main_binary() 6 | .fails() 7 | .unwrap(); 8 | } 9 | 10 | #[test] 11 | fn test_xor() { 12 | assert_cli::Assert::main_binary() 13 | .with_args(&["--in", "tests/input/xor.trs"]) 14 | .succeeds() 15 | .and() 16 | .stdout().is(include_str!("output/xor.py")) 17 | .unwrap(); 18 | } 19 | 20 | #[test] 21 | fn test_mnist() { 22 | assert_cli::Assert::main_binary() 23 | .with_args(&["--in", "tests/input/mnist.trs"]) 24 | .succeeds() 25 | .and() 26 | .stdout().is(include_str!("output/mnist.py")) 27 | .unwrap(); 28 | } 29 | 30 | #[test] 31 | fn test_gan() { 32 | assert_cli::Assert::main_binary() 33 | .with_args(&["--in", "tests/input/gan.trs"]) 34 | .succeeds() 35 | .and() 36 | .stdout().is(include_str!("output/gan.py")) 37 | .unwrap(); 38 | } -------------------------------------------------------------------------------- /trsc/tests/output/gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | 8 | class Discriminator(nn.Module): 9 | '''Discriminator::forward([!1, <1>, <28>, <28>] -> [!4, <1>])''' 10 | def __init__(self): 11 | super(Discriminator, self).__init__() 12 | self.lin1 = nn.Linear(in_features=784, out_features=512) 13 | self.lin2 = nn.Linear(in_features=512, out_features=256) 14 | self.lin3 = nn.Linear(in_features=256, out_features=1) 15 | def forward(self, x): 16 | x = x.view(-1, 784) 17 | x = self.lin1(x) 18 | x = F.leaky_relu(x) 19 | x = self.lin2(x) 20 | x = F.leaky_relu(x) 21 | x = self.lin3(x) 22 | return F.sigmoid(x) 23 | 24 | 25 | class Generator(nn.Module): 26 | '''Generator::forward([!1, <100>] -> [!1, <1>, <28>, <28>])''' 27 | def __init__(self): 28 | super(Generator, self).__init__() 29 | self.lin1 = nn.Linear(in_features=100, out_features=128) 30 | self.lin2 = nn.Linear(in_features=128, out_features=256) 31 | self.bn1 = nn.BatchNorm1d(num_features=256) 32 | self.lin3 = nn.Linear(in_features=256, out_features=512) 33 | self.bn2 = nn.BatchNorm1d(num_features=512) 34 | self.lin4 = nn.Linear(in_features=512, out_features=1024) 35 | self.bn3 = nn.BatchNorm1d(num_features=1024) 36 | self.lin5 = nn.Linear(in_features=1024, out_features=784) 37 | def forward(self, x): 38 | x = self.lin1(x) 39 | x = F.leaky_relu(x) 40 | x = self.lin2(x) 41 | x = self.bn1(x) 42 | x = F.leaky_relu(x) 43 | x = self.lin3(x) 44 | x = self.bn2(x) 45 | x = F.leaky_relu(x) 46 | x = self.lin4(x) 47 | x = self.bn3(x) 48 | x = F.leaky_relu(x) 49 | x = self.lin5(x) 50 | x = F.tanh(x) 51 | return x.view(-1, 1, 28, 28) 52 | 53 | 54 | -------------------------------------------------------------------------------- /trsc/tests/output/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | 8 | class Mnist(nn.Module): 9 | '''Mnist::forward([!1, <1>, <28>, <28>] -> [!1, <10>])''' 10 | def __init__(self): 11 | super(Mnist, self).__init__() 12 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=(5, 5)) 13 | self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5) 14 | self.dropout = Dropout2d(p=0.5) 15 | self.fc1 = nn.Linear(in_features=320, out_features=50) 16 | self.fc2 = nn.Linear(in_features=50, out_features=10) 17 | nn.init.normal_(std=1) 18 | nn.init.normal_(std=1) 19 | def forward(self, x): 20 | x = self.conv1(x) 21 | x = F.max_pool2d(x, kernel_size) 22 | x = F.relu(x) 23 | x = self.conv2(x) 24 | x = self.dropout(x) 25 | x = F.max_pool2d(x, kernel_size) 26 | x = F.relu(x) 27 | x = x.view(-1, 320) 28 | x = self.fc1(x) 29 | x = F.relu(x) 30 | x = self.example(x) 31 | return F.log_softmax(x, dim) 32 | def example(self, x): 33 | x = self.fc2(x) 34 | return F.relu(x) 35 | 36 | 37 | -------------------------------------------------------------------------------- /trsc/tests/output/xor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | 8 | class Xor(nn.Module): 9 | '''Xor::forward([!1, <2>] -> [!1, <1>])''' 10 | def __init__(self): 11 | super(Xor, self).__init__() 12 | self.fc1 = nn.Linear(in_features=2, out_features=3) 13 | self.fc2 = nn.Linear(in_features=3, out_features=1) 14 | def forward(self, x): 15 | x = self.fc1(x) 16 | x = F.sigmoid(x) 17 | return self.fc2(x) 18 | -------------------------------------------------------------------------------- /trsc_core_derive/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "trsc_core_derive" 3 | version = "0.1.0" 4 | authors = ["ricky han "] 5 | 6 | [lib] 7 | proc-macro = true 8 | name = "trsc_core_derive" 9 | 10 | [dependencies] 11 | syn = "0.11.11" 12 | quote = "0.3.15" 13 | 14 | -------------------------------------------------------------------------------- /trsc_core_derive/src/attrs.rs: -------------------------------------------------------------------------------- 1 | use syn; 2 | use std::collections::BTreeMap; 3 | 4 | pub fn get_is_stateful(attrs: &[syn::Attribute]) -> bool { 5 | for attr in attrs.iter() { 6 | if let syn::MetaItem::Word(id) = &attr.value { 7 | if "stateful" == id.as_ref() { 8 | return true; 9 | } 10 | } 11 | } 12 | return false; 13 | } 14 | 15 | pub fn get_fns(attrs: &[syn::Attribute]) -> BTreeMap<&str, String> { 16 | let mut map = BTreeMap::new(); 17 | for attr in attrs.iter() { 18 | if let syn::MetaItem::NameValue(key, val) = &attr.value { 19 | let key = key.as_ref(); 20 | if "path" == key { continue; } 21 | let val = if let syn::Lit::Str(s,..) = val { s } 22 | else { continue }; 23 | 24 | map.insert(key, val.clone()); 25 | } 26 | } 27 | map 28 | } 29 | 30 | pub fn get_op_name(attrs: &[syn::Attribute]) -> Option<&str> { 31 | get_str_attr("name", attrs) 32 | } 33 | 34 | pub fn get_path(attrs: &[syn::Attribute]) -> Option<&str> { 35 | get_str_attr("path", attrs) 36 | } 37 | 38 | pub fn get_str_attr<'a>(keyname: &'static str, attrs: &'a [syn::Attribute]) -> Option<&'a str> { 39 | for attr in attrs.iter() { 40 | if let syn::MetaItem::NameValue(key, val) = &attr.value { 41 | let key = key.as_ref(); 42 | if keyname == key { 43 | if let syn::Lit::Str(s,..) = val { return Some(s) } 44 | else { return None }; 45 | } 46 | } 47 | } 48 | None 49 | } 50 | -------------------------------------------------------------------------------- /trsc_core_derive/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![recursion_limit="128"] 2 | extern crate proc_macro; 3 | extern crate syn; 4 | #[macro_use] 5 | extern crate quote; 6 | 7 | mod attrs; 8 | mod parser; 9 | 10 | use attrs::*; 11 | use std::collections::BTreeMap; 12 | use proc_macro::TokenStream; 13 | use parser::{parse_decl, FnDecl}; 14 | 15 | #[proc_macro_derive(Op, attributes(stateful, path, init_normal, new, forward))] 16 | pub fn derive(input: TokenStream) -> TokenStream { 17 | // Construct a string representation of the type definition 18 | let s = input.to_string(); 19 | 20 | // Parse the string representation 21 | let ast = syn::parse_derive_input(&s).unwrap(); 22 | 23 | // Build the impl 24 | let gen = impl_op(&ast); 25 | 26 | // Return the generated impl 27 | gen.parse().unwrap() 28 | } 29 | 30 | fn impl_op(ast: &syn::DeriveInput) -> quote::Tokens { 31 | if let syn::Body::Enum(_) = ast.body { 32 | panic!("Cannot derive Op for `enum`"); 33 | } 34 | 35 | let name = &ast.ident; 36 | 37 | let stateful = get_is_stateful(&ast.attrs); 38 | let op_name = name.to_string(); 39 | let fns = get_fns(&ast.attrs); 40 | let path = get_path(&ast.attrs).unwrap_or_else(|| panic!("no path supplied")); 41 | let fn_decls = get_fn_decls(path, &fns); 42 | let ty_sigs = gen_ty_sigs(&fn_decls); 43 | 44 | quote! { 45 | impl Op for #name { 46 | fn get_name(&self) -> &'static str { 47 | #op_name 48 | } 49 | fn ty_sigs(&self, _tenv: &mut TypeEnv) -> Vec<(MethodName, Type)> { 50 | #ty_sigs 51 | } 52 | fn is_stateful(&self) -> bool { 53 | #stateful 54 | } 55 | } 56 | } 57 | } 58 | 59 | 60 | fn get_fn_decls(path: &str, ty_sigs: &BTreeMap<&str, String>) -> Vec { 61 | ty_sigs 62 | .iter() 63 | .map(|(k,v)| { 64 | parse_decl(path, k, v) 65 | }) 66 | .collect() 67 | } 68 | 69 | fn gen_ty_sigs(decls: &[FnDecl]) -> quote::Tokens { 70 | let ty_sigs: Vec = decls.iter().map(|i|gen_decl(i)).collect(); 71 | quote! { 72 | vec![ 73 | #(#ty_sigs),* 74 | ] 75 | } 76 | } 77 | 78 | fn gen_decl(fn_decl: &FnDecl) -> quote::Tokens { 79 | let name = &fn_decl.name; 80 | let path = &fn_decl.path; 81 | 82 | if fn_decl.resolved { 83 | let params = &fn_decl.params; 84 | let tys = &fn_decl.tys; 85 | let ret = &fn_decl.ret; 86 | quote! { 87 | ( 88 | #name, 89 | fun!( 90 | self.get_name(), 91 | #name, 92 | args!( 93 | #( 94 | arg!( 95 | #params, 96 | #tys 97 | ) 98 | ),* 99 | ), 100 | #ret 101 | ), 102 | ) 103 | } 104 | } else { 105 | quote! { 106 | ( 107 | #name, 108 | Type::UnresolvedModuleFun(#path, self.get_name(), #name, CSpan::fresh_span()), 109 | ) 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /trsc_core_derive/src/parser.rs: -------------------------------------------------------------------------------- 1 | use quote::{ToTokens, Tokens}; 2 | 3 | #[derive(Debug, Clone)] 4 | pub enum Type { 5 | Float, 6 | Int, 7 | Tsr, 8 | SelfTy, 9 | Unit, 10 | //... 11 | } 12 | 13 | impl Type { 14 | fn from_str(s: &str) -> Self { 15 | use self::Type::*; 16 | match s { 17 | "float" => Float, 18 | "int" => Int, 19 | "self" => SelfTy, 20 | "unit" => Unit, 21 | "tsr0" => Tsr, 22 | _ => panic!("Unknown type"), 23 | } 24 | } 25 | } 26 | 27 | impl ToTokens for Type { 28 | fn to_tokens(&self, tokens: &mut Tokens) { 29 | use self::Type::*; 30 | match self { 31 | Float => tokens.append(quote!{float!()}), 32 | Int => tokens.append(quote!{int!()}), 33 | Unit => tokens.append(quote!{unit!()}), 34 | SelfTy => tokens.append(quote!{module!(self.get_name())}), 35 | _ => unimplemented!(), 36 | } 37 | } 38 | } 39 | 40 | #[derive(Clone, Debug)] 41 | pub struct FnDecl { 42 | pub resolved: bool, 43 | pub params: Vec, 44 | pub tys: Vec, 45 | pub ret: Type, 46 | pub name: String, 47 | pub path: String, 48 | } 49 | 50 | #[derive(Clone, Debug, Eq, PartialEq)] 51 | pub enum Token { 52 | LPAREN, 53 | RPAREN, 54 | WORD(String), 55 | SEMI, 56 | ARROW, 57 | QMARK, 58 | COMMA, 59 | } 60 | 61 | pub fn parse_decl(path: &str, name: &str, decl: &str) -> FnDecl { 62 | let tokens = lex(decl); 63 | parse(path, name, &tokens) 64 | } 65 | 66 | macro_rules! eat { 67 | ($it:ident, $ty:ident, $msg:expr) => { 68 | if let Some($ty) = $it.next() {} else { 69 | panic!($msg) 70 | } 71 | } 72 | } 73 | 74 | fn parse(path: &str, name: &str, toks: &[Token]) -> FnDecl { 75 | use self::Token::*; 76 | 77 | let mut it = toks.iter().peekable(); 78 | let mut ret = FnDecl { 79 | name: name.to_owned(), 80 | path: path.to_owned(), 81 | resolved: true, 82 | params: vec![], 83 | tys: vec![], 84 | ret: self::Type::Float, 85 | }; 86 | 87 | if let Some(QMARK) = it.peek() { 88 | ret.resolved = false; 89 | it.next(); 90 | } else { 91 | } 92 | 93 | eat!(it, LPAREN, "lparen not found"); 94 | 95 | while let Some(tok) = it.peek().cloned() { 96 | if *tok == RPAREN { 97 | it.next(); 98 | } else if let WORD(ref name) = *tok { 99 | // param name 100 | ret.params.push(name.clone()); 101 | it.next(); 102 | eat!(it, SEMI, "semi"); 103 | // param ty 104 | if let Some(WORD(ref tyword)) = it.next() { 105 | let ty: Type = Type::from_str(tyword.as_str()); 106 | ret.tys.push(ty); 107 | } else { 108 | panic!("No param type specified"); 109 | } 110 | } else if ARROW == *tok { 111 | // return type 112 | it.next(); 113 | if let Some(WORD(ref tyword)) = it.next() { 114 | let ty: Type = Type::from_str(tyword.as_str()); 115 | ret.ret = ty; 116 | return ret; 117 | } else { 118 | panic!("No ret ty specified"); 119 | } 120 | } else { 121 | it.next(); 122 | } 123 | } 124 | 125 | ret 126 | } 127 | 128 | 129 | fn lex(decl: &str) -> Vec { 130 | use self::Token::*; 131 | 132 | let mut it = decl.chars().peekable(); 133 | let mut toks = vec![]; 134 | while let Some(c) = it.peek().cloned() { 135 | match c { 136 | '(' => { 137 | toks.push(LPAREN); 138 | it.next(); 139 | } 140 | ')' => { 141 | toks.push(RPAREN); 142 | it.next(); 143 | } 144 | '?' => { 145 | toks.push(QMARK); 146 | it.next(); 147 | } 148 | ' ' | '\n' => { 149 | it.next(); 150 | } 151 | 'A'...'z' => { 152 | let mut buf = String::new(); 153 | while let Some(ch) = it.peek().cloned() { 154 | if ch.is_alphanumeric() || ch == '_' { 155 | buf.push(ch); 156 | it.next(); 157 | } else { 158 | break; 159 | } 160 | } 161 | toks.push(WORD(buf)); 162 | } 163 | ':' => { 164 | toks.push(SEMI); 165 | it.next(); 166 | } 167 | ',' => { 168 | toks.push(COMMA); 169 | it.next(); 170 | } 171 | '-' => { 172 | it.next(); 173 | if let Some('>') = it.next() { 174 | toks.push(ARROW); 175 | it.next(); 176 | } else { 177 | panic!("malformed"); 178 | } 179 | } 180 | _ => { 181 | panic!("{}", c); 182 | } 183 | } 184 | } 185 | toks 186 | } -------------------------------------------------------------------------------- /vscode-syntax-ext/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | *.vsix -------------------------------------------------------------------------------- /vscode-syntax-ext/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | // A launch configuration that launches the extension inside a new window 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | { 6 | "version": "0.2.0", 7 | "configurations": [ 8 | { 9 | "name": "Extension", 10 | "type": "extensionHost", 11 | "request": "launch", 12 | "runtimeExecutable": "${execPath}", 13 | "args": [ 14 | "--extensionDevelopmentPath=${workspaceFolder}" 15 | ] 16 | } 17 | ] 18 | } -------------------------------------------------------------------------------- /vscode-syntax-ext/.vscodeignore: -------------------------------------------------------------------------------- 1 | .vscode/** 2 | .vscode-test/** 3 | .gitignore 4 | vsc-extension-quickstart.md 5 | -------------------------------------------------------------------------------- /vscode-syntax-ext/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | All notable changes to the "tensorscript" extension will be documented in this file. 3 | 4 | Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how to structure this file. 5 | 6 | ## [Unreleased] 7 | - Initial release -------------------------------------------------------------------------------- /vscode-syntax-ext/README.md: -------------------------------------------------------------------------------- 1 | # tensorscript README 2 | 3 | This is the README for your extension "tensorscript". After writing up a brief description, we recommend including the following sections. 4 | 5 | ## Features 6 | 7 | Describe specific features of your extension including screenshots of your extension in action. Image paths are relative to this README file. 8 | 9 | For example if there is an image subfolder under your extension project workspace: 10 | 11 | \!\[feature X\]\(images/feature-x.png\) 12 | 13 | > Tip: Many popular extensions utilize animations. This is an excellent way to show off your extension! We recommend short, focused animations that are easy to follow. 14 | 15 | ## Requirements 16 | 17 | If you have any requirements or dependencies, add a section describing those and how to install and configure them. 18 | 19 | ## Extension Settings 20 | 21 | Include if your extension adds any VS Code settings through the `contributes.configuration` extension point. 22 | 23 | For example: 24 | 25 | This extension contributes the following settings: 26 | 27 | * `myExtension.enable`: enable/disable this extension 28 | * `myExtension.thing`: set to `blah` to do something 29 | 30 | ## Known Issues 31 | 32 | Calling out known issues can help limit users opening duplicate issues against your extension. 33 | 34 | ## Release Notes 35 | 36 | Users appreciate release notes as you update your extension. 37 | 38 | ### 1.0.0 39 | 40 | Initial release of ... 41 | 42 | ### 1.0.1 43 | 44 | Fixed issue #. 45 | 46 | ### 1.1.0 47 | 48 | Added features X, Y, and Z. 49 | 50 | ----------------------------------------------------------------------------------------------------------- 51 | 52 | ## Working with Markdown 53 | 54 | **Note:** You can author your README using Visual Studio Code. Here are some useful editor keyboard shortcuts: 55 | 56 | * Split the editor (`Cmd+\` on macOS or `Ctrl+\` on Windows and Linux) 57 | * Toggle preview (`Shift+CMD+V` on macOS or `Shift+Ctrl+V` on Windows and Linux) 58 | * Press `Ctrl+Space` (Windows, Linux) or `Cmd+Space` (macOS) to see a list of Markdown snippets 59 | 60 | ### For more information 61 | 62 | * [Visual Studio Code's Markdown Support](http://code.visualstudio.com/docs/languages/markdown) 63 | * [Markdown Syntax Reference](https://help.github.com/articles/markdown-basics/) 64 | 65 | **Enjoy!** 66 | -------------------------------------------------------------------------------- /vscode-syntax-ext/language-configuration.json: -------------------------------------------------------------------------------- 1 | { 2 | "comments": { 3 | // symbol used for single line comment. Remove this entry if your language does not support line comments 4 | "lineComment": "//", 5 | }, 6 | // symbols used as brackets 7 | "brackets": [ 8 | ["{", "}"], 9 | ["[", "]"], 10 | ["<", ">"], 11 | ["(", ")"] 12 | ], 13 | // symbols that are auto closed when typing 14 | "autoClosingPairs": [ 15 | ["{", "}"], 16 | ["[", "]"], 17 | ["(", ")"], 18 | ["\"", "\""], 19 | ["<", ">"], 20 | ["'", "'"] 21 | ], 22 | // symbols that that can be used to surround a selection 23 | "surroundingPairs": [ 24 | ["{", "}"], 25 | ["[", "]"], 26 | ["(", ")"], 27 | ["\"", "\""], 28 | ["<", ">"], 29 | ["'", "'"] 30 | ] 31 | } -------------------------------------------------------------------------------- /vscode-syntax-ext/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorscript", 3 | "displayName": "TensorScript", 4 | "description": "Syntax Highlighting for TensorScript", 5 | "version": "0.0.1", 6 | "publisher": "rickyhan", 7 | "engines": { 8 | "vscode": "^1.22.0" 9 | }, 10 | "categories": [ 11 | "Languages" 12 | ], 13 | "contributes": { 14 | "languages": [{ 15 | "id": "tensorscript", 16 | "aliases": ["TensorScript", "tensorscript"], 17 | "extensions": [".trs"], 18 | "configuration": "./language-configuration.json" 19 | }], 20 | "grammars": [{ 21 | "language": "tensorscript", 22 | "scopeName": "source.trs", 23 | "path": "./syntaxes/tensorscript.tmLanguage.json" 24 | }] 25 | } 26 | } -------------------------------------------------------------------------------- /vscode-syntax-ext/syntaxes/tensorscript.tmLanguage.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json", 3 | "name": "TensorScript", 4 | "patterns": [ 5 | { 6 | "include": "#keywords" 7 | }, 8 | { 9 | "include": "#comments" 10 | }, 11 | { 12 | "include": "#tensor_type" 13 | }, 14 | { 15 | "include": "#type_sig" 16 | }, 17 | { 18 | "include": "#strings" 19 | }, 20 | { 21 | "include": "#pipes" 22 | } 23 | ], 24 | "repository": { 25 | "keywords": { 26 | "patterns": [{ 27 | "name": "keyword.control.tensorscript", 28 | "match": "\\b(if|while|for|return|use|node|weights|graph|def|dim|tsr|forward)\\b" 29 | }] 30 | }, 31 | "strings": { 32 | "name": "string.quoted.double.tensorscript", 33 | "begin": "\"", 34 | "end": "\"", 35 | "patterns": [ 36 | { 37 | "name": "constant.character.escape.tensorscript", 38 | "match": "\\\\." 39 | } 40 | ] 41 | }, 42 | "comments": { 43 | "patterns": [ 44 | { 45 | "name": "comment.line.tensorscript", 46 | "match": "\/\/.*" 47 | } 48 | ] 49 | }, 50 | "type_sig": { 51 | "patterns": [ 52 | { 53 | "name": "constant.language.tensorscript", 54 | "match": "<.*>" 55 | } 56 | ] 57 | }, 58 | "tensor_type": { 59 | "patterns": [ 60 | { 61 | "name": "markup.list.tensorscript", 62 | "match": "[.*]" 63 | } 64 | ] 65 | }, 66 | "pipes": { 67 | "patterns": [ 68 | { 69 | "name": "keyword.control.tensorscript", 70 | "match": "\b|>\b" 71 | } 72 | ] 73 | } 74 | }, 75 | "scopeName": "source.trs" 76 | } -------------------------------------------------------------------------------- /vscode-syntax-ext/vsc-extension-quickstart.md: -------------------------------------------------------------------------------- 1 | # Welcome to your VS Code Extension 2 | 3 | ## What's in the folder 4 | * This folder contains all of the files necessary for your extension. 5 | * `package.json` - this is the manifest file in which you declare your language support and define 6 | the location of the grammar file that has been copied into your extension. 7 | * `syntaxes/tensorscript.tmLanguage.json` - this is the Text mate grammar file that is used for tokenization. 8 | * `language-configuration.json` - this the language configuration, defining the tokens that are used for 9 | comments and brackets. 10 | 11 | ## Get up and running straight away 12 | * Make sure the language configuration settings in `language-configuration.json` are accurate. 13 | * Press `F5` to open a new window with your extension loaded. 14 | * Create a new file with a file name suffix matching your language. 15 | * Verify that syntax highlighting works and that the language configuration settings are working. 16 | 17 | ## Make changes 18 | * You can relaunch the extension from the debug toolbar after making changes to the files listed above. 19 | * You can also reload (`Ctrl+R` or `Cmd+R` on Mac) the VS Code window with your extension to load your changes. 20 | 21 | ## Add more language features 22 | * To add features such as intellisense, hovers and validators check out the VS Code extenders documentation at 23 | https://code.visualstudio.com/docs 24 | 25 | ## Install your extension 26 | * To start using your extension with Visual Studio Code copy it into the `/.vscode/extensions` folder and restart Code. 27 | * To share your extension with the world, read on https://code.visualstudio.com/docs about publishing an extension. 28 | --------------------------------------------------------------------------------