├── .github ├── README.md └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── pyproject.toml └── src └── TinyLean ├── __init__.py ├── __main__.py ├── ast.py ├── grammar.py ├── ir.py └── tests ├── __init__.py ├── onboard.py ├── test_checker.py ├── test_main.py ├── test_parser.py └── test_resolver.py /.github/README.md: -------------------------------------------------------------------------------- 1 | # TinyLean 2 | 3 | ![Supported Python versions](https://img.shields.io/pypi/pyversions/TinyLean) 4 | ![Lines of Python](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/anqurvanillapy/5d8f9b1d4b414b7076cf84f4eae089d9/raw/cloc.json) 5 | [![Test](https://github.com/anqurvanillapy/TinyLean/actions/workflows/test.yml/badge.svg)](https://github.com/anqurvanillapy/TinyLean/actions/workflows/test.yml) 6 | [![codecov](https://codecov.io/gh/anqurvanillapy/TinyLean/graph/badge.svg?token=M0P3GXBQDK)](https://codecov.io/gh/anqurvanillapy/TinyLean) 7 | 8 | 仅用不到 1K 行 Python 实现的 Lean 4 风格定理证明器。 9 | 10 | ```lean 11 | def TinyLean ≔ {T: Type} → (a: T) → T 12 | ``` 13 | 14 | * 你可以在这个项目中学到基础的**定理证明**(theorem proving)原理及其实现。 15 | * 丰富的中文注释,并以**工业界常用词汇**的使用优先,帮助你轻松地将工业及学术用语联系在一起。 16 | * 丰富的**单元测试**和高覆盖率,帮助你在做出任何修改时执行回归测试,或改写成其他语言项目时当做参考。 17 | 18 | > [!NOTE] 19 | > In pre-`v1` releases I used English everywhere in this project for convenience. So I feel very sorry for the early 20 | > stargazers who might not expect **full Chinese content** here (including documentation and comments), since I decided 21 | > to retarget the audience. Please reach me if you have trouble of any kind. 22 | 23 | ## ❓ 疑问 24 | 25 | > *“这个定理证明有多‘强’?”* 26 | 27 | 如果你了解甚至十分熟悉 PLT(编程语言理论)的知识,这个项目实现了以下语言特性: 28 | 29 |
30 | 点击展开剧透 31 | 39 |
40 | 41 | TinyLean 可能是你能找得到的以上特性结合一块的最短实现。在本文末尾的**探索**部分你还能找到更多资料。 42 | 43 | > *“我对定理证明不感兴趣,所以它还能做什么?”* 44 | 45 | 当你有一个比 C++、Java、TypeScript、Rust、Haskell 强大的类型系统,它的实现不到 1K 行,还适合移植到其他的语言时,再仔细思考一下你想拿它做些什么事情。 46 | 47 | 我用这些知识实现了 [RowScript] 编程语言,一个支持行多态(row polymorphism)的 JavaScript 方言。 48 | 49 | [RowScript]: https://github.com/rowscript/rowscript 50 | 51 | > *“项目名为什么是 TinyLean?为什么项目名不是 mini-lean,μLean 之类的其他名字呢?”* 52 | 53 | “TinyLean”是对 [TinyIdris] 项目的致敬,TinyIdris 是一个“极简” Idris 编程语言的实现。 54 | 55 | [TinyIdris]: https://github.com/edwinb/SPLV20 56 | 57 | > *“为什么使用中文?”* 58 | 59 | 即使使用英文,PLT 领域内就连最简单的术语都充满着歧义和晦涩。如果你对 PLT 60 | 里的各种术语仍未祛魅,去搞清楚 [dependent sum type] 和 [sum type] 的区别吧,这是每一个 PLer 学习过程中必吃的 💩。 61 | 62 | 我会用尽可能简单、常见的语言和术语,来帮助你祛魅的过程。 63 | 64 | [dependent sum type]: https://ncatlab.org/nlab/show/dependent+sum 65 | 66 | [sum type]: https://ncatlab.org/nlab/show/sum+type 67 | 68 | ## ⏬ 安装 69 | 70 | ### 尝试 71 | 72 | 如果只是想尝试玩玩本项目,可以从 PyPI 上安装完整的实现: 73 | 74 | ```bash 75 | pip install TinyLean 76 | ``` 77 | 78 | 用 `tinylean` 命令执行任意 `.lean` 文件: 79 | 80 | ```bash 81 | tinylean example.lean 82 | ``` 83 | 84 | 甚至可以执行一个 Markdown 文件,所有标记了 ```lean 的代码块都会执行类型检查。 85 | 86 | > [!IMPORTANT] 87 | > 你正在阅读的 README 文件是个合法的 TinyLean 文件! 88 | 89 | ```bash 90 | tinylean example.md 91 | ``` 92 | 93 | ### 本地阅读源码 94 | 95 | 克隆本项目: 96 | 97 | ```bash 98 | git clone https://github.com/anqurvanillapy/TinyLean 99 | cd TinyLean/ 100 | ``` 101 | 102 | 本地测试任何你的 `.lean`/`.md` 文件,比如本文件: 103 | 104 | ```bash 105 | python -m src.TinyLean .github/README.md 106 | ``` 107 | 108 | 安装并使用 `pytest` 执行所有单元测试: 109 | 110 | ```bash 111 | pip install pytest 112 | pytest 113 | ``` 114 | 115 | ## 🧙指南 116 | 117 | 那么,欢迎来到定理证明的世界!让我们一步步实现如何优雅的证明旷世难题 `1+1=2`。 118 | 119 | ### DTLC 120 | 121 | 一开始,这个世界仅有这些东西: 122 | 123 | * *类型的类型*(type of type,也叫 universe),即 `Type` 124 | * 引用,又叫变量名,形如 `x` 125 | * 函数,形如 `λ x y ↦ y` 126 | * 函数类型,形如 `(x: Type) → (y: Type) → Type` 127 | * 调用,形如 `x y` 128 | 129 | 这个世界有个名字,叫 DTLC(dependently-typed lambda calculus)。 130 | 131 | 做一些简单的 lambda 演算,比如定义一个叫 `id` 的恒等函数(identity function),接受一个 `a` 并返回去: 132 | 133 | ```lean 134 | def id (T: Type) (a: T): T := a 135 | 136 | def Hello: Type := Type 137 | 138 | example := id Type Hello 139 | ``` 140 | 141 | ### ITP 与 ATP 142 | 143 | 定理证明器的功能往往都是*交互式的*(ITP,interactive theorem proving),这的意思是,当你不太清楚目前证明所需要的信息时,你可以询问证明器,例如: 144 | 145 | ```lean 146 | def myLemma 147 | (a: Type) 148 | (b: (_: Type) -> Type) 149 | (c: b a) 150 | : Type 151 | := Type 152 | /- ^~~^ 尝试将这里的“Type”换成“_” -/ 153 | 154 | def myTheorem := myLemma Type (id Type) Type 155 | ``` 156 | 157 | 当你把 `:= Type` 替换成 `:= _` 时,你在代码中就留下了一个“洞(hole,又叫 goal)”,证明器会告诉你在 `_` 158 | 的位置要求填什么类型的值,并且上下文中都有哪些变量可以使用。证明器会输出类似这样的信息: 159 | 160 | ```plaintext 161 | .github/README.md:?:?: unsolved placeholder: 162 | ?u.? : Type 163 | 164 | context: 165 | (a: Type) 166 | (b: (_: Type) → Type) 167 | (c: (b a)) 168 | ``` 169 | 170 | 而所谓的*自动定理证明*(ATP,automatic theorem proving),则是根据上下文可用的变量,自动填入符合类型限制的值。 171 | 172 | TinyLean 能实现部分 ATP 的功能,可以填补一些“显而易见”的洞。而如你所见,完整的 ATP 是一个十分适合 AI 接手的问题。 173 | 174 | ### 隐式参数 175 | 176 | TinyLean 支持*隐式参数*(implicit argument)的特性,将我们的 `id` 函数改写,可以省去我们对 `T` 参数的传递,类型检查器能推导出来。 177 | 178 | ```lean 179 | def id1 {T: Type} (a: T): T := a 180 | 181 | example := id1 Hello 182 | ``` 183 | 184 | 而实际上,隐式参数的原理就是由证明器帮忙插入 `_`,来看是否能由证明器根据上下文自动填补答案。以上的例子等同于在 `id` 的调用中留下 185 | `_`: 186 | 187 | ```lean 188 | example := id _ Hello 189 | ``` 190 | 191 | 此外,如果你想显式地赋予 `id1` 中 `T` 的参数,不想由证明器填补,则使用以下语法: 192 | 193 | ```lean 194 | example := id1 (T := Type) Hello 195 | ``` 196 | 197 | ### 邱奇数 198 | 199 | 仅用 DTLC,我们仍旧能够表达自然数(natural number),比如运用[邱奇数](Church numerals)的方式。 200 | 201 | 定义 `CN` 类型: 202 | 203 | ```lean 204 | def CN: Type := 205 | (T: Type) -> (S: (n: T) -> T) -> (Z: T) -> T 206 | ``` 207 | 208 | 定义一个数字 `3`,它形如“零的后继的后继的后继”: 209 | 210 | ```lean 211 | def _3: CN := fun T S Z => S (S (S Z)) 212 | ``` 213 | 214 | 定义加法和乘法: 215 | 216 | ```lean 217 | def addCN (a: CN) (b: CN): CN := 218 | fun T S Z => (a T S) (b T S Z) 219 | 220 | def mulCN (a: CN) (b: CN): CN := 221 | fun T S Z => (a T) (b T S) Z 222 | ``` 223 | 224 | 做些简单演算: 225 | 226 | ```lean 227 | def _6: CN := addCN _3 _3 228 | def _9: CN := mulCN _3 _3 229 | ``` 230 | 231 | [邱奇数]: https://en.wikipedia.org/wiki/Church_encoding 232 | 233 | ### 相等 234 | 235 | 编写证明最重要的工具是*相等*(equality),仅有 `1+1` 而不能证明 `1+1=2` 是荒唐的。而仅使用 236 | DTLC,我们依旧可以表达出等式,比如运用[Leibniz 等式]的方式。 237 | 238 | 定义 `LEq` 类型、`lRefl`(reflexivity,自反性)和 `lSym`(symmetry,对称性): 239 | 240 | ```lean 241 | def LEq {T: Type} (a: T) (b: T): Type := 242 | (p: (v: T) -> Type) -> (pa: p a) -> p b 243 | 244 | def lRefl {T: Type} (a: T): LEq a a := 245 | fun p pa => pa 246 | 247 | def lSym {T: Type} (a: T) (b: T) (p: LEq a b): LEq b a := 248 | (p (fun b => LEq b a)) (lRefl a) 249 | ``` 250 | 251 | 让我们证明刚刚的 `_9 = _3 + _6`: 252 | 253 | ```lean 254 | example: LEq _9 (addCN _3 _6) := lRefl _9 255 | ``` 256 | 257 | [Leibniz 等式]: https://en.wikipedia.org/wiki/Equality_(mathematics) 258 | 259 | ### 归纳数据类型 260 | 261 | 我们可以用归纳数据类型(inductive data type)来定义一个新的类型,比如我们终于可以有一个更直观的自然数了: 262 | 263 | ```lean 264 | inductive N where 265 | | Z 266 | | S (n: N) 267 | open N 268 | ``` 269 | 270 | 这个定义已经非常接近[Peano 公理]所定义的自然数: 271 | 272 | 1. 0(`Z`)是一个自然数(`N`) 273 | 2. 对于所有自然数 `n`,`n` 的后继(`S n`)也是一个自然数 274 | 275 | 其加法定义,运用递归(recursion)也更加自然: 276 | 277 | ```lean 278 | def addN (n: N) (m: N): N := 279 | match n with 280 | | Z => m 281 | | S pred => S (addN pred m) 282 | 283 | example := addN (S Z) (S Z) 284 | ``` 285 | 286 | 假设一个归纳数据类型没有任何构造器(constructor),则它就是一个空类型(bottom type,即 ⊥): 287 | 288 | ```lean 289 | inductive Bot where 290 | open Bot 291 | ``` 292 | 293 | [爆炸原理](ex falso)是指我们可以从矛盾中获取出任何事物,我们可以用 `nomatch` 写出这样的定理: 294 | 295 | ```lean 296 | def exFalso (T: Type) (x: Bot): T := nomatch x 297 | ``` 298 | 299 | 这里,我们凭空拿出来了一个 `T` 类型的值。 300 | 301 | [Peano 公理]: https://en.wikipedia.org/wiki/Peano_axioms 302 | 303 | [爆炸原理]: https://en.wikipedia.org/wiki/Principle_of_explosion 304 | 305 | ### 索引类型 306 | 307 | 归纳数据类型是可以携带参数(parameter)的,携带参数时我们称这样的类型为索引类型(indexed type),因为它“被某个值索引(indexed 308 | by a value)”。这样的类型我们还可以称作“归纳集(inductive family)”。 309 | 310 | 比如在 C++ 中,我们可以用“[非类型模板参数](non-type template parameter)”实现 `std::array` 的写法,此时 `3` 311 | 记录着数组的长度,它只是一个普通的数值。 312 | 313 | 同样的,我们可以定义一个能在类型上记录长度的 vector 类型: 314 | 315 | ```lean 316 | inductive Vec (A: Type) (n: N) where 317 | | Nil (n := Z) 318 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m) 319 | open Vec 320 | ``` 321 | 322 | 这里的 `(n := Z)` 意思是指,当我使用 `Nil` 构造一个空 vector 时,它的类型参数 `n` 会被填为 `Z`,代表其长度为 0。 323 | 324 | 几个长度不同的 vector 的例子: 325 | 326 | ```lean 327 | def v0: Vec Type Z := Nil 328 | def v1: Vec Type (S Z) := Cons N v0 329 | def v2: Vec Type (S (S Z)) := Cons CN v1 330 | ``` 331 | 332 | [非类型模板参数]: https://en.cppreference.com/w/cpp/language/template_parameters#Non-type_template_parameter 333 | 334 | ### 依赖模式匹配 335 | 336 | 索引类型能帮助我们排除掉不可能出现的模式(pattern)。举个例子,当我们使用 `Nil` 构造一个空 vector,并尝试对它进行 `match` 337 | 匹配时,很明显我们不需要再去考虑 `Cons` 的情况。这样的特性称作“依赖模式匹配(dependent pattern matching)”。 338 | 339 | ```lean 340 | example := 341 | match v0 with 342 | | Nil => Z 343 | ``` 344 | 345 | 假设我们补充上 `Cons` 的情况,证明器会报出如下错误: 346 | 347 | ```plaintext 348 | .github/README.md:?:?: type mismatch: 349 | want: 350 | (Vec Type N.Z) 351 | 352 | got: 353 | (Vec ?m.? (N.S ?m.?)) 354 | ``` 355 | 356 | 所以一个空类型不一定是没有构造器的类型,也有可能是完全没办法构造出来的类型,例如: 357 | 358 | ```lean 359 | inductive Weird (n: N) where 360 | | MkWeird (n := Z) 361 | open Weird 362 | 363 | example (A: Type) (x: Weird (S Z)): A := nomatch x 364 | ``` 365 | 366 | 此时 `Weird (S Z)` 也是一个空类型,因为我们完全没办法构造一个这样类型的值。 367 | 368 | ### 新的相等类型 369 | 370 | 通过索引类型的特性,我们可以定义出更好理解的相等类型了: 371 | 372 | ```lean 373 | inductive Eq {T: Type} (a: T) (b: T) where 374 | | Refl (a := b) 375 | open Eq 376 | ``` 377 | 378 | 用 `addN` 和 `Eq` 测试一下我们的 `1+1=2`: 379 | 380 | ```lean 381 | example: Eq (addN (S Z) (S Z)) (S (S Z)) := Refl (T := N) 382 | ``` 383 | 384 | ### 类 385 | 386 | 在目前我们介绍的类型系统世界中,所有类型都同属于 `Type` 之下,我们没有办法对类型进行二次“归类”,这个 `Type` 忽然就变成了“新的 387 | `any`”。这样的坏处在于,我希望 `int` 类型的默认值是 `0`,希望 `string` 类型的默认值是 `""`,而我能通过一个函数 388 | `default::` 就能生成这个类型的默认值,这要怎么做到呢? 389 | 390 | 类型类(typeclass,又叫 trait)则能很好地解决这个问题: 391 | 392 | ```lean 393 | class Default (T: Type) where 394 | default: T 395 | open Default 396 | ``` 397 | 398 | 有了 `Default` 这个类(class)后,我们就可以为不同的类型定义 `Default` 的实例(instance)。 399 | 400 | ### 实例 401 | 402 | 为 `N` 类型定义它的默认值 `Z`: 403 | 404 | ```lean 405 | instance: Default N 406 | where 407 | default := Z 408 | ``` 409 | 410 | > [!CAUTION] 411 | > 注意这里 `where` 关键词需要写到新的一行,因为 Lean 4 语法的灵活性很大,为了保持 TinyLean 语法声明文件的简洁,很多语法歧义尚未处理。 412 | 413 | 我们写个 `(default N) = Z` 的证明: 414 | 415 | ```lean 416 | example: Eq Z (default N) := Refl (T := N) 417 | ``` 418 | 419 | ### 类参数 420 | 421 | 我们可以使用类参数(class parameter)来检查某个类型(type)是否符合类(class)的要求,例如: 422 | 423 | ```lean 424 | def mustBeDefault (T: Type) [p: Default T] := Type 425 | ``` 426 | 427 | 调用 `mustBeDefault` 时,我们要求参数 `T` 符合 `Default` 这一个类的限制。 428 | 429 | ```lean 430 | example := mustBeDefault N 431 | ``` 432 | 433 | 很明显,`N` 类型符合这个限制。而当我们传入其他的类型,例如 `Bot` 时,证明器会告诉我们找不到对应的实例声明: 434 | 435 | ```plaintext 436 | .github/README.md:?:?: no such instance for class '(Default Bot)' 437 | ``` 438 | 439 | ### 操作符重载 440 | 441 | 有了类,操作符重载(operator overloading)也能够轻松实现。在 TinyLean 中,中缀操作符 `+`、`-`、`*`、`/` 会被简单地翻译成 442 | `add`、`sub`、`mul`、`div` 的函数调用,所以我们要先定义好对应的类和类方法(class method): 443 | 444 | ```lean 445 | class Add {T: Type} where 446 | add: (a: T) -> (b: T) -> T 447 | open Add 448 | ``` 449 | 450 | > [!NOTE] 451 | > 注意到这个 `add` 的操作是*同构*(homogenous)的,也就是输入和输出的类型都一致,更好的定义则是*异构*(heterogeneous)的,即类似 452 | > `T → U → V` 的定义,在此我们省略异构加法的讨论。 453 | 454 | 为 `N` 类型定义相应的实例: 455 | 456 | ```lean 457 | instance: Add (T := N) 458 | where 459 | add := addN 460 | ``` 461 | 462 | 这样,我们就能在 `1+1=2` 的证明中使用中缀操作符了: 463 | 464 | ```lean 465 | example 466 | : Eq (S (S Z)) ((S Z) + (S Z)) 467 | := Refl (T := N) 468 | ``` 469 | 470 | ## 🔍 探索 471 | 472 | 接下来,你可以继续探索以下的世界: 473 | 474 | ### 源码 475 | 476 | 从 [`tests/onboard.py`] 文件开始阅读项目源码。 477 | 478 | [`tests/onboard.py`]: ../src/TinyLean/tests/onboard.py 479 | 480 | ### 未知 481 | 482 | 如果你觉得在“指南”阶段仍有许多困惑,甚至完全没法理解发生了什么,这是正常的。“指南”实际上更像是对 TinyLean 483 | 特性的展示,而不是一个正儿八经的定理证明教程,因为这样的优质教程其实是很多的,例如: 484 | 485 | * [Theorem Proving in Lean 4](https://leanprover.github.io/theorem_proving_in_lean4/) 486 | * [Programming Language Foundations in Agda](https://plfa.github.io/) 487 | 488 | 这些教程/书籍对我而言,并不是第一次读了就全部懂了,而是三至四年内反复地、片段式地不断重复阅读其中的某些片段才明白的。 489 | 490 | 而我得坦白,让我真正理解类型论的方式,是自己亲手实现一个又一个类型论。 491 | 492 | ### 跃迁 493 | 494 | TODO 495 | 496 | ## 🫡 致谢 497 | 498 | TODO 499 | 500 | --- 501 | 502 | MIT License Copyright © Anqur 503 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | tags: 5 | - v* 6 | jobs: 7 | build-and-publish: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | id-token: write 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v4 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.x' 18 | - name: Install build dependencies 19 | run: pip install --upgrade build check-manifest twine 20 | - name: Build 21 | run: | 22 | check-manifest 23 | python -m build 24 | python -m twine check dist/* 25 | - name: Publish 26 | uses: pypa/gh-action-pypi-publish@release/v1 27 | with: 28 | skip-existing: true 29 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | - push 4 | - pull_request 5 | jobs: 6 | test: 7 | strategy: 8 | matrix: 9 | python: 10 | - '3.12' 11 | - '3.13' 12 | platform: 13 | - ubuntu-latest 14 | - macos-latest 15 | - windows-latest 16 | runs-on: ${{ matrix.platform }} 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python ${{ matrix.python }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python }} 23 | - name: Install test dependencies 24 | run: | 25 | pip install --upgrade build check-manifest twine black pytest pytest-cov 26 | pip install . 27 | - name: Lint 28 | run: black . --check 29 | - name: Build 30 | run: | 31 | check-manifest 32 | python -m build 33 | python -m twine check dist/* 34 | - name: Test and report coverage 35 | run: pytest --cov --cov-branch --cov-report=xml 36 | - name: Upload coverage reports to Codecov 37 | uses: codecov/codecov-action@v5 38 | with: 39 | token: ${{ secrets.CODECOV_TOKEN }} 40 | cloc: 41 | runs-on: ubuntu-latest 42 | steps: 43 | - uses: actions/checkout@v4 44 | - name: Install dependencies 45 | run: | 46 | sudo apt-get update 47 | sudo apt-get install cloc jq 48 | - name: Run CLOC 49 | run: cloc --exclude-dir=dist,venv,build,tests --json . | jq '"CLOC="+(.Python.code|tostring)' -r >> $GITHUB_ENV 50 | - name: Create CLOC badge 51 | uses: schneegans/dynamic-badges-action@v1.7.0 52 | with: 53 | auth: ${{ secrets.CLOC_GIST_SECRET }} 54 | gistID: 5d8f9b1d4b414b7076cf84f4eae089d9 55 | filename: cloc.json 56 | label: Python 行数 57 | message: ${{ env.CLOC }} 58 | color: orange 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | .vscode/ 4 | 5 | build/ 6 | dist/ 7 | *.egg-info/ 8 | *.egg 9 | *.py[cod] 10 | __pycache__/ 11 | *.so 12 | *~ 13 | venv/ 14 | 15 | .nox 16 | .cache 17 | .coverage 18 | coverage.* 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Anqur 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "TinyLean" 7 | version = "1.0.0" 8 | description = "Tiny theorem prover with syntax like Lean 4" 9 | readme = ".github/README.md" 10 | requires-python = ">=3.12" 11 | license = { file = "LICENSE" } 12 | keywords = [ 13 | "lean", 14 | "lean4", 15 | "theorem-proving", 16 | "theorem-prover", 17 | "programming-language", 18 | ] 19 | authors = [ 20 | { name = "Anqur", email = "anqurvanillapy@gmail.com" }, 21 | ] 22 | classifiers = [ 23 | "Intended Audience :: Education", 24 | "Topic :: Education", 25 | "License :: OSI Approved :: MIT License", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.12", 28 | "Programming Language :: Python :: 3.13", 29 | "Programming Language :: Python :: 3 :: Only", 30 | ] 31 | dependencies = [ 32 | "pyparsing==3.2.1" 33 | ] 34 | 35 | [project.urls] 36 | "Homepage" = "https://github.com/anqurvanillapy/TinyLean" 37 | "Bug Reports" = "https://github.com/anqurvanillapy/TinyLean/issues" 38 | "Source" = "https://github.com/anqurvanillapy/TinyLean" 39 | 40 | [project.scripts] 41 | tinylean = "TinyLean.__main__:main" 42 | -------------------------------------------------------------------------------- /src/TinyLean/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from itertools import count 3 | 4 | fresh = count(1).__next__ 5 | 6 | 7 | @dataclass(frozen=True) 8 | class Name: 9 | text: str 10 | id: int = field(default_factory=fresh) 11 | 12 | def __str__(self): 13 | return self.text 14 | 15 | def is_unbound(self): 16 | return self.text == "_" 17 | 18 | 19 | @dataclass(frozen=True) 20 | class Param[T]: 21 | name: Name 22 | type: T 23 | is_implicit: bool 24 | is_class: bool = False 25 | 26 | def __str__(self): 27 | l, r = "()" if not self.is_implicit else "{}" if not self.is_class else "[]" 28 | return f"{l}{self.name}: {self.type}{r}" 29 | 30 | 31 | @dataclass(frozen=True) 32 | class Decl: 33 | loc: int 34 | 35 | 36 | @dataclass(frozen=True) 37 | class Def[T](Decl): 38 | name: Name 39 | params: list[Param[T]] 40 | ret: T 41 | body: T 42 | 43 | 44 | @dataclass(frozen=True) 45 | class Sig[T](Decl): 46 | name: Name 47 | params: list[Param[T]] 48 | ret: T 49 | 50 | 51 | @dataclass(frozen=True) 52 | class Example[T](Decl): 53 | params: list[Param[T]] 54 | ret: T 55 | body: T 56 | 57 | 58 | @dataclass(frozen=True) 59 | class Ctor[T](Decl): 60 | name: Name 61 | params: list[Param[T]] 62 | ty_args: list[tuple[T, T]] 63 | ty_name: Name | None = None 64 | 65 | 66 | @dataclass(frozen=True) 67 | class Data[T](Decl): 68 | name: Name 69 | params: list[Param[T]] 70 | ctors: list[Ctor[T]] 71 | 72 | 73 | @dataclass(frozen=True) 74 | class Field[T](Decl): 75 | name: Name 76 | type: T 77 | cls_name: Name | None = None 78 | 79 | 80 | @dataclass(frozen=True) 81 | class Class[T](Decl): 82 | name: Name 83 | params: list[Param[T]] 84 | fields: list[Field[T]] 85 | instances: list[int] = field(default_factory=list) 86 | 87 | 88 | @dataclass(frozen=True) 89 | class Instance[T](Decl): 90 | type: T 91 | fields: list[tuple[T, T]] 92 | id: int = field(default_factory=fresh) 93 | -------------------------------------------------------------------------------- /src/TinyLean/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from pyparsing import util, exceptions 5 | 6 | from . import ast, ir 7 | 8 | 9 | fatal = lambda m: sys.exit(int(not print(m))) 10 | 11 | 12 | _F = Path(sys.argv[1]) if len(sys.argv) > 1 else None 13 | 14 | 15 | def fatal_on(text: str, loc: int, m: str): 16 | fatal(f"{_F}:{util.lineno(loc, text)}:{util.col(loc, text)}: {m}") 17 | 18 | 19 | def main(file=_F if _F else fatal("usage: tinylean FILE")): 20 | try: 21 | with open(file, encoding="utf-8") as f: 22 | text = f.read() 23 | ast.check_string(text, file.suffix == ".md") 24 | except OSError as e: 25 | fatal(e) 26 | except exceptions.ParseException as e: 27 | fatal_on(text, e.loc, str(e).split("(at char")[0].strip()) 28 | except ast.UndefinedVariableError as e: 29 | v, loc = e.args 30 | fatal_on(text, loc, f"undefined variable '{v}'") 31 | except ast.DuplicateVariableError as e: 32 | v, loc = e.args 33 | fatal_on(text, loc, f"duplicate variable '{v}'") 34 | except ast.TypeMismatchError as e: 35 | want, got, loc = e.args 36 | fatal_on(text, loc, f"type mismatch:\nwant:\n {want}\n\ngot:\n {got}") 37 | except ast.UnsolvedPlaceholderError as e: 38 | name, ctx, ty, loc = e.args 39 | ty_msg = f" {name} : {ty}" 40 | ctx_msg = "".join([f"\n {p}" for p in ctx.values()]) if ctx else " (none)" 41 | fatal_on(text, loc, f"unsolved placeholder:\n{ty_msg}\n\ncontext:{ctx_msg}") 42 | except ast.UnknownCaseError as e: 43 | want, got, loc = e.args 44 | fatal_on(text, loc, f"cannot match case '{got}' of type '{want}'") 45 | except ast.DuplicateCaseError as e: 46 | name, loc = e.args 47 | fatal_on(text, loc, f"duplicate case '{name}'") 48 | except ast.CaseParamMismatchError as e: 49 | want, got, loc = e.args 50 | fatal_on(text, loc, f"want '{want}' case parameters, but got '{got}'") 51 | except ast.CaseMissError as e: 52 | miss, loc = e.args 53 | fatal_on(text, loc, f"missing case: {miss}") 54 | except ast.FieldMissError as e: 55 | miss, loc = e.args 56 | fatal_on(text, loc, f"missing field: {miss}") 57 | except ast.UnknownFieldError as e: 58 | want, got, loc = e.args 59 | fatal_on(text, loc, f"unknown field '{got}' of class '{want}'") 60 | except ir.NoInstanceError as e: 61 | name, loc = e.args 62 | fatal_on(text, loc, f"no such instance for class '{name}'") 63 | except RecursionError as e: 64 | print("Program too complex or oops you just got '⊥'! Please report this issue:") 65 | raise e 66 | except Exception as e: 67 | print("Internal compiler error! Please report this issue:") 68 | raise e 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /src/TinyLean/ast.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from itertools import chain 3 | from dataclasses import dataclass, field 4 | from typing import OrderedDict, cast as _c 5 | 6 | from pyparsing import ParseResults 7 | 8 | from . import ( 9 | Name, 10 | Param, 11 | Decl, 12 | ir, 13 | grammar as _g, 14 | fresh, 15 | Def, 16 | Example, 17 | Ctor, 18 | Data, 19 | Sig, 20 | Field, 21 | Class, 22 | Instance, 23 | ) 24 | 25 | 26 | @dataclass(frozen=True) 27 | class Node: 28 | loc: int 29 | 30 | 31 | @dataclass(frozen=True) 32 | class Type(Node): ... 33 | 34 | 35 | @dataclass(frozen=True) 36 | class Ref(Node): 37 | name: Name 38 | 39 | 40 | @dataclass(frozen=True) 41 | class FnType(Node): 42 | param: Param[Node] 43 | ret: Node 44 | 45 | 46 | @dataclass(frozen=True) 47 | class Fn(Node): 48 | param: Name 49 | body: Node 50 | 51 | 52 | @dataclass(frozen=True) 53 | class Call(Node): 54 | callee: Node 55 | arg: Node 56 | implicit: str | bool 57 | 58 | 59 | @dataclass(frozen=True) 60 | class Placeholder(Node): 61 | is_user: bool 62 | 63 | 64 | @dataclass(frozen=True) 65 | class Nomatch(Node): 66 | arg: Node 67 | 68 | 69 | @dataclass(frozen=True) 70 | class Case(Node): 71 | ctor: Ref 72 | params: list[Name] 73 | body: Node 74 | 75 | 76 | @dataclass(frozen=True) 77 | class Match(Node): 78 | arg: Node 79 | cases: list[Case] 80 | 81 | 82 | _g.name.add_parse_action(lambda r: Name(r[0][0])) 83 | 84 | _ops = {"+": "add", "-": "sub", "*": "mul", "/": "div"} 85 | 86 | 87 | def _infix(loc: int, ret: ParseResults): 88 | r = ret[0] 89 | if not isinstance(r, ParseResults): 90 | return r 91 | return Call(loc, Call(loc, Ref(loc, Name(_ops[r[1]])), r[0], False), r[2], False) 92 | 93 | 94 | _g.expr.add_parse_action(_infix) 95 | _g.type_.add_parse_action(lambda l, r: Type(l)) 96 | _g.ph.add_parse_action(lambda l, r: Placeholder(l, True)) 97 | _g.ref.add_parse_action(lambda l, r: Ref(l, r[0][0])) 98 | _g.i_param.add_parse_action(lambda r: Param(r[0], r[1], True)) 99 | _g.e_param.add_parse_action(lambda r: Param(r[0], r[1], False)) 100 | _g.c_param.add_parse_action(lambda r: Param(r[0], r[1], True, True)) 101 | _g.fn_type.add_parse_action(lambda l, r: FnType(l, r[0], r[1])) 102 | _g.fn.add_parse_action( 103 | lambda l, r: reduce(lambda a, n: Fn(l, n, a), reversed(r[0]), r[1]) 104 | ) 105 | _g.match.add_parse_action(lambda l, r: Match(l, r[0], list(r[1]))) 106 | _g.case.add_parse_action(lambda r: Case(r[0].loc, r[0], r[1], r[2])) 107 | _g.nomatch.add_parse_action(lambda l, r: Nomatch(l, r[0][0])) 108 | _g.i_arg.add_parse_action(lambda l, r: (r[1], r[0])) 109 | _g.e_arg.add_parse_action(lambda l, r: (r[0], False)) 110 | _g.call.add_parse_action( 111 | lambda l, r: reduce(lambda a, b: Call(l, a, b[0], b[1]), r[1:], r[0]) 112 | ) 113 | _g.p_expr.add_parse_action(lambda r: r[0]) 114 | 115 | _g.return_type.add_parse_action(lambda l, r: r[0] if len(r) else Placeholder(l, False)) 116 | _g.def_.add_parse_action(lambda r: Def(r[0].loc, r[0].name, list(r[1]), r[2], r[3])) 117 | _g.example.add_parse_action(lambda l, r: Example(l, list(r[0]), r[1], r[2])) 118 | _g.type_arg.add_parse_action(lambda r: (r[0], r[1])) 119 | _g.ctor.add_parse_action(lambda r: Ctor(r[0].loc, r[0].name, list(r[1]), list(r[2]))) 120 | _g.data.add_condition( 121 | lambda r: r[0].name.text == r[3], message="open and datatype name mismatch" 122 | ).add_parse_action(lambda r: Data(r[0].loc, r[0].name, list(r[1]), list(r[2]))) 123 | _g.c_field.add_parse_action(lambda l, r: Field(l, r[0], r[1])) 124 | _g.class_.add_condition( 125 | lambda r: r[0].name.text == r[3], message="open and class name mismatch" 126 | ).add_parse_action(lambda r: Class(r[0].loc, r[0].name, list(r[1]), list(r[2]))) 127 | _g.i_field.add_parse_action(lambda r: (r[0], r[1])) 128 | _g.inst.add_parse_action(lambda l, r: Instance(l, r[0], list(r[1]))) 129 | 130 | 131 | @dataclass(frozen=True) 132 | class Parser: 133 | is_markdown: bool = False 134 | 135 | def __ror__(self, s: str): 136 | if not self.is_markdown: 137 | return list(_g.program.parse_string(s, parse_all=True)) 138 | return chain.from_iterable(r[0] for r in _g.markdown.scan_string(s)) 139 | 140 | 141 | class DuplicateVariableError(Exception): ... 142 | 143 | 144 | class UndefinedVariableError(Exception): ... 145 | 146 | 147 | @dataclass(frozen=True) 148 | class NameResolver: 149 | locals: dict[str, Name] = field(default_factory=dict) 150 | globals: dict[str, Name] = field(default_factory=dict) 151 | 152 | def __ror__(self, decls: list[Decl]): 153 | return [self._decl(d) for d in decls] 154 | 155 | def _decl(self, decl: Decl) -> Decl: 156 | self.locals.clear() 157 | 158 | if isinstance(decl, Def) or isinstance(decl, Data): 159 | self._insert_global(decl.loc, decl.name) 160 | 161 | if isinstance(decl, Def) or isinstance(decl, Example): 162 | return self._def_or_example(decl) 163 | 164 | if isinstance(decl, Data): 165 | return self._data(decl) 166 | 167 | if isinstance(decl, Class): 168 | return self._class(decl) 169 | 170 | return self._inst(_c(Instance, decl)) 171 | 172 | def _def_or_example(self, d: Def[Node] | Example[Node]): 173 | params = self._params(d.params) 174 | ret = self.expr(d.ret) 175 | body = self.expr(d.body) 176 | if isinstance(d, Example): 177 | return Example(d.loc, params, ret, body) 178 | return Def(d.loc, d.name, params, ret, body) 179 | 180 | def _data(self, d: Data[Node]): 181 | params = self._params(d.params) 182 | ctors = [self._ctor(c, d.name) for c in d.ctors] 183 | return Data(d.loc, d.name, params, ctors) 184 | 185 | def _ctor(self, c: Ctor[Node], ty_name: Name): 186 | params = self._params(c.params) 187 | ty_args = [(self.expr(n), self.expr(t)) for n, t in c.ty_args] 188 | for p in params: 189 | del self.locals[p.name.text] 190 | self._insert_global(c.loc, c.name) 191 | return Ctor(c.loc, c.name, params, ty_args, ty_name) 192 | 193 | def _class(self, c: Class[Node]): 194 | params = self._params(c.params) 195 | fields = [] 196 | for f in c.fields: 197 | self._insert_global(f.loc, f.name) 198 | fields.append(Field(f.loc, f.name, self.expr(f.type))) 199 | self._insert_global(c.loc, c.name) 200 | return Class(c.loc, c.name, params, fields) 201 | 202 | def _inst(self, i: Instance[Node]): 203 | t = self.expr(i.type) 204 | fields = [] 205 | field_ids = set() 206 | for n, v in i.fields: 207 | n = _c(Ref, self.expr(n)) 208 | if n.name.id in field_ids: 209 | raise DuplicateVariableError(n.name.text, n.loc) 210 | field_ids.add(n.name.id) 211 | fields.append((n, (self.expr(v)))) 212 | return Instance(i.loc, t, fields) 213 | 214 | def _params(self, params: list[Param[Node]]): 215 | ret = [] 216 | for p in params: 217 | self._insert_local(p.name) 218 | ret.append(Param(p.name, self.expr(p.type), p.is_implicit, p.is_class)) 219 | return ret 220 | 221 | def expr(self, n: Node) -> Node: 222 | if isinstance(n, Ref): 223 | if v := self.locals.get(n.name.text, self.globals.get(n.name.text)): 224 | return Ref(n.loc, v) 225 | raise UndefinedVariableError(n.name.text, n.loc) 226 | if isinstance(n, FnType): 227 | typ = self.expr(n.param.type) 228 | b = self._with_locals(n.ret, n.param.name) 229 | p = Param(n.param.name, typ, n.param.is_implicit, n.param.is_class) 230 | return FnType(n.loc, p, b) 231 | if isinstance(n, Fn): 232 | return Fn(n.loc, n.param, self._with_locals(n.body, n.param)) 233 | if isinstance(n, Call): 234 | return Call(n.loc, self.expr(n.callee), self.expr(n.arg), n.implicit) 235 | if isinstance(n, Nomatch): 236 | return Nomatch(n.loc, self.expr(n.arg)) 237 | if isinstance(n, Match): 238 | arg = self.expr(n.arg) 239 | cases = [] 240 | for c in n.cases: 241 | ctor = _c(Ref, self.expr(c.ctor)) 242 | body = self._with_locals(c.body, *c.params) 243 | cases.append(Case(c.loc, ctor, c.params, body)) 244 | return Match(n.loc, arg, cases) 245 | assert isinstance(n, Type) or isinstance(n, Placeholder) 246 | return n 247 | 248 | def _with_locals(self, node: Node, *names: Name): 249 | olds = [(v, self._insert_local(v)) for v in names] 250 | ret = self.expr(node) 251 | for v, old in olds: 252 | if old: 253 | self._insert_local(old) 254 | elif not v.is_unbound(): 255 | del self.locals[v.text] 256 | return ret 257 | 258 | def _insert_local(self, v: Name): 259 | if v.is_unbound(): 260 | return None 261 | old = self.locals.get(v.text) 262 | self.locals[v.text] = v 263 | return old 264 | 265 | def _insert_global(self, loc: int, name: Name): 266 | if not name.is_unbound(): 267 | if name.text in self.globals: 268 | raise DuplicateVariableError(name.text, loc) 269 | self.globals[name.text] = name 270 | 271 | 272 | class TypeMismatchError(Exception): ... 273 | 274 | 275 | class UnsolvedPlaceholderError(Exception): ... 276 | 277 | 278 | class UnknownCaseError(Exception): ... 279 | 280 | 281 | class DuplicateCaseError(Exception): ... 282 | 283 | 284 | class CaseParamMismatchError(Exception): ... 285 | 286 | 287 | class CaseMissError(Exception): ... 288 | 289 | 290 | class FieldMissError(Exception): ... 291 | 292 | 293 | class UnknownFieldError(Exception): ... 294 | 295 | 296 | @dataclass(frozen=True) 297 | class TypeChecker: 298 | globals: dict[int, Decl] = field(default_factory=dict) 299 | locals: dict[int, Param[ir.IR]] = field(default_factory=dict) 300 | holes: OrderedDict[int, ir.Hole] = field(default_factory=OrderedDict) 301 | recur_ids: set[int] = field(default_factory=set) 302 | 303 | def __ror__(self, ds: list[Decl]): 304 | ret = [self._run(d) for d in ds] 305 | for i, h in self.holes.items(): 306 | if h.answer.is_unsolved(): 307 | ty = self._inliner().run(h.answer.type) 308 | if _is_solved_class(ty): 309 | continue 310 | p = ir.Placeholder(i, h.is_user) 311 | raise UnsolvedPlaceholderError(str(p), h.locals, ty, h.loc) 312 | return ret 313 | 314 | def _run(self, decl: Decl) -> Decl: 315 | self.locals.clear() 316 | if isinstance(decl, Def) or isinstance(decl, Example): 317 | return self._def_or_example(decl) 318 | if isinstance(decl, Data): 319 | return self._data(decl) 320 | if isinstance(decl, Instance): 321 | return self._inst(decl) 322 | return self._class(_c(Class, decl)) 323 | 324 | def _def_or_example(self, d: Def[Node] | Example[Node]): 325 | params = self._params(d.params) 326 | ret = self.check(d.ret, ir.Type()) 327 | 328 | if isinstance(d, Def): 329 | self.globals[d.name.id] = Sig(d.loc, d.name, params, ret) 330 | body = self.check(d.body, ret) 331 | 332 | if isinstance(d, Example): 333 | return Example(d.loc, params, ret, body) 334 | 335 | checked = Def(d.loc, d.name, params, ret, body) 336 | self.globals[d.name.id] = checked 337 | return checked 338 | 339 | def _data(self, d: Data[Node]): 340 | params = self._params(d.params) 341 | data = Data(d.loc, d.name, params, []) 342 | self.globals[d.name.id] = data 343 | data.ctors.extend(self._ctor(c) for c in d.ctors) 344 | return data 345 | 346 | def _ctor(self, c: Ctor[Node]): 347 | params = self._params(c.params) 348 | ty_args: list[tuple[ir.IR, ir.IR]] = [] 349 | for x, v in c.ty_args: 350 | x_val, x_ty = self.infer(x) 351 | v_val = self.check(v, x_ty) 352 | ty_args.append((x_val, v_val)) 353 | ctor = Ctor(c.loc, c.name, params, ty_args, c.ty_name) 354 | self.globals[c.name.id] = ctor 355 | return ctor 356 | 357 | def _class(self, c: Class[Node]): 358 | params = self._params(c.params) 359 | fs = [ 360 | Field(f.loc, f.name, self.check(f.type, ir.Type()), c.name) 361 | for f in c.fields 362 | ] 363 | self.globals.update({f.name.id: f for f in fs}) 364 | cls = Class(c.loc, c.name, params, fs) 365 | self.globals[c.name.id] = cls 366 | return cls 367 | 368 | def _inst(self, i: Instance[Node]): 369 | ty = self.check(i.type, ir.Type()) 370 | if not isinstance(ty, ir.Class): 371 | raise TypeMismatchError("class", str(ty), i.type.loc) 372 | c = _c(Class, self.globals[ty.name.id]) 373 | vals = {_c(Ref, n).name.id: (n, f) for n, f in i.fields} 374 | fields = [] 375 | for f in c.fields: 376 | nv = vals.pop(f.name.id, None) 377 | if not nv: 378 | raise FieldMissError(f.name.text, i.loc) 379 | f_decl = _c(Field, self.globals[f.name.id]) 380 | field_ty = ir.from_field(f_decl, c, False)[1] 381 | env = [] 382 | for ty_arg in ty.args: 383 | assert isinstance(field_ty, ir.FnType) 384 | env.append((field_ty.param.name, ty_arg)) 385 | field_ty = field_ty.ret 386 | f_type = self._inliner().run_with(field_ty, *env) 387 | fields.append((ir.Ref(f.name), self.check(nv[1], f_type))) 388 | for n, _ in vals.values(): 389 | assert isinstance(n, Ref) 390 | raise UnknownFieldError(c.name.text, n.name.text, n.loc) 391 | c.instances.append(i.id) 392 | inst = Instance(i.loc, _c(ir.IR, ty), fields, i.id) 393 | self.globals[i.id] = inst 394 | return inst 395 | 396 | def _params(self, params: list[Param[Node]]): 397 | ret = [] 398 | for p in params: 399 | t = self.check(p.type, ir.Type()) 400 | if p.is_class: 401 | t = self._inliner().run(t) 402 | assert p.is_implicit 403 | if not isinstance(t, ir.Class): 404 | raise TypeMismatchError("class", str(t), p.type.loc) 405 | param = Param(p.name, t, p.is_implicit, p.is_class) 406 | self.locals[p.name.id] = param 407 | ret.append(param) 408 | return ret 409 | 410 | def check(self, n: Node, typ: ir.IR) -> ir.IR: 411 | if isinstance(n, Fn): 412 | t = self._inliner().run(typ) 413 | if not isinstance(t, ir.FnType): 414 | raise TypeMismatchError(str(t), "function", n.loc) 415 | ret = self._inliner().run_with(t.ret, (t.param.name, ir.Ref(n.param))) 416 | p = Param(n.param, t.param.type, t.param.is_implicit, t.param.is_class) 417 | return ir.Fn(p, self._check_with(n.body, ret, p)) 418 | 419 | holes_len = len(self.holes) 420 | val, got = self.infer(n) 421 | got = self._inliner().run(got) 422 | want = self._inliner().run(typ) 423 | 424 | if _can_insert_placeholders(want): 425 | if new_f := _with_placeholders(n, got, False): 426 | # FIXME: No valid tests yet. 427 | assert len(self.holes) == holes_len 428 | val, got = self.infer(new_f) 429 | 430 | if not self._eq(got, want): 431 | raise TypeMismatchError(str(want), str(got), n.loc) 432 | 433 | return val 434 | 435 | def infer(self, n: Node) -> tuple[ir.IR, ir.IR]: 436 | if isinstance(n, Ref): 437 | if param := self.locals.get(n.name.id): 438 | return ir.Ref(param.name), param.type 439 | d = self.globals[n.name.id] 440 | if isinstance(d, Def): 441 | return ir.from_def(d) 442 | if isinstance(d, Sig): 443 | self.recur_ids.add(d.name.id) 444 | return ir.from_sig(d) 445 | if isinstance(d, Data): 446 | return ir.from_data(d) 447 | if isinstance(d, Ctor): 448 | data_decl = _c(Data, self.globals[d.ty_name.id]) 449 | return ir.from_ctor(d, data_decl) 450 | if isinstance(d, Field): 451 | return ir.from_field(d, _c(Class, self.globals[d.cls_name.id])) 452 | return ir.from_class(_c(Class, d)) 453 | if isinstance(n, FnType): 454 | p_typ = self.check(n.param.type, ir.Type()) 455 | p = Param(n.param.name, p_typ, n.param.is_implicit, n.param.is_class) 456 | b_val = self._check_with(n.ret, ir.Type(), p) 457 | return ir.FnType(p, b_val), ir.Type() 458 | if isinstance(n, Call): 459 | holes_len = len(self.holes) 460 | f_val, got = self.infer(n.callee) 461 | 462 | if implicit_f := _with_placeholders(n.callee, got, n.implicit): 463 | [self.holes.popitem() for _ in range(len(self.holes) - holes_len)] 464 | return self.infer(Call(n.loc, implicit_f, n.arg, n.implicit)) 465 | 466 | if not isinstance(got, ir.FnType): 467 | raise TypeMismatchError("function", str(got), n.callee.loc) 468 | 469 | x_tm = self._check_with(n.arg, got.param.type, got.param) 470 | typ = self._inliner().run_with(got.ret, (got.param.name, x_tm)) 471 | val = self._inliner().apply(f_val, x_tm) 472 | return val, typ 473 | if isinstance(n, Placeholder): 474 | ty = self._insert_hole(n.loc, n.is_user, ir.Type()) 475 | v = self._insert_hole(n.loc, n.is_user, ty) 476 | return v, ty 477 | if isinstance(n, Nomatch): 478 | _, got = self.infer(n.arg) 479 | if not isinstance(got, ir.Data): 480 | raise TypeMismatchError("datatype", str(got), n.arg.loc) 481 | data = _c(Data, self.globals[got.name.id]) 482 | for c in data.ctors: 483 | self._exhaust(n.arg.loc, c, data, got) 484 | return ir.Nomatch(), self._insert_hole(n.loc, False, ir.Type()) 485 | if isinstance(n, Match): 486 | arg, arg_ty = self.infer(n.arg) 487 | if not isinstance(arg_ty, ir.Data): 488 | raise TypeMismatchError("datatype", str(arg_ty), n.arg.loc) 489 | data = _c(Data, self.globals[arg_ty.name.id]) 490 | ctors = {c.name.id: c for c in data.ctors} 491 | ty: ir.IR | None = None 492 | cases: dict[int, ir.Case] = {} 493 | for c in n.cases: 494 | ctor = ctors.get(c.ctor.name.id) 495 | if not ctor: 496 | raise UnknownCaseError(data.name.text, c.ctor.name.text, c.loc) 497 | with ir.dirty_holes(self.holes): 498 | c_ty = self._ctor_return_type(c.loc, ctor, data) 499 | if not self._eq(c_ty, arg_ty): 500 | raise TypeMismatchError(str(arg_ty), str(c_ty), c.loc) 501 | if ctor.name.id in cases: 502 | raise DuplicateCaseError(ctor.name.text, c.loc) 503 | if len(c.params) != len(ctor.params): 504 | raise CaseParamMismatchError(len(ctor.params), len(c.params), c.loc) 505 | ps = [Param(n, p.type, False) for n, p in zip(c.params, ctor.params)] 506 | if ty is None: 507 | body, ty = self._infer_with(c.body, *ps) 508 | else: 509 | body = self._check_with(c.body, ty, *ps) 510 | cases[ctor.name.id] = ir.Case(ctor.name, ps, body) 511 | for c in [c for i, c in ctors.items() if i not in cases]: 512 | self._exhaust(n.loc, c, data, arg_ty) 513 | return ir.Match(arg, cases), ty 514 | assert isinstance(n, Type) 515 | return ir.Type(), ir.Type() 516 | 517 | def _inliner(self): 518 | return ir.Inliner(self.holes, self.globals) 519 | 520 | def _eq(self, got: ir.IR, want: ir.IR): 521 | return ir.Converter(self.holes, self.globals).eq(got, want) 522 | 523 | def _check_with(self, n: Node, typ: ir.IR, *ps: Param[ir.IR]): 524 | self.locals.update({p.name.id: p for p in ps}) 525 | ret = self.check(n, typ) 526 | [self.locals.pop(p.name.id, None) for p in ps] 527 | return ret 528 | 529 | def _infer_with(self, n: Node, *ps: Param[ir.IR]): 530 | self.locals.update({p.name.id: p for p in ps}) 531 | v, ty = self.infer(n) 532 | [self.locals.pop(p.name.id, None) for p in ps] 533 | return v, ty 534 | 535 | def _insert_hole(self, loc: int, is_user: bool, typ: ir.IR): 536 | i = fresh() 537 | self.holes[i] = ir.Hole(loc, is_user, self.locals.copy(), ir.Answer(typ)) 538 | return ir.Placeholder(i, is_user) 539 | 540 | def _ctor_return_type(self, loc: int, c: Ctor[ir.IR], d: Data[ir.IR]): 541 | _, ty = ir.from_ctor(c, d) 542 | while isinstance(ty, ir.FnType): 543 | p = ty.param 544 | x = ( 545 | self._insert_hole(loc, False, p.type) 546 | if p.is_implicit 547 | else ir.Ref(p.name) 548 | ) 549 | ty = self._inliner().run_with(ty.ret, (p.name, x)) 550 | return ty 551 | 552 | def _exhaust(self, loc: int, c: Ctor[ir.IR], d: Data[ir.IR], want: ir.IR): 553 | with ir.dirty_holes(self.holes): 554 | if self._eq(self._ctor_return_type(loc, c, d), want): 555 | raise CaseMissError(c.name.text, loc) 556 | 557 | 558 | def _is_solved_class(ty: ir.IR): 559 | if isinstance(ty, ir.Class): 560 | return all(not isinstance(a, ir.Ref) for a in ty.args) 561 | return False 562 | 563 | 564 | def _can_insert_placeholders(ty: ir.IR): 565 | return not isinstance(ty, ir.FnType) or not ty.param.is_implicit 566 | 567 | 568 | def _with_placeholders(f: Node, f_ty: ir.IR, implicit: str | bool) -> Node | None: 569 | if not isinstance(f_ty, ir.FnType): 570 | return None 571 | 572 | if isinstance(implicit, bool): 573 | if not f_ty.param.is_implicit: 574 | return None 575 | return _call_placeholder(f) if not implicit else None 576 | 577 | pending = 0 578 | while True: 579 | if not isinstance(f_ty, ir.FnType) or not f_ty.param.is_implicit: 580 | raise UndefinedVariableError(implicit, f.loc) 581 | if f_ty.param.name.text == implicit: 582 | break 583 | pending += 1 584 | f_ty = f_ty.ret 585 | 586 | if not pending: 587 | return None 588 | 589 | for _ in range(pending): 590 | f = _call_placeholder(f) 591 | return f 592 | 593 | 594 | def _call_placeholder(f: Node): 595 | return Call(f.loc, f, Placeholder(f.loc, False), True) 596 | 597 | 598 | check_string = lambda s, md=False: s | Parser(md) | NameResolver() | TypeChecker() 599 | -------------------------------------------------------------------------------- /src/TinyLean/grammar.py: -------------------------------------------------------------------------------- 1 | from pyparsing import * 2 | 3 | ParserElement.enable_packrat() 4 | 5 | COMMENT = Regex(r"/\-(?:[^-]|\-(?!/))*\-\/").set_name("comment") 6 | 7 | IDENT = unicode_set.identifier() 8 | 9 | DEF, EXAMPLE, IND, WHERE, OPEN, TYPE, NOMATCH, MATCH, WITH, UNDER, CLASS, INST = map( 10 | lambda w: Suppress(Keyword(w)), 11 | "def example inductive where open Type nomatch match with _ class instance".split(), 12 | ) 13 | 14 | ASSIGN, ARROW, FUN, TO = map( 15 | lambda s: Suppress(s[0]) | Suppress(s[1:]), "≔:= →-> λfun ↦=>".split() 16 | ) 17 | 18 | LPAREN, RPAREN, LBRACE, RBRACE, LBRACKET, RBRACKET, COLON, BAR, NEWLINE = map( 19 | Suppress, "(){}[]:|\n" 20 | ) 21 | INLINE_WHITE = Opt(Suppress(White(" \t\r"))).set_name("inline_whitespace") 22 | 23 | forwards = lambda names: map(lambda n: Forward().set_name(n), names.split()) 24 | 25 | expr, atom, fn_type, fn, match, nomatch, call, p_expr, type_, ph, ref = forwards( 26 | "expr atom fn_type fn match nomatch call paren_expr type placeholder ref" 27 | ) 28 | case, i_arg, e_arg = forwards("case implicit_arg explicit_arg") 29 | 30 | infix_op = lambda s: (one_of(s), 2, OpAssoc.LEFT) 31 | expr <<= infix_notation(atom, [infix_op("* /"), infix_op("+ -")]) 32 | atom <<= fn_type | fn | match | nomatch | call | p_expr | type_ | ph | ref 33 | 34 | name = Group(IDENT).set_name("name") 35 | i_param = (LBRACE + name + COLON + expr + RBRACE).set_name("implicit_param") 36 | e_param = (LPAREN + name + COLON + expr + RPAREN).set_name("explicit_param") 37 | c_param = (LBRACKET + name + COLON + expr + RBRACKET).set_name("class_param") 38 | param = (i_param | e_param | c_param).set_name("param") 39 | fn_type <<= param + ARROW + expr 40 | fn <<= FUN - Group(OneOrMore(name)) + TO + expr 41 | match <<= MATCH - (type_ | ref | p_expr) + WITH + Group(OneOrMore(case)) 42 | case <<= BAR - ref + Group(ZeroOrMore(name)) + TO + expr 43 | nomatch <<= (NOMATCH - INLINE_WHITE + e_arg).leave_whitespace() 44 | callee = ref | p_expr 45 | call <<= (callee + OneOrMore(INLINE_WHITE + (i_arg | e_arg))).leave_whitespace() 46 | i_arg <<= LPAREN + IDENT + ASSIGN + expr + RPAREN 47 | e_arg <<= (type_ | ph | ref | p_expr).leave_whitespace() 48 | p_expr <<= LPAREN + expr + RPAREN 49 | type_ <<= Group(TYPE) 50 | ph <<= Group(UNDER) 51 | ref <<= Group(name) 52 | 53 | return_type = Opt(COLON + expr) 54 | params = Group(ZeroOrMore(param)) 55 | def_ = (DEF - ref + params + return_type + ASSIGN + expr).set_name("definition") 56 | example = (EXAMPLE - params + return_type + ASSIGN + expr).set_name("example") 57 | type_arg = (LPAREN + ref + ASSIGN + expr + RPAREN).set_name("type_arg") 58 | ctor = (BAR - ref + params + Group(ZeroOrMore(type_arg))).set_name("constructor") 59 | data = (IND - ref + params + WHERE + Group(ZeroOrMore(ctor)) + OPEN + IDENT).set_name( 60 | "datatype" 61 | ) 62 | c_field = (name + COLON + expr).set_name("class_field") 63 | class_ = ( 64 | CLASS - ref + params + WHERE + Group(ZeroOrMore(c_field)) + OPEN + IDENT 65 | ).set_name("class") 66 | i_field = (ref + ASSIGN + expr).set_name("instance_field") 67 | inst = (INST - COLON + expr + WHERE + Group(ZeroOrMore(i_field))).set_name("instance") 68 | declaration = (def_ | example | data | class_ | inst).set_name("declaration") 69 | 70 | program = ZeroOrMore(declaration).ignore(COMMENT).set_name("program") 71 | 72 | line_exact = lambda w: Suppress(AtLineStart(w) + LineEnd()) 73 | markdown = line_exact("```lean") + program + line_exact("```") 74 | -------------------------------------------------------------------------------- /src/TinyLean/ir.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from dataclasses import dataclass, field 3 | from functools import reduce as _r 4 | from typing import Optional, cast as _c, OrderedDict 5 | 6 | from . import ( 7 | Name, 8 | Param, 9 | Def, 10 | Data as DataDecl, 11 | Ctor as CtorDecl, 12 | Sig, 13 | Decl, 14 | Class as ClassDecl, 15 | Field as FieldDecl, 16 | Instance, 17 | ) 18 | 19 | 20 | @dataclass(frozen=True) 21 | class IR: ... 22 | 23 | 24 | @dataclass(frozen=True) 25 | class Type(IR): 26 | def __str__(self): 27 | return "Type" 28 | 29 | 30 | @dataclass(frozen=True) 31 | class Ref(IR): 32 | name: Name 33 | 34 | def __str__(self): 35 | return str(self.name) 36 | 37 | 38 | @dataclass(frozen=True) 39 | class FnType(IR): 40 | param: Param[IR] 41 | ret: IR 42 | 43 | def __str__(self): 44 | return f"{self.param} → {self.ret}" 45 | 46 | 47 | @dataclass(frozen=True) 48 | class Fn(IR): 49 | param: Param[IR] 50 | body: IR 51 | 52 | def __str__(self): 53 | return f"λ {self.param} ↦ {self.body}" 54 | 55 | 56 | @dataclass(frozen=True) 57 | class Call(IR): 58 | callee: IR 59 | arg: IR 60 | 61 | def __str__(self): 62 | return f"({self.callee} {self.arg})" 63 | 64 | 65 | @dataclass(frozen=True) 66 | class Placeholder(IR): 67 | id: int 68 | is_user: bool 69 | 70 | def __str__(self): 71 | t = "u" if self.is_user else "m" 72 | return f"?{t}.{self.id}" 73 | 74 | 75 | @dataclass(frozen=True) 76 | class Data(IR): 77 | name: Name 78 | args: list[IR] 79 | 80 | def __str__(self): 81 | s = " ".join(str(x) for x in [self.name, *self.args]) 82 | return f"({s})" if len(self.args) else s 83 | 84 | 85 | @dataclass(frozen=True) 86 | class Ctor(IR): 87 | ty_name: Name 88 | name: Name 89 | args: list[IR] 90 | 91 | def __str__(self): 92 | n = f"{self.ty_name}.{self.name}" 93 | s = " ".join(str(x) for x in [n, *self.args]) 94 | return f"({s})" if len(self.args) else s 95 | 96 | 97 | @dataclass(frozen=True) 98 | class Nomatch(IR): 99 | def __str__(self): 100 | return "nomatch" 101 | 102 | 103 | @dataclass(frozen=True) 104 | class Case(IR): 105 | ctor: Name 106 | params: list[Param[IR]] 107 | body: IR 108 | 109 | def __str__(self): 110 | s = " ".join([str(self.ctor), *map(str, self.params)]) 111 | return f"| {s} ↦ {self.body}" 112 | 113 | 114 | @dataclass(frozen=True) 115 | class Match(IR): 116 | arg: IR 117 | cases: dict[int, Case] 118 | 119 | def __str__(self): 120 | cs = " ".join(map(str, self.cases.values())) 121 | return f"match {self.arg} with {cs}" 122 | 123 | 124 | @dataclass(frozen=True) 125 | class Recur(IR): 126 | name: Name 127 | 128 | def __str__(self): 129 | return str(self.name) 130 | 131 | 132 | @dataclass(frozen=True) 133 | class Class(IR): 134 | name: Name 135 | args: list[IR] 136 | 137 | def __str__(self): 138 | s = " ".join(str(x) for x in [self.name, *self.args]) 139 | return f"({s})" if len(self.args) else s 140 | 141 | def is_unsolved(self): 142 | return any(isinstance(a, Ref) for a in self.args) 143 | 144 | 145 | @dataclass(frozen=True) 146 | class Field(IR): 147 | name: Name 148 | type: IR 149 | 150 | def __str__(self): 151 | return str(self.name) 152 | 153 | 154 | @dataclass(frozen=True) 155 | class Renamer: 156 | locals: dict[int, int] = field(default_factory=dict) 157 | 158 | def run(self, v: IR) -> IR: 159 | if isinstance(v, Ref): 160 | if v.name.id in self.locals: 161 | return Ref(Name(v.name.text, self.locals[v.name.id])) 162 | return v 163 | if isinstance(v, Call): 164 | return Call(self.run(v.callee), self.run(v.arg)) 165 | if isinstance(v, Fn): 166 | return Fn(self._param(v.param), self.run(v.body)) 167 | if isinstance(v, FnType): 168 | return FnType(self._param(v.param), self.run(v.ret)) 169 | if isinstance(v, Data): 170 | return Data(v.name, [self.run(x) for x in v.args]) 171 | if isinstance(v, Ctor): 172 | return Ctor(v.ty_name, v.name, [self.run(x) for x in v.args]) 173 | if isinstance(v, Match): 174 | arg = self.run(v.arg) 175 | cases = { 176 | i: Case(c.ctor, [self._param(p) for p in c.params], self.run(c.body)) 177 | for i, c in v.cases.items() 178 | } 179 | return Match(arg, cases) 180 | if isinstance(v, Class): 181 | return Class(v.name, [self.run(x) for x in v.args]) 182 | if isinstance(v, Field): 183 | return Field(v.name, self.run(v.type)) 184 | assert any(isinstance(v, c) for c in (Type, Placeholder, Nomatch, Recur)) 185 | return v 186 | 187 | def _param(self, p: Param[IR]): 188 | name = Name(p.name.text) 189 | self.locals[p.name.id] = name.id 190 | return Param(name, self.run(p.type), p.is_implicit, p.is_class) 191 | 192 | 193 | _rn = lambda v: Renamer().run(v) 194 | 195 | 196 | def _to(p: list[Param[IR]], v: IR, t=False): 197 | return _r(lambda a, q: _c(IR, FnType(q, a) if t else Fn(q, a)), reversed(p), v) 198 | 199 | 200 | def from_def(d: Def[IR]): 201 | return _rn(_to(d.params, d.body)), _rn(_to(d.params, d.ret, True)) 202 | 203 | 204 | def from_sig(s: Sig[IR]): 205 | return Recur(s.name), _rn(_to(s.params, s.ret, True)) 206 | 207 | 208 | def from_data(d: DataDecl[IR]): 209 | args = [Ref(p.name) for p in d.params] 210 | return _rn(_to(d.params, Data(d.name, args))), _rn(_to(d.params, Type(), True)) 211 | 212 | 213 | def from_ctor(c: CtorDecl[IR], d: DataDecl[IR]): 214 | adhoc = {x.name.id: v for x, v in _c(dict[Ref, IR], c.ty_args)} 215 | miss = [Param(p.name, p.type, True) for p in d.params if p.name.id not in adhoc] 216 | 217 | v = _to(c.params, Ctor(d.name, c.name, [Ref(p.name) for p in c.params])) 218 | v = _to(miss, v) 219 | 220 | ty_args = [adhoc.get(p.name.id, Ref(p.name)) for p in d.params] 221 | ty = _to(c.params, Data(d.name, ty_args), True) 222 | ty = _to(miss, ty, True) 223 | 224 | return _rn(v), _rn(ty) 225 | 226 | 227 | def from_class(c: ClassDecl[IR]): 228 | args = [Ref(p.name) for p in c.params] 229 | return _rn(_to(c.params, Class(c.name, args))), _rn(_to(c.params, Type(), True)) 230 | 231 | 232 | def from_field(f: FieldDecl[IR], c: ClassDecl[IR], has_c_param=True): 233 | t = Class(c.name, [Ref(p.name) for p in c.params]) 234 | ps = [*c.params, Param(Name("inst"), t, True, True)] if has_c_param else c.params 235 | return _rn(_to(ps, Field(f.name, t))), _rn(_to(ps, f.type, True)) 236 | 237 | 238 | @dataclass 239 | class Answer: 240 | type: IR 241 | value: Optional[IR] = None 242 | 243 | def is_unsolved(self): 244 | return self.value is None 245 | 246 | 247 | @dataclass(frozen=True) 248 | class Hole: 249 | loc: int 250 | is_user: bool 251 | locals: dict[int, Param[IR]] 252 | answer: Answer 253 | 254 | 255 | @contextmanager 256 | def dirty_holes(holes: OrderedDict[int, Hole]): 257 | l = len(holes) 258 | try: 259 | yield 260 | finally: 261 | [holes.popitem() for _ in range(len(holes) - l)] 262 | 263 | 264 | class NoInstanceError(Exception): ... 265 | 266 | 267 | @dataclass 268 | class Inliner: 269 | holes: OrderedDict[int, Hole] 270 | globals: dict[int, Decl] 271 | can_recurse: bool = True 272 | env: dict[int, IR] = field(default_factory=dict) 273 | 274 | def run(self, v: IR) -> IR: 275 | if isinstance(v, Ref): 276 | return self.run(_rn(self.env[v.name.id])) if v.name.id in self.env else v 277 | if isinstance(v, Call): 278 | f = self.run(v.callee) 279 | x = self.run(v.arg) 280 | if isinstance(f, Fn): 281 | return self.run_with(f.body, (f.param.name, x)) 282 | return Call(f, x) 283 | if isinstance(v, Fn): 284 | return Fn(self._param(v.param), self.run(v.body)) 285 | if isinstance(v, FnType): 286 | return FnType(self._param(v.param), self.run(v.ret)) 287 | if isinstance(v, Placeholder): 288 | h = self.holes[v.id] 289 | h.answer.type = self.run(h.answer.type) 290 | return v if h.answer.is_unsolved() else self.run(h.answer.value) 291 | if isinstance(v, Ctor): 292 | return Ctor(v.ty_name, v.name, [self.run(v) for v in v.args]) 293 | if isinstance(v, Data): 294 | return Data(v.name, [self.run(v) for v in v.args]) 295 | if isinstance(v, Match): 296 | arg = self.run(v.arg) 297 | can_recurse = self.can_recurse 298 | self.can_recurse = False 299 | cases = { 300 | i: Case(c.ctor, [self._param(p) for p in c.params], self.run(c.body)) 301 | for i, c in v.cases.items() 302 | } 303 | self.can_recurse = can_recurse 304 | if not isinstance(arg, Ctor): 305 | return Match(arg, cases) 306 | c = cases[arg.name.id] 307 | env = [(x.name, v) for x, v in zip(c.params, arg.args)] 308 | return self.run_with(c.body, *env) 309 | if isinstance(v, Recur): 310 | if self.can_recurse: 311 | d = self.globals[v.name.id] 312 | if isinstance(d, Def): 313 | return from_def(d)[0] 314 | assert isinstance(d, Sig) 315 | return v 316 | if isinstance(v, Class): 317 | return Class(v.name, [self.run(t) for t in v.args]) 318 | if isinstance(v, Field): 319 | c = _c(Class, self.run(v.type)) 320 | if c.is_unsolved(): 321 | return Field(v.name, c) 322 | i = self._resolve_instance(c) 323 | val = next(val for n, val in i.fields if _c(Ref, n).name.id == v.name.id) 324 | return self.run(val) 325 | assert isinstance(v, Type) or isinstance(v, Nomatch) 326 | return v 327 | 328 | def run_with(self, x: IR, *env: tuple[Name, IR]): 329 | self.env.update({n.id: v for n, v in env}) 330 | return self.run(x) 331 | 332 | def apply(self, f: IR, *args: IR): 333 | ret = f 334 | for x in args: 335 | if isinstance(ret, Fn): 336 | ret = self.run_with(ret.body, (ret.param.name, x)) 337 | else: 338 | ret = Call(ret, x) 339 | return ret 340 | 341 | def _param(self, param: Param[IR]): 342 | p = Param(param.name, self.run(param.type), param.is_implicit, param.is_class) 343 | if not p.is_class: 344 | return p 345 | ty = _c(Class, p.type) 346 | if not ty.is_unsolved() and not self._resolve_instance(ty): 347 | raise NoInstanceError(str(ty), self.globals[ty.name.id].loc) 348 | return p 349 | 350 | def _resolve_instance(self, c: Class) -> Optional[Instance[IR]]: 351 | cls = _c(ClassDecl, self.globals[c.name.id]) 352 | for inst_id in cls.instances: 353 | i = _c(Instance, self.globals[inst_id]) 354 | with dirty_holes(self.holes): 355 | if Converter(self.holes, self.globals).eq(c, i.type): 356 | return i 357 | return None 358 | 359 | 360 | @dataclass(frozen=True) 361 | class Converter: 362 | holes: OrderedDict[int, Hole] 363 | globals: dict[int, Decl] 364 | 365 | def eq(self, lhs: IR, rhs: IR): 366 | match lhs, rhs: 367 | case Placeholder() as x, y: 368 | return self._solve(x, y) 369 | case x, Placeholder() as y: 370 | return self._solve(y, x) 371 | case Ref(x), Ref(y): 372 | return x.id == y.id 373 | case Call(f, x), Call(g, y): 374 | return self.eq(f, g) and self.eq(x, y) 375 | case Fn(p, b), Fn(q, c): 376 | env = [(q.name, Ref(p.name))] 377 | return self.eq(b, Inliner(self.holes, self.globals).run_with(c, *env)) 378 | case FnType(p, b), FnType(q, c): 379 | if not self.eq(p.type, q.type): 380 | return False 381 | env = [(q.name, Ref(p.name))] 382 | return self.eq(b, Inliner(self.holes, self.globals).run_with(c, *env)) 383 | case Data(x, xs), Data(y, ys): 384 | return x.id == y.id and self._args(xs, ys) 385 | case Ctor(t, x, xs), Ctor(u, y, ys): 386 | return t.id == u.id and x.id == y.id and self._args(xs, ys) 387 | case Type(), Type(): 388 | return True 389 | case Class(x, xs), Class(y, ys): 390 | return x.id == y.id and self._args(xs, ys) 391 | 392 | # FIXME: Following cases not seen in tests yet: 393 | assert not (isinstance(lhs, Placeholder) and isinstance(rhs, Placeholder)) 394 | assert not (isinstance(lhs, Match) and isinstance(rhs, Match)) 395 | assert not (isinstance(lhs, Field) and isinstance(rhs, Field)) 396 | 397 | return False 398 | 399 | def _solve(self, p: Placeholder, answer: IR): 400 | h = self.holes[p.id] 401 | if not h.answer.is_unsolved(): 402 | return self.eq(h.answer.value, answer) 403 | h.answer.value = answer 404 | 405 | if isinstance(answer, Ref): 406 | for param in h.locals.values(): 407 | if param.name.id == answer.name.id: 408 | assert self.eq(param.type, h.answer.type) # FIXME: will fail here? 409 | 410 | return True 411 | 412 | def _args(self, xs: list[IR], ys: list[IR]): 413 | assert len(xs) == len(ys) 414 | return all(self.eq(x, y) for x, y in zip(xs, ys)) 415 | -------------------------------------------------------------------------------- /src/TinyLean/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pyparsing import ParserElement 2 | 3 | from .. import ast, grammar 4 | 5 | 6 | def parse(g: ParserElement, text: str): 7 | return g.parse_string(text, parse_all=True) 8 | 9 | 10 | def resolve(s: str): 11 | return s | ast.Parser() | ast.NameResolver() 12 | 13 | 14 | def resolve_md(s: str): 15 | return s | ast.Parser(True) | ast.NameResolver() 16 | 17 | 18 | def resolve_expr(s: str): 19 | return ast.NameResolver().expr(parse(grammar.expr, s)[0]) 20 | -------------------------------------------------------------------------------- /src/TinyLean/tests/onboard.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anqur/TinyLean/312b3eacfe8e0489971d871eaca886ee1d737246/src/TinyLean/tests/onboard.py -------------------------------------------------------------------------------- /src/TinyLean/tests/test_checker.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from . import resolve_expr 4 | from .. import ast, Name, Param, ir, Data, Example, Def, Class, Instance 5 | 6 | check_expr = lambda s, t: ast.TypeChecker().check(resolve_expr(s), t) 7 | infer_expr = lambda s: ast.TypeChecker().infer(resolve_expr(s)) 8 | 9 | 10 | class TestTypeChecker(TestCase): 11 | def test_check_expr_type(self): 12 | check_expr("Type", ir.Type()) 13 | check_expr("{a: Type} -> (b: Type) -> a", ir.Type()) 14 | 15 | def test_check_expr_type_failed(self): 16 | with self.assertRaises(ast.TypeMismatchError) as e: 17 | check_expr("fun a => a", ir.Type()) 18 | want, got, loc = e.exception.args 19 | self.assertEqual(0, loc) 20 | self.assertEqual("Type", want) 21 | self.assertEqual("function", got) 22 | 23 | def test_check_expr_function(self): 24 | check_expr( 25 | "fun a => a", 26 | ir.FnType(Param(Name("a"), ir.Type(), False), ir.Type()), 27 | ) 28 | 29 | def test_check_expr_on_infer(self): 30 | check_expr("Type", ir.Type()) 31 | 32 | def test_check_expr_on_infer_failed(self): 33 | with self.assertRaises(ast.TypeMismatchError) as e: 34 | check_expr("(a: Type) -> a", ir.Ref(Name("a"))) 35 | want, got, loc = e.exception.args 36 | self.assertEqual(0, loc) 37 | self.assertEqual("a", want) 38 | self.assertEqual("Type", got) 39 | 40 | def test_infer_expr_type(self): 41 | v, ty = infer_expr("Type") 42 | assert isinstance(v, ir.Type) 43 | assert isinstance(ty, ir.Type) 44 | 45 | def test_infer_expr_call_failed(self): 46 | with self.assertRaises(ast.TypeMismatchError) as e: 47 | infer_expr("(Type) Type") 48 | want, got, loc = e.exception.args 49 | self.assertEqual(1, loc) 50 | self.assertEqual("function", want) 51 | self.assertEqual("Type", got) 52 | 53 | def test_infer_expr_function_type(self): 54 | v, ty = infer_expr("{a: Type} -> a") 55 | assert isinstance(v, ir.FnType) 56 | self.assertEqual("{a: Type} → a", str(v)) 57 | assert isinstance(ty, ir.Type) 58 | 59 | def test_check_program(self): 60 | ast.check_string("def a: Type := Type") 61 | ast.check_string("def f (a: Type): Type := a") 62 | ast.check_string("def f: (_: Type) -> Type := fun a => a") 63 | ast.check_string("def id (T: Type) (a: T): T := a") 64 | 65 | def test_check_program_failed(self): 66 | with self.assertRaises(ast.TypeMismatchError) as e: 67 | ast.check_string("def id (a: Type): a := Type") 68 | want, got, loc = e.exception.args 69 | self.assertEqual(23, loc) 70 | self.assertEqual("a", want) 71 | self.assertEqual("Type", got) 72 | 73 | def test_check_program_call(self): 74 | ast.check_string( 75 | """ 76 | def f0 (a: Type): Type := a 77 | def f1: Type := f0 Type 78 | def f2: f0 Type := Type 79 | """ 80 | ) 81 | 82 | def test_check_program_call_failed(self): 83 | with self.assertRaises(ast.TypeMismatchError) as e: 84 | ast.check_string( 85 | """ 86 | def f0 (a: Type): Type := a 87 | def f1 (a: Type): Type := f0 88 | """ 89 | ) 90 | want, got, loc = e.exception.args 91 | self.assertEqual(87, loc) 92 | self.assertEqual("Type", want) 93 | self.assertEqual("(a: Type) → Type", got) 94 | 95 | def test_check_program_placeholder(self): 96 | ast.check_string( 97 | """ 98 | def a := Type 99 | def b: Type := a 100 | """ 101 | ) 102 | 103 | def test_check_program_placeholder_locals(self): 104 | ast.check_string("def f (T: Type) (a: T) := a") 105 | 106 | def test_check_program_placeholder_unsolved(self): 107 | with self.assertRaises(ast.UnsolvedPlaceholderError) as e: 108 | ast.check_string("def a: Type := _") 109 | name, ctx, ty, loc = e.exception.args 110 | self.assertTrue(name.startswith("?u")) 111 | self.assertEqual(0, len(ctx)) 112 | assert isinstance(ty, ir.Type) 113 | self.assertEqual(15, loc) 114 | 115 | def test_check_program_call_implicit_arg(self): 116 | _, _, example = ast.check_string( 117 | """ 118 | def id {T: Type} (a: T): T := a 119 | def f := id (T := Type) Type 120 | example := f 121 | """ 122 | ) 123 | assert isinstance(example.body, ir.Type) 124 | 125 | def test_check_program_call_implicit_arg_failed(self): 126 | with self.assertRaises(ast.UndefinedVariableError) as e: 127 | ast.check_string( 128 | """ 129 | def id {T: Type} (a: T): T := a 130 | def f := id (U := Type) Type 131 | """ 132 | ) 133 | name, loc = e.exception.args 134 | self.assertEqual("U", name) 135 | self.assertEqual(74, loc) 136 | 137 | def test_check_program_call_implicit_arg_long(self): 138 | ast.check_string( 139 | """ 140 | def f {T: Type} {U: Type} (a: U): Type := T 141 | def g: f (U := Type) Type := Type 142 | """ 143 | ) 144 | 145 | def test_check_program_call_implicit(self): 146 | _, _, example = ast.check_string( 147 | """ 148 | def id {T: Type} (a: T): T := a 149 | def f := id Type 150 | example := f 151 | """ 152 | ) 153 | assert isinstance(example.body, ir.Type) 154 | 155 | def test_check_program_call_no_explicit_failed(self): 156 | with self.assertRaises(ast.UnsolvedPlaceholderError) as e: 157 | ast.check_string( 158 | """ 159 | def f {T: Type}: Type := T 160 | def g: Type := f 161 | """ 162 | ) 163 | name, ctx, ty, loc = e.exception.args 164 | self.assertTrue(name.startswith("?m")) 165 | self.assertEqual(1, len(ctx)) 166 | assert isinstance(ty, ir.Type) 167 | self.assertEqual(75, loc) 168 | 169 | def test_check_program_call_mixed_implicit(self): 170 | ast.check_string( 171 | """ 172 | def f (T: Type) {U: Type}: Type := U 173 | /- Cannot insert placeholders for an implicit function type. -/ 174 | example: {U: Type} -> Type := f Type 175 | """ 176 | ) 177 | 178 | def test_check_program_datatype_nat(self): 179 | x, _, _, _2 = ast.check_string( 180 | """ 181 | inductive N where 182 | | Z 183 | | S (n: N) 184 | open N 185 | 186 | example: N := Z 187 | example: N := S Z 188 | example: N := S (S Z) 189 | """ 190 | ) 191 | assert isinstance(x, Data) 192 | self.assertEqual(2, len(x.ctors)) 193 | 194 | n_v, n_ty = ir.from_data(x) 195 | self.assertEqual("N", str(n_v)) 196 | self.assertEqual("Type", str(n_ty)) 197 | 198 | z_v, z_ty = ir.from_ctor(x.ctors[0], x) 199 | self.assertEqual("N.Z", str(z_v)) 200 | self.assertEqual("N", str(z_ty)) 201 | 202 | s_v, s_ty = ir.from_ctor(x.ctors[1], x) 203 | self.assertEqual("λ (n: N) ↦ (N.S n)", str(s_v)) 204 | self.assertEqual("(n: N) → N", str(s_ty)) 205 | 206 | assert isinstance(_2, Example) 207 | self.assertEqual("(N.S (N.S N.Z))", str(_2.body)) 208 | 209 | def test_check_program_datatype_nat_failed(self): 210 | with self.assertRaises(ast.TypeMismatchError) as e: 211 | ast.check_string( 212 | """ 213 | inductive N where 214 | | Z 215 | | S (n: N) 216 | open N 217 | 218 | example: Type := Z 219 | """ 220 | ) 221 | want, got, loc = e.exception.args 222 | self.assertEqual("Type", want) 223 | self.assertEqual("N", got) 224 | self.assertEqual(139, loc) 225 | 226 | def test_check_program_datatype_maybe(self): 227 | x = ast.check_string( 228 | """ 229 | inductive Maybe (A: Type) where 230 | | Nothing 231 | | Just (a: A) 232 | open Maybe 233 | 234 | example: Maybe Type := Nothing 235 | example: Maybe Type := Just Type 236 | """ 237 | )[0] 238 | assert isinstance(x, Data) 239 | self.assertEqual(2, len(x.ctors)) 240 | 241 | maybe_v, maybe_ty = ir.from_data(x) 242 | self.assertEqual("λ (A: Type) ↦ (Maybe A)", str(maybe_v)) 243 | self.assertEqual("(A: Type) → Type", str(maybe_ty)) 244 | 245 | nothing_v, nothing_ty = ir.from_ctor(x.ctors[0], x) 246 | self.assertEqual("λ {A: Type} ↦ Maybe.Nothing", str(nothing_v)) 247 | self.assertEqual("{A: Type} → (Maybe A)", str(nothing_ty)) 248 | 249 | just_v, just_ty = ir.from_ctor(x.ctors[1], x) 250 | self.assertEqual("λ {A: Type} ↦ λ (a: A) ↦ (Maybe.Just a)", str(just_v)) 251 | self.assertEqual("{A: Type} → (a: A) → (Maybe A)", str(just_ty)) 252 | 253 | assert isinstance(just_v, ir.Fn) 254 | assert isinstance(just_v.body, ir.Fn) 255 | assert isinstance(just_v.body.body, ir.Ctor) 256 | just_arg = just_v.body.body.args[0] 257 | assert isinstance(just_arg, ir.Ref) 258 | self.assertEqual(just_v.body.param.name.id, just_arg.name.id) 259 | 260 | assert isinstance(just_ty, ir.FnType) 261 | assert isinstance(just_ty.ret, ir.FnType) 262 | assert isinstance(just_ty.ret.ret, ir.Data) 263 | just_ty_arg = just_ty.ret.ret.args[0] 264 | assert isinstance(just_ty_arg, ir.Ref) 265 | self.assertEqual(just_ty.param.name.id, just_ty_arg.name.id) 266 | 267 | def test_check_program_datatype_maybe_unsolved(self): 268 | with self.assertRaises(ast.UnsolvedPlaceholderError) as e: 269 | ast.check_string( 270 | """ 271 | inductive Maybe (A: Type) where 272 | | Nothing 273 | | Just (a: A) 274 | open Maybe 275 | 276 | example := Nothing 277 | """ 278 | ) 279 | name, ctx, ty, loc = e.exception.args 280 | self.assertTrue(name.startswith("?m")) 281 | self.assertEqual(1, len(ctx)) 282 | assert isinstance(ty, ir.Type) 283 | self.assertEqual(160, loc) 284 | 285 | def test_check_program_datatype_maybe_failed(self): 286 | with self.assertRaises(ast.TypeMismatchError) as e: 287 | ast.check_string( 288 | """ 289 | inductive Maybe (A: Type) where 290 | | Nothing 291 | | Just (a: A) 292 | open Maybe 293 | 294 | inductive A where | AA open A 295 | inductive B where | BB open B 296 | example: Maybe B := Just AA 297 | """ 298 | ) 299 | want, got, _ = e.exception.args 300 | self.assertEqual("(Maybe B)", want) 301 | self.assertEqual("(Maybe A)", got) 302 | 303 | def test_check_program_datatype_vec(self): 304 | x = ast.check_string( 305 | """ 306 | inductive N where 307 | | Z 308 | | S (n: N) 309 | open N 310 | 311 | inductive Vec (A: Type) (n: N) where 312 | | Nil (n := Z) 313 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m) 314 | open Vec 315 | 316 | /- This will emit some "dirty" placeholders. -/ 317 | def v1: Vec N (S Z) := Cons Z Nil 318 | """ 319 | )[1] 320 | assert isinstance(x, Data) 321 | self.assertEqual(2, len(x.ctors)) 322 | 323 | vec_v, vec_ty = ir.from_data(x) 324 | self.assertEqual("λ (A: Type) ↦ λ (n: N) ↦ (Vec A n)", str(vec_v)) 325 | self.assertEqual("(A: Type) → (n: N) → Type", str(vec_ty)) 326 | 327 | nil_v, nil_ty = ir.from_ctor(x.ctors[0], x) 328 | self.assertEqual("λ {A: Type} ↦ Vec.Nil", str(nil_v)) 329 | self.assertEqual("{A: Type} → (Vec A N.Z)", str(nil_ty)) 330 | 331 | cons_v, cons_ty = ir.from_ctor(x.ctors[1], x) 332 | self.assertEqual( 333 | "λ {A: Type} ↦ λ {m: N} ↦ λ (a: A) ↦ λ (v: (Vec A m)) ↦ (Vec.Cons m a v)", 334 | str(cons_v), 335 | ) 336 | self.assertEqual( 337 | "{A: Type} → {m: N} → (a: A) → (v: (Vec A m)) → (Vec A (N.S m))", 338 | str(cons_ty), 339 | ) 340 | 341 | def test_check_program_ctor_eq(self): 342 | ast.check_string( 343 | """ 344 | def Eq {T: Type} (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b 345 | def refl {T: Type} (a: T): Eq a a := fun p pa => pa 346 | inductive A where | AA open A 347 | example: Eq AA AA := refl AA 348 | """ 349 | ) 350 | 351 | def test_check_program_nomatch(self): 352 | _, _, e = ast.check_string( 353 | """ 354 | inductive Bottom where open Bottom 355 | def elimBot {A: Type} (x: Bottom): A := nomatch x 356 | example (x: Bottom): Type := elimBot x 357 | """ 358 | ) 359 | assert isinstance(e, Example) 360 | assert isinstance(e.body, ir.Nomatch) 361 | self.assertEqual("nomatch", str(e.body)) 362 | 363 | def test_check_program_nomatch_non_data_failed(self): 364 | with self.assertRaises(ast.TypeMismatchError) as e: 365 | ast.check_string("example := nomatch Type") 366 | want, got, loc = e.exception.args 367 | self.assertEqual("datatype", want) 368 | self.assertEqual("Type", got) 369 | self.assertEqual(19, loc) 370 | 371 | def test_check_program_nomatch_non_empty_failed(self): 372 | with self.assertRaises(ast.CaseMissError) as e: 373 | ast.check_string( 374 | """ 375 | inductive A where | AA open A 376 | example := nomatch AA 377 | """ 378 | ) 379 | name, loc = e.exception.args 380 | self.assertEqual("AA", name) 381 | self.assertEqual(82, loc) 382 | 383 | def test_check_program_nomatch_dpm(self): 384 | ast.check_string( 385 | """ 386 | inductive N where 387 | | Z 388 | | S (n: N) 389 | open N 390 | 391 | inductive T (n: N) where 392 | | MkT (n := Z) 393 | open T 394 | 395 | example (x: T (S Z)): Type := nomatch x 396 | """ 397 | ) 398 | 399 | def test_check_program_nomatch_eq_failed(self): 400 | text = """ 401 | def Eq {T: Type} (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b 402 | def refl {T: Type} (a: T): Eq a a := fun p pa => pa 403 | inductive Bottom where open Bottom 404 | def a (x: Bottom): Type := nomatch x 405 | def b (x: Bottom): Type := nomatch x 406 | /- (a x) and (b x) should not be the same thing, apparently. -/ 407 | example (x: Bottom): Eq (a x) (b x) := refl (a x) 408 | """ 409 | with self.assertRaises(ast.TypeMismatchError) as e: 410 | ast.check_string(text) 411 | want, got, loc = e.exception.args 412 | self.assertEqual( 413 | "(p: (v: Type) → Type) → (pa: (p nomatch)) → (p nomatch)", want 414 | ) 415 | self.assertEqual("(p: (v: Type) → Type) → (pa: (p nomatch)) → (p nomatch)", got) 416 | self.assertEqual(text.index("refl (a x)"), loc) 417 | 418 | def test_check_program_match(self): 419 | _, f, e = ast.check_string( 420 | """ 421 | inductive V where 422 | | A (x: Type) (y: Type) 423 | | B (x: Type) 424 | open V 425 | 426 | def f (v: V): Type := 427 | match v with 428 | | A x y => x 429 | | B x => x 430 | 431 | example := f (A Type Type) 432 | """ 433 | ) 434 | assert isinstance(f, Def) 435 | self.assertEqual( 436 | "match v with | A (x: Type) (y: Type) ↦ x | B (x: Type) ↦ x", str(f.body) 437 | ) 438 | assert isinstance(e, Example) 439 | assert isinstance(e.body, ir.Type) 440 | 441 | def test_check_program_match_dpm(self): 442 | ast.check_string( 443 | """ 444 | inductive N where 445 | | Z 446 | | S (n: N) 447 | open N 448 | 449 | inductive Vec (A: Type) (n: N) where 450 | | Nil (n := Z) 451 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m) 452 | open Vec 453 | 454 | def v0: Vec N Z := Nil 455 | 456 | example := 457 | match v0 with 458 | | Nil => Z 459 | """ 460 | ) 461 | 462 | def test_check_program_match_dpm_failed(self): 463 | text = """ 464 | inductive N where 465 | | Z 466 | | S (n: N) 467 | open N 468 | 469 | inductive Vec (A: Type) (n: N) where 470 | | Nil (n := Z) 471 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m) 472 | open Vec 473 | 474 | def v0: Vec N Z := Nil 475 | 476 | example := 477 | match v0 with 478 | | Nil => Z 479 | | Cons a v => Z 480 | """ 481 | want_loc = text.index("| Cons a v") + 2 482 | with self.assertRaises(ast.TypeMismatchError) as e: 483 | ast.check_string(text) 484 | want, got, loc = e.exception.args 485 | self.assertIn("N.Z", want) 486 | self.assertIn("N.S", got) 487 | self.assertEqual(want_loc, loc) 488 | 489 | def test_check_program_match_type_failed(self): 490 | with self.assertRaises(ast.TypeMismatchError) as e: 491 | ast.check_string( 492 | """ 493 | inductive A where | AA open A 494 | example := 495 | match Type with 496 | | AA => AA 497 | """ 498 | ) 499 | want, got, loc = e.exception.args 500 | self.assertEqual("datatype", want) 501 | self.assertEqual("Type", got) 502 | self.assertEqual(98, loc) 503 | 504 | def test_check_program_match_unknown_case_failed(self): 505 | with self.assertRaises(ast.UnknownCaseError) as e: 506 | ast.check_string( 507 | """ 508 | inductive A where | AA open A 509 | inductive B where | BB open B 510 | example (x: A) := 511 | match x with 512 | | BB => AA 513 | """ 514 | ) 515 | want, got, loc = e.exception.args 516 | self.assertEqual("A", want) 517 | self.assertEqual("BB", got) 518 | self.assertEqual(178, loc) 519 | 520 | def test_check_program_match_duplicate_case_failed(self): 521 | text = """ 522 | inductive A where | AA open A 523 | 524 | example (x: A): Type := 525 | match x with 526 | | AA => (a: Type) -> Type 527 | | AA => Type 528 | """ 529 | with self.assertRaises(ast.DuplicateCaseError) as e: 530 | ast.check_string(text) 531 | name, loc = e.exception.args 532 | self.assertEqual("AA", name) 533 | self.assertEqual(text.index("AA => Type"), loc) 534 | 535 | def test_check_program_match_param_mismatch_failed(self): 536 | text = """ 537 | inductive A where | AA open A 538 | example (x: A): Type := 539 | match x with 540 | | AA a => AA 541 | """ 542 | with self.assertRaises(ast.CaseParamMismatchError) as e: 543 | ast.check_string(text) 544 | want, got, loc = e.exception.args 545 | self.assertEqual(0, want) 546 | self.assertEqual(1, got) 547 | self.assertEqual(text.index("AA a"), loc) 548 | 549 | def test_check_program_match_miss_failed(self): 550 | text = """ 551 | inductive A where | AA | BB open A 552 | example (x: A): Type := 553 | match x with 554 | | AA => AA 555 | """ 556 | with self.assertRaises(ast.CaseMissError) as e: 557 | ast.check_string(text) 558 | name, loc = e.exception.args 559 | self.assertEqual("BB", name) 560 | self.assertEqual(text.index("match x with"), loc) 561 | 562 | def test_check_program_match_inline(self): 563 | ast.check_string( 564 | """ 565 | inductive A where | AA open A 566 | def f (x: A) := 567 | match x with 568 | | AA => AA 569 | def g (x: A) := f x /- match expression not inlined yet -/ 570 | """ 571 | ) 572 | 573 | def test_check_program_eq(self): 574 | ast.check_string( 575 | """ 576 | inductive N where 577 | | Z 578 | | S (n: N) 579 | open N 580 | 581 | inductive Eq {T: Type} (a: T) (b: T) where 582 | | Refl (a := b) 583 | open Eq 584 | 585 | example: Eq (S Z) (S Z) := Refl (T := N) 586 | """ 587 | ) 588 | 589 | def test_check_program_eq_failed(self): 590 | text = """ 591 | inductive N where 592 | | Z 593 | | S (n: N) 594 | open N 595 | 596 | inductive Eq {T: Type} (a: T) (b: T) where 597 | | Refl (a := b) 598 | open Eq 599 | 600 | example: Eq Z (S Z) := Refl (T := N) 601 | """ 602 | with self.assertRaises(ast.TypeMismatchError) as e: 603 | ast.check_string(text) 604 | want, got, loc = e.exception.args 605 | self.assertEqual("(Eq N N.Z (N.S N.Z))", want) 606 | got = " ".join(["_" if "?m." in s else s for s in got[1:-1].split()]) 607 | self.assertEqual("Eq N _ _", got) 608 | self.assertEqual(text.index("Refl (T := N)"), loc) 609 | 610 | def test_check_program_recurse(self): 611 | _, add, e = ast.check_string( 612 | """ 613 | inductive N where 614 | | Z 615 | | S (n: N) 616 | open N 617 | 618 | def add (n: N) (m: N): N := 619 | match n with 620 | | Z => m 621 | | S pred => S (add pred m) 622 | 623 | example := add (S Z) (S Z) 624 | """ 625 | ) 626 | assert isinstance(add, Def) 627 | self.assertEqual( 628 | "match n with | Z ↦ m | S (pred: N) ↦ (N.S ((add pred) m))", str(add.body) 629 | ) 630 | assert isinstance(e, Example) 631 | self.assertEqual("(N.S (N.S N.Z))", str(e.body)) 632 | 633 | def test_check_program_class(self): 634 | ast.check_string( 635 | """ 636 | class C where open C 637 | example [p: C] := Type 638 | """ 639 | ) 640 | 641 | def test_check_program_class_param_failed(self): 642 | text = "example [p: Type] := Type" 643 | with self.assertRaises(ast.TypeMismatchError) as e: 644 | ast.check_string(text) 645 | want, got, loc = e.exception.args 646 | self.assertEqual("class", want) 647 | self.assertEqual("Type", got) 648 | self.assertEqual(text.index("Type]"), loc) 649 | 650 | def test_check_program_class_stuck(self): 651 | _, _, e = ast.check_string( 652 | """ 653 | class Default (T: Type) where 654 | default: T 655 | open Default 656 | def f (U: Type) [p: Default U] := default U (inst := p) 657 | example (V: Type) [q: Default V] := f V (p := q) 658 | """ 659 | ) 660 | assert isinstance(e, Example) 661 | self.assertEqual("default", str(e.body)) 662 | 663 | def test_check_program_class_failed(self): 664 | text = """ 665 | class C where open C 666 | def f [p: C] := Type 667 | example := f 668 | """ 669 | with self.assertRaises(ir.NoInstanceError) as e: 670 | ast.check_string(text) 671 | got, loc = e.exception.args 672 | self.assertEqual("C", got) 673 | self.assertEqual(text.index("C where"), loc) 674 | 675 | def test_check_program_instance(self): 676 | c, i, _, _ = ast.check_string( 677 | """ 678 | class C where open C 679 | instance: C 680 | where 681 | def f [p: C] := Type 682 | example := f 683 | """ 684 | ) 685 | assert isinstance(c, Class) 686 | assert isinstance(i, Instance) 687 | self.assertEqual(c.instances[0], i.id) 688 | 689 | def test_check_program_instance_miss_failed(self): 690 | text = """ 691 | class C where 692 | c: Type 693 | open C 694 | instance: C 695 | where 696 | """ 697 | with self.assertRaises(ast.FieldMissError) as e: 698 | ast.check_string(text) 699 | name, loc = e.exception.args 700 | self.assertEqual("c", name) 701 | self.assertEqual(text.index("instance: C"), loc) 702 | 703 | def test_check_program_instance_unknown_failed(self): 704 | text = """ 705 | class A where 706 | a: Type 707 | open A 708 | class B where 709 | b: Type 710 | open B 711 | instance: A 712 | where 713 | a := Type 714 | b := Type 715 | """ 716 | with self.assertRaises(ast.UnknownFieldError) as e: 717 | ast.check_string(text) 718 | want, got, loc = e.exception.args 719 | self.assertEqual("A", want) 720 | self.assertEqual("b", got) 721 | self.assertEqual(text.index("b :="), loc) 722 | 723 | def test_check_program_instance_mismatch_failed(self): 724 | text = "instance: Type\nwhere" 725 | with self.assertRaises(ast.TypeMismatchError) as e: 726 | ast.check_string(text) 727 | want, got, loc = e.exception.args 728 | self.assertEqual("class", want) 729 | self.assertEqual("Type", got) 730 | self.assertEqual(text.index("Type"), loc) 731 | 732 | def test_check_program_field(self): 733 | _, _, _, e = ast.check_string( 734 | """ 735 | inductive Void where open Void 736 | 737 | class C where 738 | c: Type 739 | open C 740 | 741 | instance: C 742 | where 743 | c := Void 744 | 745 | example: Type := c 746 | """ 747 | ) 748 | assert isinstance(e, Example) 749 | assert isinstance(e.body, ir.Data) 750 | self.assertEqual("Void", e.body.name.text) 751 | 752 | def test_check_program_field_parametric(self): 753 | _, _, _, d = ast.check_string( 754 | """ 755 | class Default (T: Type) where 756 | default: T 757 | open Default 758 | 759 | inductive Data where 760 | | A 761 | | B 762 | open Data 763 | 764 | instance: Default Data 765 | where 766 | default := A 767 | 768 | def f := default Data 769 | """ 770 | ) 771 | assert isinstance(d, Def) 772 | self.assertEqual("Data.A", str(d.body)) 773 | 774 | def test_check_program_class_add(self): 775 | _, _, _, _, f = ast.check_string( 776 | """ 777 | inductive N where 778 | | Z 779 | | S (n: N) 780 | open N 781 | 782 | def addN (a: N) (b: N): N := 783 | match a with 784 | | Z => b 785 | | S pred => S (addN pred b) 786 | 787 | class Add {T: Type} where 788 | add: (a: T) -> (b: T) -> T 789 | open Add 790 | 791 | instance: Add (T := N) 792 | where 793 | add := addN 794 | 795 | def f := add (S Z) (S Z) 796 | """ 797 | ) 798 | assert isinstance(f, Def) 799 | self.assertEqual("(N.S (N.S N.Z))", str(f.body)) 800 | -------------------------------------------------------------------------------- /src/TinyLean/tests/test_main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest import TestCase 3 | 4 | from .. import ast, ir 5 | 6 | 7 | def nat_to_int(v: ir.IR): 8 | n = 0 9 | while True: 10 | if isinstance(v, ir.Fn): 11 | v = v.body 12 | else: 13 | break 14 | while True: 15 | if isinstance(v, ir.Call): 16 | assert isinstance(v.callee, ir.Ref) 17 | assert v.callee.name.text == "S" 18 | v = v.arg 19 | n += 1 20 | else: 21 | assert isinstance(v, ir.Ref) 22 | assert v.name.text == "Z" 23 | break 24 | return n 25 | 26 | 27 | class TestMain(TestCase): 28 | def test_nat(self): 29 | _, _, _, _3, _6, _9 = ast.check_string( 30 | """ 31 | def Nat: Type := 32 | (T: Type) -> (S: (n: T) -> T) -> (Z: T) -> T 33 | 34 | def add (a: Nat) (b: Nat): Nat := 35 | fun T S Z => (a T S) (b T S Z) 36 | 37 | def mul (a: Nat) (b: Nat): Nat := 38 | fun T S Z => (a T) (b T S) Z 39 | 40 | def _3: Nat := fun T S Z => S (S (S Z)) 41 | 42 | def _6: Nat := add _3 _3 43 | 44 | def _9: Nat := mul _3 _3 45 | """ 46 | ) 47 | self.assertEqual(3, nat_to_int(_3.body)) 48 | self.assertEqual(6, nat_to_int(_6.body)) 49 | self.assertEqual(9, nat_to_int(_9.body)) 50 | 51 | def test_leibniz_equality(self): 52 | ast.check_string( 53 | """ 54 | def Eq (T: Type) (a: T) (b: T): Type := 55 | (p: (v: T) -> Type) -> (pa: p a) -> p b 56 | 57 | def refl (T: Type) (a: T): Eq T a a := 58 | fun p pa => pa 59 | 60 | def sym (T: Type) (a: T) (b: T) (p: Eq T a b): Eq T b a := 61 | (p (fun b => Eq T b a)) (refl T a) 62 | 63 | def A: Type := Type 64 | 65 | def B: Type := Type 66 | 67 | def lemma: Eq Type A B := refl Type A 68 | 69 | def theorem (p: Eq Type A B): Eq Type B A := sym Type A B lemma 70 | """ 71 | ) 72 | 73 | def test_leibniz_equality_failed(self): 74 | with self.assertRaises(ast.TypeMismatchError) as e: 75 | ast.check_string( 76 | """ 77 | def Eq (T: Type) (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b 78 | def refl (T: Type) (a: T): Eq T a a := fun p => fun pa => pa 79 | def A: Type := (a: Type) -> Type 80 | def B: Type := (a: (b: Type) -> Type) -> Type 81 | def _: Eq Type A B := refl Type A 82 | /- ^~~^ failed here -/ 83 | """ 84 | ) 85 | want, got, loc = e.exception.args 86 | self.assertEqual(323, loc) 87 | self.assertEqual( 88 | "(p: (v: Type) → Type) → (pa: (p (a: Type) → Type)) → (p (a: (b: Type) → Type) → Type)", 89 | str(want), 90 | ) 91 | self.assertEqual( 92 | "(p: (v: Type) → Type) → (pa: (p (a: Type) → Type)) → (p (a: Type) → Type)", 93 | str(got), 94 | ) 95 | 96 | def test_markdown(self): 97 | results = ast.check_string( 98 | """\ 99 | # Heading 1 100 | 101 | ```lean 102 | def Eq (T: Type) (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b 103 | 104 | def refl (T: Type) (a: T): Eq T a a := fun p => fun pa => pa 105 | ``` 106 | 107 | ```lean 108 | def sym (T: Type) (a: T) (b: T) (p: Eq T a b): Eq T b a := (p (fun b => Eq T b a)) (refl T a) 109 | ``` 110 | 111 | ```lean4 112 | def A: Type := Type 113 | ``` 114 | 115 | ```python 116 | print("Hello, world!") 117 | ``` 118 | 119 | ``` 120 | Broken code. 121 | ``````` 122 | 123 | Footer. 124 | """, 125 | True, 126 | ) 127 | self.assertEqual(3, len(results)) 128 | eq, refl, sym = results 129 | self.assertEqual("Eq", eq.name.text) 130 | self.assertEqual("refl", refl.name.text) 131 | self.assertEqual("sym", sym.name.text) 132 | 133 | def test_readme(self): 134 | p = Path(__file__).parent / ".." / ".." / ".." / ".github" / "README.md" 135 | with open(p, encoding="utf-8") as f: 136 | results = ast.check_string(f.read(), True) 137 | self.assertGreater(len(results), 1) 138 | 139 | def test_example(self): 140 | ast.check_string( 141 | """ 142 | def T: Type := Type 143 | example: Type := T 144 | """ 145 | ) 146 | 147 | def test_leibniz_equality_implicit(self): 148 | ast.check_string( 149 | """ 150 | def Eq {T: Type} (a: T) (b: T): Type := 151 | (p: (v: T) -> Type) -> (pa: p a) -> p b 152 | 153 | def refl {T: Type} (a: T): Eq a a := 154 | fun p pa => pa 155 | 156 | def sym {T: Type} (a: T) (b: T) (p: Eq a b): Eq b a := 157 | (p (fun b => Eq b a)) (refl a) 158 | 159 | def A: Type := Type 160 | 161 | def B: Type := Type 162 | 163 | def lemma: Eq A B := refl A 164 | 165 | def theorem (p: Eq A B): Eq B A := sym A B lemma 166 | """ 167 | ) 168 | 169 | def test_operator_overloading(self): 170 | ast.check_string( 171 | """ 172 | inductive N where 173 | | Z 174 | | S (n: N) 175 | open N 176 | 177 | def addN (a: N) (b: N): N := 178 | match a with 179 | | Z => b 180 | | S pred => S (addN pred b) 181 | 182 | class Add {T: Type} where 183 | add: (a: T) -> (b: T) -> T 184 | open Add 185 | 186 | instance: Add (T := N) 187 | where 188 | add := addN 189 | 190 | def f := (S Z) + (S Z) 191 | """ 192 | ) 193 | -------------------------------------------------------------------------------- /src/TinyLean/tests/test_parser.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from pyparsing import ParseException 4 | 5 | from . import parse 6 | from .. import ( 7 | ast, 8 | Name, 9 | grammar, 10 | Param, 11 | Decl, 12 | Def, 13 | Example, 14 | Data, 15 | Ctor, 16 | Class, 17 | Instance, 18 | ) 19 | 20 | 21 | class TestParser(TestCase): 22 | def test_fresh(self): 23 | self.assertNotEqual(Name("i").id, Name("j").id) 24 | 25 | def test_parse_name(self): 26 | x = parse(grammar.name, " hello")[0] 27 | assert isinstance(x, Name) 28 | self.assertEqual("hello", x.text) 29 | 30 | def test_parse_name_unbound(self): 31 | x = parse(grammar.name, "_")[0] 32 | self.assertTrue(x.is_unbound()) 33 | 34 | def test_parse_type(self): 35 | x = parse(grammar.type_, " Type")[0] 36 | assert isinstance(x, ast.Type) 37 | self.assertEqual(2, x.loc) 38 | 39 | def test_parse_reference(self): 40 | x = parse(grammar.ref, " hello")[0] 41 | assert isinstance(x, ast.Ref) 42 | self.assertEqual(2, x.loc) 43 | self.assertEqual("hello", x.name.text) 44 | 45 | def test_parse_paren_expr(self): 46 | x = parse(grammar.p_expr, "(hello)")[0] 47 | assert isinstance(x, ast.Ref) 48 | self.assertEqual(1, x.loc) 49 | self.assertEqual("hello", x.name.text) 50 | 51 | def test_parse_implicit_param(self): 52 | x = parse(grammar.i_param, " {a: b}")[0] 53 | assert isinstance(x, Param) 54 | self.assertTrue(x.is_implicit) 55 | self.assertEqual("a", x.name.text) 56 | assert isinstance(x.type, ast.Ref) 57 | self.assertEqual(5, x.type.loc) 58 | 59 | def test_parse_explicit_param(self): 60 | x = parse(grammar.e_param, " (a : Type)")[0] 61 | assert isinstance(x, Param) 62 | self.assertFalse(x.is_implicit) 63 | self.assertEqual("a", x.name.text) 64 | assert isinstance(x.type, ast.Type) 65 | self.assertEqual(6, x.type.loc) 66 | 67 | def test_parse_call(self): 68 | x = parse(grammar.call, "a b")[0] 69 | assert isinstance(x, ast.Call) 70 | self.assertEqual(0, x.loc) 71 | assert isinstance(x.callee, ast.Ref) 72 | self.assertEqual(0, x.callee.loc) 73 | self.assertEqual("a", x.callee.name.text) 74 | self.assertEqual(2, x.arg.loc) 75 | assert isinstance(x.arg, ast.Ref) 76 | self.assertEqual("b", x.arg.name.text) 77 | 78 | def test_parse_call_paren(self): 79 | x = parse(grammar.call, "(a) b (Type)")[0] 80 | assert isinstance(x, ast.Call) 81 | self.assertEqual(0, x.loc) 82 | assert isinstance(x.callee, ast.Call) 83 | assert isinstance(x.callee.callee, ast.Ref) 84 | self.assertEqual(1, x.callee.callee.loc) 85 | self.assertEqual("a", x.callee.callee.name.text) 86 | assert isinstance(x.callee.arg, ast.Ref) 87 | self.assertEqual(4, x.callee.arg.loc) 88 | self.assertEqual("b", x.callee.arg.name.text) 89 | assert isinstance(x.arg, ast.Type) 90 | self.assertEqual(7, x.arg.loc) 91 | 92 | def test_parse_call_paren_function(self): 93 | x = parse(grammar.call, "(fun _ => Type) Type")[0] 94 | assert isinstance(x, ast.Call) 95 | self.assertEqual(0, x.loc) 96 | assert isinstance(x.callee, ast.Fn) 97 | self.assertEqual(1, x.callee.loc) 98 | self.assertTrue(x.callee.param.is_unbound()) 99 | assert isinstance(x.callee.body, ast.Type) 100 | self.assertEqual(10, x.callee.body.loc) 101 | assert isinstance(x.arg, ast.Type) 102 | self.assertEqual(16, x.arg.loc) 103 | 104 | def test_parse_function_type(self): 105 | x = parse(grammar.fn_type, " (a : Type) -> a")[0] 106 | assert isinstance(x, ast.FnType) 107 | assert isinstance(x.param, Param) 108 | self.assertEqual("a", x.param.name.text) 109 | assert isinstance(x.param.type, ast.Type) 110 | self.assertEqual(7, x.param.type.loc) 111 | assert isinstance(x.ret, ast.Ref) 112 | self.assertEqual("a", x.ret.name.text) 113 | self.assertEqual(16, x.ret.loc) 114 | 115 | def test_parse_function_type_long(self): 116 | x = parse(grammar.fn_type, " {a : Type} -> (b: Type) -> a")[0] 117 | assert isinstance(x, ast.FnType) 118 | assert isinstance(x.param, Param) 119 | self.assertEqual("a", x.param.name.text) 120 | assert isinstance(x.param.type, ast.Type) 121 | self.assertEqual(6, x.param.type.loc) 122 | assert isinstance(x.ret, ast.FnType) 123 | assert isinstance(x.ret.param, Param) 124 | self.assertEqual("b", x.ret.param.name.text) 125 | assert isinstance(x.ret.param.type, ast.Type) 126 | self.assertEqual(19, x.ret.param.type.loc) 127 | assert isinstance(x.ret.ret, ast.Ref) 128 | self.assertEqual("a", x.ret.ret.name.text) 129 | self.assertEqual(28, x.ret.ret.loc) 130 | 131 | def test_parse_function(self): 132 | x = parse(grammar.fn, " fun a => a")[0] 133 | assert isinstance(x, ast.Fn) 134 | self.assertEqual(2, x.loc) 135 | assert isinstance(x.param, Name) 136 | self.assertEqual("a", x.param.text) 137 | assert isinstance(x.body, ast.Ref) 138 | self.assertEqual("a", x.body.name.text) 139 | self.assertEqual(11, x.body.loc) 140 | 141 | def test_parse_function_long(self): 142 | x = parse(grammar.fn, " fun a => fun b => a b")[0] 143 | assert isinstance(x, ast.Fn) 144 | self.assertEqual(3, x.loc) 145 | assert isinstance(x.param, Name) 146 | self.assertEqual("a", x.param.text) 147 | assert isinstance(x.body, ast.Fn) 148 | self.assertEqual(12, x.body.loc) 149 | assert isinstance(x.body.param, Name) 150 | self.assertEqual("b", x.body.param.text) 151 | assert isinstance(x.body.body, ast.Call) 152 | self.assertEqual(21, x.body.body.loc) 153 | assert isinstance(x.body.body.callee, ast.Ref) 154 | self.assertEqual("a", x.body.body.callee.name.text) 155 | assert isinstance(x.body.body.arg, ast.Ref) 156 | self.assertEqual("b", x.body.body.arg.name.text) 157 | 158 | def test_parse_function_multi(self): 159 | x = parse(grammar.fn, " fun c d => c d")[0] 160 | assert isinstance(x, ast.Fn) 161 | self.assertEqual(2, x.loc) 162 | assert isinstance(x.param, Name) 163 | self.assertEqual("c", x.param.text) 164 | assert isinstance(x.body, ast.Fn) 165 | self.assertEqual(2, x.body.loc) 166 | assert isinstance(x.body.param, Name) 167 | self.assertEqual("d", x.body.param.text) 168 | 169 | def test_parse_definition_constant(self): 170 | x = parse(grammar.def_, " def f : Type := Type")[0] 171 | assert isinstance(x, Def) 172 | self.assertEqual(6, x.loc) 173 | self.assertEqual("f", x.name.text) 174 | self.assertEqual(0, len(x.params)) 175 | assert isinstance(x.ret, ast.Type) 176 | self.assertEqual(10, x.ret.loc) 177 | assert isinstance(x.body, ast.Type) 178 | self.assertEqual(18, x.body.loc) 179 | 180 | def test_parse_definition(self): 181 | x = parse(grammar.def_, " def f {a: Type} (b: Type): Type := a")[0] 182 | assert isinstance(x, Def) 183 | self.assertEqual(6, x.loc) 184 | self.assertEqual("f", x.name.text) 185 | assert isinstance(x.params, list) 186 | self.assertEqual(2, len(x.params)) 187 | assert isinstance(x.params[0], ast.Param) 188 | self.assertTrue(x.params[0].is_implicit) 189 | self.assertEqual("a", x.params[0].name.text) 190 | assert isinstance(x.params[0].type, ast.Type) 191 | self.assertEqual(12, x.params[0].type.loc) 192 | assert isinstance(x.params[1], ast.Param) 193 | self.assertFalse(x.params[1].is_implicit) 194 | self.assertEqual("b", x.params[1].name.text) 195 | assert isinstance(x.params[1].type, ast.Type) 196 | self.assertEqual(22, x.params[1].type.loc) 197 | 198 | def test_parse_program(self): 199 | x = list( 200 | parse( 201 | grammar.program, 202 | """ 203 | def a: Type := Type 204 | def b: Type := Type 205 | """, 206 | ) 207 | ) 208 | self.assertEqual(2, len(x)) 209 | assert isinstance(x[0], Decl) 210 | self.assertEqual("a", x[0].name.text) 211 | assert isinstance(x[1], Decl) 212 | self.assertEqual("b", x[1].name.text) 213 | 214 | def test_parse_example(self): 215 | x = parse(grammar.example, " example: Type := Type")[0] 216 | assert isinstance(x, Example) 217 | self.assertEqual(2, x.loc) 218 | self.assertEqual(0, len(x.params)) 219 | assert isinstance(x.ret, ast.Type) 220 | assert isinstance(x.body, ast.Type) 221 | 222 | def test_parse_placeholder(self): 223 | x = parse(grammar.fn, " fun _ => _")[0] 224 | assert isinstance(x, ast.Fn) 225 | self.assertTrue(x.param.is_unbound()) 226 | assert isinstance(x.body, ast.Placeholder) 227 | self.assertEqual(10, x.body.loc) 228 | 229 | def test_parse_return_type(self): 230 | x = parse(grammar.return_type, ": Type")[0] 231 | assert isinstance(x, ast.Type) 232 | self.assertEqual(2, x.loc) 233 | 234 | def test_parse_return_placeholder(self): 235 | x = parse(grammar.return_type, "")[0] 236 | assert isinstance(x, ast.Placeholder) 237 | self.assertFalse(x.is_user) 238 | 239 | def test_parse_definition_no_return(self): 240 | x = parse(grammar.def_, "def a := Type")[0] 241 | assert isinstance(x, Def) 242 | assert isinstance(x.ret, ast.Placeholder) 243 | self.assertFalse(x.ret.is_user) 244 | 245 | def test_parse_call_implicit(self): 246 | x = parse(grammar.call, "a ( T := Nat )")[0] 247 | assert isinstance(x, ast.Call) 248 | assert isinstance(x.callee, ast.Ref) 249 | self.assertEqual("a", x.callee.name.text) 250 | self.assertEqual("T", x.implicit) 251 | assert isinstance(x.arg, ast.Ref) 252 | self.assertEqual("Nat", x.arg.name.text) 253 | 254 | def test_parse_call_explicit(self): 255 | x = parse(grammar.call, "a b")[0] 256 | assert isinstance(x, ast.Call) 257 | assert isinstance(x.callee, ast.Ref) 258 | self.assertEqual("a", x.callee.name.text) 259 | self.assertFalse(x.implicit) 260 | assert isinstance(x.arg, ast.Ref) 261 | self.assertEqual("b", x.arg.name.text) 262 | 263 | def test_parse_definition_call_implicit(self): 264 | x = parse( 265 | grammar.def_, 266 | """ 267 | def f: Type := a ( 268 | T := Nat 269 | ) b 270 | """, 271 | )[0] 272 | assert isinstance(x, Def) 273 | assert isinstance(x.body, ast.Call) 274 | assert isinstance(x.body.callee, ast.Call) 275 | assert isinstance(x.body.callee.callee, ast.Ref) 276 | self.assertEqual("a", x.body.callee.callee.name.text) 277 | assert isinstance(x.body.callee.arg, ast.Ref) 278 | self.assertEqual("Nat", x.body.callee.arg.name.text) 279 | self.assertEqual("T", x.body.callee.implicit) 280 | assert isinstance(x.body.arg, ast.Ref) 281 | self.assertEqual("b", x.body.arg.name.text) 282 | 283 | def test_parse_datatype_empty(self): 284 | x = parse( 285 | grammar.data, 286 | """ 287 | inductive Void where 288 | open Void 289 | """, 290 | )[0] 291 | assert isinstance(x, Data) 292 | self.assertEqual("Void", x.name.text) 293 | self.assertEqual(0, len(x.params)) 294 | self.assertEqual(0, len(x.ctors)) 295 | 296 | def test_parse_datatype_empty_failed(self): 297 | with self.assertRaises(ParseException) as e: 298 | parse(grammar.data, "inductive Foo where open Bar") 299 | self.assertIn("open and datatype name mismatch", str(e.exception)) 300 | 301 | def test_parse_datatype_ctors(self): 302 | x = parse( 303 | grammar.data, 304 | """ 305 | inductive D {T: Type} (U: Type) where 306 | | A 307 | | B {X: Type} (Y: Type) (U := Type) 308 | | C (T := Type) 309 | open D 310 | """, 311 | )[0] 312 | assert isinstance(x, Data) 313 | self.assertEqual("D", x.name.text) 314 | self.assertEqual(2, len(x.params)) 315 | self.assertEqual("T", x.params[0].name.text) 316 | self.assertEqual("U", x.params[1].name.text) 317 | self.assertEqual(3, len(x.ctors)) 318 | a, b, c = x.ctors 319 | assert isinstance(a, Ctor) 320 | self.assertEqual("A", a.name.text) 321 | self.assertEqual(0, len(a.params)) 322 | assert isinstance(b, Ctor) 323 | self.assertEqual("B", b.name.text) 324 | self.assertEqual(2, len(b.params)) 325 | self.assertEqual(1, len(b.ty_args)) 326 | self.assertEqual("U", b.ty_args[0][0].name.text) 327 | assert isinstance(c, Ctor) 328 | self.assertEqual("C", c.name.text) 329 | self.assertEqual(0, len(c.params)) 330 | self.assertEqual(1, len(c.ty_args)) 331 | self.assertEqual("T", c.ty_args[0][0].name.text) 332 | 333 | def test_parse_expr_nomatch(self): 334 | x = parse(grammar.nomatch, "nomatch x")[0] 335 | assert isinstance(x, ast.Nomatch) 336 | assert isinstance(x.arg, ast.Ref) 337 | self.assertEqual("x", x.arg.name.text) 338 | 339 | def test_parse_expr_case(self): 340 | x = parse(grammar.case, "| A a b => a")[0] 341 | assert isinstance(x, ast.Case) 342 | self.assertEqual(2, x.loc) 343 | self.assertEqual("A", x.ctor.name.text) 344 | self.assertEqual(2, len(x.params)) 345 | self.assertEqual("a", x.params[0].text) 346 | self.assertEqual("b", x.params[1].text) 347 | assert isinstance(x.body, ast.Ref) 348 | self.assertEqual("a", x.body.name.text) 349 | 350 | def test_parse_expr_match(self): 351 | x = parse( 352 | grammar.match, 353 | """ 354 | match Type with 355 | | A a => a 356 | | B b => b 357 | | _ => x /- not actually a default case -/ 358 | """, 359 | )[0] 360 | assert isinstance(x, ast.Match) 361 | assert isinstance(x.arg, ast.Type) 362 | self.assertEqual(3, len(x.cases)) 363 | self.assertTrue(x.cases[2].ctor.name.is_unbound()) 364 | 365 | def test_parse_class_param(self): 366 | x = parse(grammar.c_param, "[p: GAdd Type]")[0] 367 | assert isinstance(x, Param) 368 | self.assertEqual("p", x.name.text) 369 | assert isinstance(x.type, ast.Call) 370 | assert isinstance(x.type.callee, ast.Ref) 371 | self.assertEqual("GAdd", x.type.callee.name.text) 372 | assert isinstance(x.type.arg, ast.Type) 373 | self.assertTrue(x.is_implicit) 374 | 375 | def test_parse_class_empty(self): 376 | x = parse( 377 | grammar.class_, 378 | """ 379 | class A where 380 | open A 381 | """, 382 | )[0] 383 | assert isinstance(x, Class) 384 | self.assertEqual("A", x.name.text) 385 | self.assertEqual(0, len(x.params)) 386 | self.assertEqual(0, len(x.fields)) 387 | 388 | def test_parse_class_empty_failed(self): 389 | with self.assertRaises(ParseException) as e: 390 | parse(grammar.class_, "class A where open B") 391 | self.assertIn("open and class name mismatch", str(e.exception)) 392 | 393 | def test_parse_class_fields(self): 394 | x = parse( 395 | grammar.class_, 396 | """ 397 | class Op {T: Type} where 398 | add: (a: T) -> (b: T) -> T 399 | mul: (a: T) -> (b: T) -> T 400 | open Op 401 | """, 402 | )[0] 403 | assert isinstance(x, Class) 404 | self.assertEqual(1, len(x.params)) 405 | self.assertEqual("T", x.params[0].name.text) 406 | self.assertEqual("Op", x.name.text) 407 | self.assertEqual(2, len(x.fields)) 408 | self.assertEqual("add", x.fields[0].name.text) 409 | assert isinstance(x.fields[0].type, ast.FnType) 410 | self.assertEqual("mul", x.fields[1].name.text) 411 | assert isinstance(x.fields[1].type, ast.FnType) 412 | 413 | def test_parse_instance_empty(self): 414 | x = parse(grammar.inst, "instance : Monad A\nwhere")[0] 415 | assert isinstance(x, Instance) 416 | assert isinstance(x.type, ast.Call) 417 | assert isinstance(x.type.callee, ast.Ref) 418 | self.assertEqual("Monad", x.type.callee.name.text) 419 | assert isinstance(x.type.arg, ast.Ref) 420 | self.assertEqual("A", x.type.arg.name.text) 421 | self.assertEqual(0, len(x.fields)) 422 | 423 | def test_parse_instance_fields(self): 424 | x = parse( 425 | grammar.inst, 426 | """ 427 | instance : AddOp T 428 | where 429 | add := (a: T) -> (b: T) -> T 430 | """, 431 | )[0] 432 | assert isinstance(x, Instance) 433 | self.assertEqual(1, len(x.fields)) 434 | f = x.fields[0] 435 | assert isinstance(f, tuple) 436 | n, v = f 437 | assert isinstance(n, ast.Ref) 438 | self.assertEqual("add", n.name.text) 439 | assert isinstance(v, ast.FnType) 440 | 441 | def test_parse_infix_op(self): 442 | x = parse(grammar.expr, "x + y")[0] 443 | assert isinstance(x, ast.Call) 444 | assert isinstance(x.callee, ast.Call) 445 | assert isinstance(x.callee.callee, ast.Ref) 446 | self.assertEqual("add", x.callee.callee.name.text) 447 | assert isinstance(x.callee.arg, ast.Ref) 448 | self.assertEqual("x", x.callee.arg.name.text) 449 | assert isinstance(x.arg, ast.Ref) 450 | self.assertEqual("y", x.arg.name.text) 451 | -------------------------------------------------------------------------------- /src/TinyLean/tests/test_resolver.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from . import resolve_expr, resolve, resolve_md 4 | from .. import ast, Data 5 | 6 | 7 | class TestNameResolver(TestCase): 8 | def test_resolve_expr_function(self): 9 | x = resolve_expr("fun a => fun b => a b") 10 | assert isinstance(x, ast.Fn) 11 | assert isinstance(x.body, ast.Fn) 12 | assert isinstance(x.body.body, ast.Call) 13 | assert isinstance(x.body.body.callee, ast.Ref) 14 | assert isinstance(x.body.body.arg, ast.Ref) 15 | self.assertEqual(x.param.id, x.body.body.callee.name.id) 16 | self.assertEqual(x.body.param.id, x.body.body.arg.name.id) 17 | 18 | def test_resolve_expr_function_shadowed(self): 19 | x = resolve_expr("fun a => fun a => a") 20 | assert isinstance(x, ast.Fn) 21 | assert isinstance(x.body, ast.Fn) 22 | assert isinstance(x.body.body, ast.Ref) 23 | self.assertNotEqual(x.param.id, x.body.body.name.id) 24 | self.assertEqual(x.body.param.id, x.body.body.name.id) 25 | 26 | def test_resolve_expr_function_failed(self): 27 | with self.assertRaises(ast.UndefinedVariableError) as e: 28 | resolve_expr("fun a => b") 29 | n, loc = e.exception.args 30 | self.assertEqual(9, loc) 31 | self.assertEqual("b", n) 32 | 33 | def test_resolve_expr_function_type(self): 34 | x = resolve_expr("{a: Type} -> (b: Type) -> a") 35 | assert isinstance(x, ast.FnType) 36 | assert isinstance(x.ret, ast.FnType) 37 | assert isinstance(x.ret.ret, ast.Ref) 38 | self.assertEqual(x.param.name.id, x.ret.ret.name.id) 39 | self.assertNotEqual(x.ret.param.name.id, x.ret.ret.name.id) 40 | 41 | def test_resolve_expr_function_type_failed(self): 42 | with self.assertRaises(ast.UndefinedVariableError) as e: 43 | resolve_expr("{a: Type} -> (b: Type) -> c") 44 | n, loc = e.exception.args 45 | self.assertEqual(26, loc) 46 | self.assertEqual("c", n) 47 | 48 | def test_resolve_program(self): 49 | resolve( 50 | """ 51 | def f0 (a: Type): Type := a 52 | def f1 (a: Type): Type := f0 a 53 | """ 54 | ) 55 | 56 | def test_resolve_program_failed(self): 57 | with self.assertRaises(ast.UndefinedVariableError) as e: 58 | resolve("def f (a: Type) (b: c): Type := Type") 59 | n, loc = e.exception.args 60 | self.assertEqual(20, loc) 61 | self.assertEqual("c", n) 62 | 63 | def test_resolve_program_duplicate(self): 64 | with self.assertRaises(ast.DuplicateVariableError) as e: 65 | resolve( 66 | """ 67 | def f0: Type := Type 68 | def f0: Type := Type 69 | """ 70 | ) 71 | n, loc = e.exception.args 72 | self.assertEqual(58, loc) 73 | self.assertEqual("f0", n) 74 | 75 | def test_resolve_md(self): 76 | resolve_md( 77 | """\ 78 | # Heading 79 | 80 | ```lean 81 | def a := Type 82 | ``` 83 | 84 | Some text. 85 | 86 | ```lean 87 | def b := a 88 | ``` 89 | 90 | Footer. 91 | """ 92 | ) 93 | 94 | def test_resolve_expr_placeholder(self): 95 | resolve_expr("{a: Type} -> (b: Type) -> _") 96 | 97 | def test_resolve_datatype_empty(self): 98 | resolve("inductive Void where open Void") 99 | 100 | def test_resolve_datatype_nat(self): 101 | x = resolve( 102 | """ 103 | inductive N where 104 | | Z 105 | | S (n: N) 106 | open N 107 | """ 108 | )[0] 109 | assert isinstance(x, Data) 110 | a = x.name.id 111 | b_typ = x.ctors[1].params[0].type 112 | assert isinstance(b_typ, ast.Ref) 113 | b = b_typ.name.id 114 | self.assertEqual(a, b) 115 | 116 | def test_resolve_datatype_maybe(self): 117 | x = resolve( 118 | """ 119 | inductive Maybe (A: Type) where 120 | | Nothing 121 | | Just (a: A) 122 | open Maybe 123 | """ 124 | )[0] 125 | assert isinstance(x, Data) 126 | a = x.params[0].name.id 127 | b_typ = x.ctors[1].params[0].type 128 | assert isinstance(b_typ, ast.Ref) 129 | b = b_typ.name.id 130 | self.assertEqual(a, b) 131 | 132 | def test_resolve_datatype_vec(self): 133 | resolve( 134 | """ 135 | inductive N where 136 | | Z 137 | | S (n: N) 138 | open N 139 | 140 | inductive Vec (A : Type) (n : N) where 141 | | Nil (n := Z) 142 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m) 143 | open Vec 144 | """ 145 | ) 146 | 147 | def test_resolve_datatype_duplicate(self): 148 | with self.assertRaises(ast.DuplicateVariableError) as e: 149 | resolve("inductive A where | A open A") 150 | name, loc = e.exception.args 151 | self.assertEqual("A", name) 152 | self.assertEqual(20, loc) 153 | 154 | def test_resolve_match(self): 155 | resolve( 156 | """ 157 | inductive A where | AA (T: Type) open A 158 | example (x: A) := 159 | match x with 160 | | A t => t 161 | """ 162 | ) 163 | 164 | def test_resolve_match_failed(self): 165 | with self.assertRaises(ast.UndefinedVariableError) as e: 166 | resolve( 167 | """ 168 | inductive A where | AA open A 169 | example (x: A) := 170 | match x with 171 | | A => b 172 | """ 173 | ) 174 | name, loc = e.exception.args 175 | self.assertEqual("b", name) 176 | self.assertEqual(137, loc) 177 | 178 | def test_resolve_class(self): 179 | resolve( 180 | """ 181 | class A (T: Type) where 182 | a: T 183 | open A 184 | example := a 185 | """ 186 | ) 187 | 188 | def test_resolve_instance(self): 189 | resolve( 190 | """ 191 | inductive Void where open Void 192 | 193 | class A (T: Type) where 194 | a: T 195 | open A 196 | 197 | instance: A Void 198 | where 199 | a := Type 200 | """ 201 | ) 202 | 203 | def test_resolve_instance_failed(self): 204 | text = """ 205 | class C where 206 | c: Type 207 | open C 208 | instance: C 209 | where 210 | c := Type 211 | c := Type 212 | """ 213 | with self.assertRaises(ast.DuplicateVariableError) as e: 214 | resolve(text) 215 | name, loc = e.exception.args 216 | self.assertEqual("c", name) 217 | self.assertEqual(text.rindex("c :="), loc) 218 | --------------------------------------------------------------------------------