├── .github
├── README.md
└── workflows
│ ├── release.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── pyproject.toml
└── src
└── TinyLean
├── __init__.py
├── __main__.py
├── ast.py
├── grammar.py
├── ir.py
└── tests
├── __init__.py
├── onboard.py
├── test_checker.py
├── test_main.py
├── test_parser.py
└── test_resolver.py
/.github/README.md:
--------------------------------------------------------------------------------
1 | # TinyLean
2 |
3 | 
4 | 
5 | [](https://github.com/anqurvanillapy/TinyLean/actions/workflows/test.yml)
6 | [](https://codecov.io/gh/anqurvanillapy/TinyLean)
7 |
8 | 仅用不到 1K 行 Python 实现的 Lean 4 风格定理证明器。
9 |
10 | ```lean
11 | def TinyLean ≔ {T: Type} → (a: T) → T
12 | ```
13 |
14 | * 你可以在这个项目中学到基础的**定理证明**(theorem proving)原理及其实现。
15 | * 丰富的中文注释,并以**工业界常用词汇**的使用优先,帮助你轻松地将工业及学术用语联系在一起。
16 | * 丰富的**单元测试**和高覆盖率,帮助你在做出任何修改时执行回归测试,或改写成其他语言项目时当做参考。
17 |
18 | > [!NOTE]
19 | > In pre-`v1` releases I used English everywhere in this project for convenience. So I feel very sorry for the early
20 | > stargazers who might not expect **full Chinese content** here (including documentation and comments), since I decided
21 | > to retarget the audience. Please reach me if you have trouble of any kind.
22 |
23 | ## ❓ 疑问
24 |
25 | > *“这个定理证明有多‘强’?”*
26 |
27 | 如果你了解甚至十分熟悉 PLT(编程语言理论)的知识,这个项目实现了以下语言特性:
28 |
29 |
30 | 点击展开剧透
31 |
32 | - Dependently-typed lambda calculus
33 | - Holes/goals
34 | - Implicit arguments (no first-class polymorphism)
35 | - Inductive data type (à la pi-forall)
36 | - Dependent pattern matching (à la pi-forall)
37 | - Typeclass (no chained instances)
38 |
39 |
40 |
41 | TinyLean 可能是你能找得到的以上特性结合一块的最短实现。在本文末尾的**探索**部分你还能找到更多资料。
42 |
43 | > *“我对定理证明不感兴趣,所以它还能做什么?”*
44 |
45 | 当你有一个比 C++、Java、TypeScript、Rust、Haskell 强大的类型系统,它的实现不到 1K 行,还适合移植到其他的语言时,再仔细思考一下你想拿它做些什么事情。
46 |
47 | 我用这些知识实现了 [RowScript] 编程语言,一个支持行多态(row polymorphism)的 JavaScript 方言。
48 |
49 | [RowScript]: https://github.com/rowscript/rowscript
50 |
51 | > *“项目名为什么是 TinyLean?为什么项目名不是 mini-lean,μLean 之类的其他名字呢?”*
52 |
53 | “TinyLean”是对 [TinyIdris] 项目的致敬,TinyIdris 是一个“极简” Idris 编程语言的实现。
54 |
55 | [TinyIdris]: https://github.com/edwinb/SPLV20
56 |
57 | > *“为什么使用中文?”*
58 |
59 | 即使使用英文,PLT 领域内就连最简单的术语都充满着歧义和晦涩。如果你对 PLT
60 | 里的各种术语仍未祛魅,去搞清楚 [dependent sum type] 和 [sum type] 的区别吧,这是每一个 PLer 学习过程中必吃的 💩。
61 |
62 | 我会用尽可能简单、常见的语言和术语,来帮助你祛魅的过程。
63 |
64 | [dependent sum type]: https://ncatlab.org/nlab/show/dependent+sum
65 |
66 | [sum type]: https://ncatlab.org/nlab/show/sum+type
67 |
68 | ## ⏬ 安装
69 |
70 | ### 尝试
71 |
72 | 如果只是想尝试玩玩本项目,可以从 PyPI 上安装完整的实现:
73 |
74 | ```bash
75 | pip install TinyLean
76 | ```
77 |
78 | 用 `tinylean` 命令执行任意 `.lean` 文件:
79 |
80 | ```bash
81 | tinylean example.lean
82 | ```
83 |
84 | 甚至可以执行一个 Markdown 文件,所有标记了 ```lean
的代码块都会执行类型检查。
85 |
86 | > [!IMPORTANT]
87 | > 你正在阅读的 README 文件是个合法的 TinyLean 文件!
88 |
89 | ```bash
90 | tinylean example.md
91 | ```
92 |
93 | ### 本地阅读源码
94 |
95 | 克隆本项目:
96 |
97 | ```bash
98 | git clone https://github.com/anqurvanillapy/TinyLean
99 | cd TinyLean/
100 | ```
101 |
102 | 本地测试任何你的 `.lean`/`.md` 文件,比如本文件:
103 |
104 | ```bash
105 | python -m src.TinyLean .github/README.md
106 | ```
107 |
108 | 安装并使用 `pytest` 执行所有单元测试:
109 |
110 | ```bash
111 | pip install pytest
112 | pytest
113 | ```
114 |
115 | ## 🧙指南
116 |
117 | 那么,欢迎来到定理证明的世界!让我们一步步实现如何优雅的证明旷世难题 `1+1=2`。
118 |
119 | ### DTLC
120 |
121 | 一开始,这个世界仅有这些东西:
122 |
123 | * *类型的类型*(type of type,也叫 universe),即 `Type`
124 | * 引用,又叫变量名,形如 `x`
125 | * 函数,形如 `λ x y ↦ y`
126 | * 函数类型,形如 `(x: Type) → (y: Type) → Type`
127 | * 调用,形如 `x y`
128 |
129 | 这个世界有个名字,叫 DTLC(dependently-typed lambda calculus)。
130 |
131 | 做一些简单的 lambda 演算,比如定义一个叫 `id` 的恒等函数(identity function),接受一个 `a` 并返回去:
132 |
133 | ```lean
134 | def id (T: Type) (a: T): T := a
135 |
136 | def Hello: Type := Type
137 |
138 | example := id Type Hello
139 | ```
140 |
141 | ### ITP 与 ATP
142 |
143 | 定理证明器的功能往往都是*交互式的*(ITP,interactive theorem proving),这的意思是,当你不太清楚目前证明所需要的信息时,你可以询问证明器,例如:
144 |
145 | ```lean
146 | def myLemma
147 | (a: Type)
148 | (b: (_: Type) -> Type)
149 | (c: b a)
150 | : Type
151 | := Type
152 | /- ^~~^ 尝试将这里的“Type”换成“_” -/
153 |
154 | def myTheorem := myLemma Type (id Type) Type
155 | ```
156 |
157 | 当你把 `:= Type` 替换成 `:= _` 时,你在代码中就留下了一个“洞(hole,又叫 goal)”,证明器会告诉你在 `_`
158 | 的位置要求填什么类型的值,并且上下文中都有哪些变量可以使用。证明器会输出类似这样的信息:
159 |
160 | ```plaintext
161 | .github/README.md:?:?: unsolved placeholder:
162 | ?u.? : Type
163 |
164 | context:
165 | (a: Type)
166 | (b: (_: Type) → Type)
167 | (c: (b a))
168 | ```
169 |
170 | 而所谓的*自动定理证明*(ATP,automatic theorem proving),则是根据上下文可用的变量,自动填入符合类型限制的值。
171 |
172 | TinyLean 能实现部分 ATP 的功能,可以填补一些“显而易见”的洞。而如你所见,完整的 ATP 是一个十分适合 AI 接手的问题。
173 |
174 | ### 隐式参数
175 |
176 | TinyLean 支持*隐式参数*(implicit argument)的特性,将我们的 `id` 函数改写,可以省去我们对 `T` 参数的传递,类型检查器能推导出来。
177 |
178 | ```lean
179 | def id1 {T: Type} (a: T): T := a
180 |
181 | example := id1 Hello
182 | ```
183 |
184 | 而实际上,隐式参数的原理就是由证明器帮忙插入 `_`,来看是否能由证明器根据上下文自动填补答案。以上的例子等同于在 `id` 的调用中留下
185 | `_`:
186 |
187 | ```lean
188 | example := id _ Hello
189 | ```
190 |
191 | 此外,如果你想显式地赋予 `id1` 中 `T` 的参数,不想由证明器填补,则使用以下语法:
192 |
193 | ```lean
194 | example := id1 (T := Type) Hello
195 | ```
196 |
197 | ### 邱奇数
198 |
199 | 仅用 DTLC,我们仍旧能够表达自然数(natural number),比如运用[邱奇数](Church numerals)的方式。
200 |
201 | 定义 `CN` 类型:
202 |
203 | ```lean
204 | def CN: Type :=
205 | (T: Type) -> (S: (n: T) -> T) -> (Z: T) -> T
206 | ```
207 |
208 | 定义一个数字 `3`,它形如“零的后继的后继的后继”:
209 |
210 | ```lean
211 | def _3: CN := fun T S Z => S (S (S Z))
212 | ```
213 |
214 | 定义加法和乘法:
215 |
216 | ```lean
217 | def addCN (a: CN) (b: CN): CN :=
218 | fun T S Z => (a T S) (b T S Z)
219 |
220 | def mulCN (a: CN) (b: CN): CN :=
221 | fun T S Z => (a T) (b T S) Z
222 | ```
223 |
224 | 做些简单演算:
225 |
226 | ```lean
227 | def _6: CN := addCN _3 _3
228 | def _9: CN := mulCN _3 _3
229 | ```
230 |
231 | [邱奇数]: https://en.wikipedia.org/wiki/Church_encoding
232 |
233 | ### 相等
234 |
235 | 编写证明最重要的工具是*相等*(equality),仅有 `1+1` 而不能证明 `1+1=2` 是荒唐的。而仅使用
236 | DTLC,我们依旧可以表达出等式,比如运用[Leibniz 等式]的方式。
237 |
238 | 定义 `LEq` 类型、`lRefl`(reflexivity,自反性)和 `lSym`(symmetry,对称性):
239 |
240 | ```lean
241 | def LEq {T: Type} (a: T) (b: T): Type :=
242 | (p: (v: T) -> Type) -> (pa: p a) -> p b
243 |
244 | def lRefl {T: Type} (a: T): LEq a a :=
245 | fun p pa => pa
246 |
247 | def lSym {T: Type} (a: T) (b: T) (p: LEq a b): LEq b a :=
248 | (p (fun b => LEq b a)) (lRefl a)
249 | ```
250 |
251 | 让我们证明刚刚的 `_9 = _3 + _6`:
252 |
253 | ```lean
254 | example: LEq _9 (addCN _3 _6) := lRefl _9
255 | ```
256 |
257 | [Leibniz 等式]: https://en.wikipedia.org/wiki/Equality_(mathematics)
258 |
259 | ### 归纳数据类型
260 |
261 | 我们可以用归纳数据类型(inductive data type)来定义一个新的类型,比如我们终于可以有一个更直观的自然数了:
262 |
263 | ```lean
264 | inductive N where
265 | | Z
266 | | S (n: N)
267 | open N
268 | ```
269 |
270 | 这个定义已经非常接近[Peano 公理]所定义的自然数:
271 |
272 | 1. 0(`Z`)是一个自然数(`N`)
273 | 2. 对于所有自然数 `n`,`n` 的后继(`S n`)也是一个自然数
274 |
275 | 其加法定义,运用递归(recursion)也更加自然:
276 |
277 | ```lean
278 | def addN (n: N) (m: N): N :=
279 | match n with
280 | | Z => m
281 | | S pred => S (addN pred m)
282 |
283 | example := addN (S Z) (S Z)
284 | ```
285 |
286 | 假设一个归纳数据类型没有任何构造器(constructor),则它就是一个空类型(bottom type,即 ⊥):
287 |
288 | ```lean
289 | inductive Bot where
290 | open Bot
291 | ```
292 |
293 | [爆炸原理](ex falso)是指我们可以从矛盾中获取出任何事物,我们可以用 `nomatch` 写出这样的定理:
294 |
295 | ```lean
296 | def exFalso (T: Type) (x: Bot): T := nomatch x
297 | ```
298 |
299 | 这里,我们凭空拿出来了一个 `T` 类型的值。
300 |
301 | [Peano 公理]: https://en.wikipedia.org/wiki/Peano_axioms
302 |
303 | [爆炸原理]: https://en.wikipedia.org/wiki/Principle_of_explosion
304 |
305 | ### 索引类型
306 |
307 | 归纳数据类型是可以携带参数(parameter)的,携带参数时我们称这样的类型为索引类型(indexed type),因为它“被某个值索引(indexed
308 | by a value)”。这样的类型我们还可以称作“归纳集(inductive family)”。
309 |
310 | 比如在 C++ 中,我们可以用“[非类型模板参数](non-type template parameter)”实现 `std::array` 的写法,此时 `3`
311 | 记录着数组的长度,它只是一个普通的数值。
312 |
313 | 同样的,我们可以定义一个能在类型上记录长度的 vector 类型:
314 |
315 | ```lean
316 | inductive Vec (A: Type) (n: N) where
317 | | Nil (n := Z)
318 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m)
319 | open Vec
320 | ```
321 |
322 | 这里的 `(n := Z)` 意思是指,当我使用 `Nil` 构造一个空 vector 时,它的类型参数 `n` 会被填为 `Z`,代表其长度为 0。
323 |
324 | 几个长度不同的 vector 的例子:
325 |
326 | ```lean
327 | def v0: Vec Type Z := Nil
328 | def v1: Vec Type (S Z) := Cons N v0
329 | def v2: Vec Type (S (S Z)) := Cons CN v1
330 | ```
331 |
332 | [非类型模板参数]: https://en.cppreference.com/w/cpp/language/template_parameters#Non-type_template_parameter
333 |
334 | ### 依赖模式匹配
335 |
336 | 索引类型能帮助我们排除掉不可能出现的模式(pattern)。举个例子,当我们使用 `Nil` 构造一个空 vector,并尝试对它进行 `match`
337 | 匹配时,很明显我们不需要再去考虑 `Cons` 的情况。这样的特性称作“依赖模式匹配(dependent pattern matching)”。
338 |
339 | ```lean
340 | example :=
341 | match v0 with
342 | | Nil => Z
343 | ```
344 |
345 | 假设我们补充上 `Cons` 的情况,证明器会报出如下错误:
346 |
347 | ```plaintext
348 | .github/README.md:?:?: type mismatch:
349 | want:
350 | (Vec Type N.Z)
351 |
352 | got:
353 | (Vec ?m.? (N.S ?m.?))
354 | ```
355 |
356 | 所以一个空类型不一定是没有构造器的类型,也有可能是完全没办法构造出来的类型,例如:
357 |
358 | ```lean
359 | inductive Weird (n: N) where
360 | | MkWeird (n := Z)
361 | open Weird
362 |
363 | example (A: Type) (x: Weird (S Z)): A := nomatch x
364 | ```
365 |
366 | 此时 `Weird (S Z)` 也是一个空类型,因为我们完全没办法构造一个这样类型的值。
367 |
368 | ### 新的相等类型
369 |
370 | 通过索引类型的特性,我们可以定义出更好理解的相等类型了:
371 |
372 | ```lean
373 | inductive Eq {T: Type} (a: T) (b: T) where
374 | | Refl (a := b)
375 | open Eq
376 | ```
377 |
378 | 用 `addN` 和 `Eq` 测试一下我们的 `1+1=2`:
379 |
380 | ```lean
381 | example: Eq (addN (S Z) (S Z)) (S (S Z)) := Refl (T := N)
382 | ```
383 |
384 | ### 类
385 |
386 | 在目前我们介绍的类型系统世界中,所有类型都同属于 `Type` 之下,我们没有办法对类型进行二次“归类”,这个 `Type` 忽然就变成了“新的
387 | `any`”。这样的坏处在于,我希望 `int` 类型的默认值是 `0`,希望 `string` 类型的默认值是 `""`,而我能通过一个函数
388 | `default::` 就能生成这个类型的默认值,这要怎么做到呢?
389 |
390 | 类型类(typeclass,又叫 trait)则能很好地解决这个问题:
391 |
392 | ```lean
393 | class Default (T: Type) where
394 | default: T
395 | open Default
396 | ```
397 |
398 | 有了 `Default` 这个类(class)后,我们就可以为不同的类型定义 `Default` 的实例(instance)。
399 |
400 | ### 实例
401 |
402 | 为 `N` 类型定义它的默认值 `Z`:
403 |
404 | ```lean
405 | instance: Default N
406 | where
407 | default := Z
408 | ```
409 |
410 | > [!CAUTION]
411 | > 注意这里 `where` 关键词需要写到新的一行,因为 Lean 4 语法的灵活性很大,为了保持 TinyLean 语法声明文件的简洁,很多语法歧义尚未处理。
412 |
413 | 我们写个 `(default N) = Z` 的证明:
414 |
415 | ```lean
416 | example: Eq Z (default N) := Refl (T := N)
417 | ```
418 |
419 | ### 类参数
420 |
421 | 我们可以使用类参数(class parameter)来检查某个类型(type)是否符合类(class)的要求,例如:
422 |
423 | ```lean
424 | def mustBeDefault (T: Type) [p: Default T] := Type
425 | ```
426 |
427 | 调用 `mustBeDefault` 时,我们要求参数 `T` 符合 `Default` 这一个类的限制。
428 |
429 | ```lean
430 | example := mustBeDefault N
431 | ```
432 |
433 | 很明显,`N` 类型符合这个限制。而当我们传入其他的类型,例如 `Bot` 时,证明器会告诉我们找不到对应的实例声明:
434 |
435 | ```plaintext
436 | .github/README.md:?:?: no such instance for class '(Default Bot)'
437 | ```
438 |
439 | ### 操作符重载
440 |
441 | 有了类,操作符重载(operator overloading)也能够轻松实现。在 TinyLean 中,中缀操作符 `+`、`-`、`*`、`/` 会被简单地翻译成
442 | `add`、`sub`、`mul`、`div` 的函数调用,所以我们要先定义好对应的类和类方法(class method):
443 |
444 | ```lean
445 | class Add {T: Type} where
446 | add: (a: T) -> (b: T) -> T
447 | open Add
448 | ```
449 |
450 | > [!NOTE]
451 | > 注意到这个 `add` 的操作是*同构*(homogenous)的,也就是输入和输出的类型都一致,更好的定义则是*异构*(heterogeneous)的,即类似
452 | > `T → U → V` 的定义,在此我们省略异构加法的讨论。
453 |
454 | 为 `N` 类型定义相应的实例:
455 |
456 | ```lean
457 | instance: Add (T := N)
458 | where
459 | add := addN
460 | ```
461 |
462 | 这样,我们就能在 `1+1=2` 的证明中使用中缀操作符了:
463 |
464 | ```lean
465 | example
466 | : Eq (S (S Z)) ((S Z) + (S Z))
467 | := Refl (T := N)
468 | ```
469 |
470 | ## 🔍 探索
471 |
472 | 接下来,你可以继续探索以下的世界:
473 |
474 | ### 源码
475 |
476 | 从 [`tests/onboard.py`] 文件开始阅读项目源码。
477 |
478 | [`tests/onboard.py`]: ../src/TinyLean/tests/onboard.py
479 |
480 | ### 未知
481 |
482 | 如果你觉得在“指南”阶段仍有许多困惑,甚至完全没法理解发生了什么,这是正常的。“指南”实际上更像是对 TinyLean
483 | 特性的展示,而不是一个正儿八经的定理证明教程,因为这样的优质教程其实是很多的,例如:
484 |
485 | * [Theorem Proving in Lean 4](https://leanprover.github.io/theorem_proving_in_lean4/)
486 | * [Programming Language Foundations in Agda](https://plfa.github.io/)
487 |
488 | 这些教程/书籍对我而言,并不是第一次读了就全部懂了,而是三至四年内反复地、片段式地不断重复阅读其中的某些片段才明白的。
489 |
490 | 而我得坦白,让我真正理解类型论的方式,是自己亲手实现一个又一个类型论。
491 |
492 | ### 跃迁
493 |
494 | TODO
495 |
496 | ## 🫡 致谢
497 |
498 | TODO
499 |
500 | ---
501 |
502 | MIT License Copyright © Anqur
503 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 | on:
3 | push:
4 | tags:
5 | - v*
6 | jobs:
7 | build-and-publish:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | id-token: write
11 | steps:
12 | - name: Checkout
13 | uses: actions/checkout@v4
14 | - name: Set up Python
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: '3.x'
18 | - name: Install build dependencies
19 | run: pip install --upgrade build check-manifest twine
20 | - name: Build
21 | run: |
22 | check-manifest
23 | python -m build
24 | python -m twine check dist/*
25 | - name: Publish
26 | uses: pypa/gh-action-pypi-publish@release/v1
27 | with:
28 | skip-existing: true
29 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 | on:
3 | - push
4 | - pull_request
5 | jobs:
6 | test:
7 | strategy:
8 | matrix:
9 | python:
10 | - '3.12'
11 | - '3.13'
12 | platform:
13 | - ubuntu-latest
14 | - macos-latest
15 | - windows-latest
16 | runs-on: ${{ matrix.platform }}
17 | steps:
18 | - uses: actions/checkout@v4
19 | - name: Set up Python ${{ matrix.python }}
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: ${{ matrix.python }}
23 | - name: Install test dependencies
24 | run: |
25 | pip install --upgrade build check-manifest twine black pytest pytest-cov
26 | pip install .
27 | - name: Lint
28 | run: black . --check
29 | - name: Build
30 | run: |
31 | check-manifest
32 | python -m build
33 | python -m twine check dist/*
34 | - name: Test and report coverage
35 | run: pytest --cov --cov-branch --cov-report=xml
36 | - name: Upload coverage reports to Codecov
37 | uses: codecov/codecov-action@v5
38 | with:
39 | token: ${{ secrets.CODECOV_TOKEN }}
40 | cloc:
41 | runs-on: ubuntu-latest
42 | steps:
43 | - uses: actions/checkout@v4
44 | - name: Install dependencies
45 | run: |
46 | sudo apt-get update
47 | sudo apt-get install cloc jq
48 | - name: Run CLOC
49 | run: cloc --exclude-dir=dist,venv,build,tests --json . | jq '"CLOC="+(.Python.code|tostring)' -r >> $GITHUB_ENV
50 | - name: Create CLOC badge
51 | uses: schneegans/dynamic-badges-action@v1.7.0
52 | with:
53 | auth: ${{ secrets.CLOC_GIST_SECRET }}
54 | gistID: 5d8f9b1d4b414b7076cf84f4eae089d9
55 | filename: cloc.json
56 | label: Python 行数
57 | message: ${{ env.CLOC }}
58 | color: orange
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea/
3 | .vscode/
4 |
5 | build/
6 | dist/
7 | *.egg-info/
8 | *.egg
9 | *.py[cod]
10 | __pycache__/
11 | *.so
12 | *~
13 | venv/
14 |
15 | .nox
16 | .cache
17 | .coverage
18 | coverage.*
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Anqur
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "TinyLean"
7 | version = "1.0.0"
8 | description = "Tiny theorem prover with syntax like Lean 4"
9 | readme = ".github/README.md"
10 | requires-python = ">=3.12"
11 | license = { file = "LICENSE" }
12 | keywords = [
13 | "lean",
14 | "lean4",
15 | "theorem-proving",
16 | "theorem-prover",
17 | "programming-language",
18 | ]
19 | authors = [
20 | { name = "Anqur", email = "anqurvanillapy@gmail.com" },
21 | ]
22 | classifiers = [
23 | "Intended Audience :: Education",
24 | "Topic :: Education",
25 | "License :: OSI Approved :: MIT License",
26 | "Programming Language :: Python :: 3",
27 | "Programming Language :: Python :: 3.12",
28 | "Programming Language :: Python :: 3.13",
29 | "Programming Language :: Python :: 3 :: Only",
30 | ]
31 | dependencies = [
32 | "pyparsing==3.2.1"
33 | ]
34 |
35 | [project.urls]
36 | "Homepage" = "https://github.com/anqurvanillapy/TinyLean"
37 | "Bug Reports" = "https://github.com/anqurvanillapy/TinyLean/issues"
38 | "Source" = "https://github.com/anqurvanillapy/TinyLean"
39 |
40 | [project.scripts]
41 | tinylean = "TinyLean.__main__:main"
42 |
--------------------------------------------------------------------------------
/src/TinyLean/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from itertools import count
3 |
4 | fresh = count(1).__next__
5 |
6 |
7 | @dataclass(frozen=True)
8 | class Name:
9 | text: str
10 | id: int = field(default_factory=fresh)
11 |
12 | def __str__(self):
13 | return self.text
14 |
15 | def is_unbound(self):
16 | return self.text == "_"
17 |
18 |
19 | @dataclass(frozen=True)
20 | class Param[T]:
21 | name: Name
22 | type: T
23 | is_implicit: bool
24 | is_class: bool = False
25 |
26 | def __str__(self):
27 | l, r = "()" if not self.is_implicit else "{}" if not self.is_class else "[]"
28 | return f"{l}{self.name}: {self.type}{r}"
29 |
30 |
31 | @dataclass(frozen=True)
32 | class Decl:
33 | loc: int
34 |
35 |
36 | @dataclass(frozen=True)
37 | class Def[T](Decl):
38 | name: Name
39 | params: list[Param[T]]
40 | ret: T
41 | body: T
42 |
43 |
44 | @dataclass(frozen=True)
45 | class Sig[T](Decl):
46 | name: Name
47 | params: list[Param[T]]
48 | ret: T
49 |
50 |
51 | @dataclass(frozen=True)
52 | class Example[T](Decl):
53 | params: list[Param[T]]
54 | ret: T
55 | body: T
56 |
57 |
58 | @dataclass(frozen=True)
59 | class Ctor[T](Decl):
60 | name: Name
61 | params: list[Param[T]]
62 | ty_args: list[tuple[T, T]]
63 | ty_name: Name | None = None
64 |
65 |
66 | @dataclass(frozen=True)
67 | class Data[T](Decl):
68 | name: Name
69 | params: list[Param[T]]
70 | ctors: list[Ctor[T]]
71 |
72 |
73 | @dataclass(frozen=True)
74 | class Field[T](Decl):
75 | name: Name
76 | type: T
77 | cls_name: Name | None = None
78 |
79 |
80 | @dataclass(frozen=True)
81 | class Class[T](Decl):
82 | name: Name
83 | params: list[Param[T]]
84 | fields: list[Field[T]]
85 | instances: list[int] = field(default_factory=list)
86 |
87 |
88 | @dataclass(frozen=True)
89 | class Instance[T](Decl):
90 | type: T
91 | fields: list[tuple[T, T]]
92 | id: int = field(default_factory=fresh)
93 |
--------------------------------------------------------------------------------
/src/TinyLean/__main__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | from pyparsing import util, exceptions
5 |
6 | from . import ast, ir
7 |
8 |
9 | fatal = lambda m: sys.exit(int(not print(m)))
10 |
11 |
12 | _F = Path(sys.argv[1]) if len(sys.argv) > 1 else None
13 |
14 |
15 | def fatal_on(text: str, loc: int, m: str):
16 | fatal(f"{_F}:{util.lineno(loc, text)}:{util.col(loc, text)}: {m}")
17 |
18 |
19 | def main(file=_F if _F else fatal("usage: tinylean FILE")):
20 | try:
21 | with open(file, encoding="utf-8") as f:
22 | text = f.read()
23 | ast.check_string(text, file.suffix == ".md")
24 | except OSError as e:
25 | fatal(e)
26 | except exceptions.ParseException as e:
27 | fatal_on(text, e.loc, str(e).split("(at char")[0].strip())
28 | except ast.UndefinedVariableError as e:
29 | v, loc = e.args
30 | fatal_on(text, loc, f"undefined variable '{v}'")
31 | except ast.DuplicateVariableError as e:
32 | v, loc = e.args
33 | fatal_on(text, loc, f"duplicate variable '{v}'")
34 | except ast.TypeMismatchError as e:
35 | want, got, loc = e.args
36 | fatal_on(text, loc, f"type mismatch:\nwant:\n {want}\n\ngot:\n {got}")
37 | except ast.UnsolvedPlaceholderError as e:
38 | name, ctx, ty, loc = e.args
39 | ty_msg = f" {name} : {ty}"
40 | ctx_msg = "".join([f"\n {p}" for p in ctx.values()]) if ctx else " (none)"
41 | fatal_on(text, loc, f"unsolved placeholder:\n{ty_msg}\n\ncontext:{ctx_msg}")
42 | except ast.UnknownCaseError as e:
43 | want, got, loc = e.args
44 | fatal_on(text, loc, f"cannot match case '{got}' of type '{want}'")
45 | except ast.DuplicateCaseError as e:
46 | name, loc = e.args
47 | fatal_on(text, loc, f"duplicate case '{name}'")
48 | except ast.CaseParamMismatchError as e:
49 | want, got, loc = e.args
50 | fatal_on(text, loc, f"want '{want}' case parameters, but got '{got}'")
51 | except ast.CaseMissError as e:
52 | miss, loc = e.args
53 | fatal_on(text, loc, f"missing case: {miss}")
54 | except ast.FieldMissError as e:
55 | miss, loc = e.args
56 | fatal_on(text, loc, f"missing field: {miss}")
57 | except ast.UnknownFieldError as e:
58 | want, got, loc = e.args
59 | fatal_on(text, loc, f"unknown field '{got}' of class '{want}'")
60 | except ir.NoInstanceError as e:
61 | name, loc = e.args
62 | fatal_on(text, loc, f"no such instance for class '{name}'")
63 | except RecursionError as e:
64 | print("Program too complex or oops you just got '⊥'! Please report this issue:")
65 | raise e
66 | except Exception as e:
67 | print("Internal compiler error! Please report this issue:")
68 | raise e
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/src/TinyLean/ast.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | from itertools import chain
3 | from dataclasses import dataclass, field
4 | from typing import OrderedDict, cast as _c
5 |
6 | from pyparsing import ParseResults
7 |
8 | from . import (
9 | Name,
10 | Param,
11 | Decl,
12 | ir,
13 | grammar as _g,
14 | fresh,
15 | Def,
16 | Example,
17 | Ctor,
18 | Data,
19 | Sig,
20 | Field,
21 | Class,
22 | Instance,
23 | )
24 |
25 |
26 | @dataclass(frozen=True)
27 | class Node:
28 | loc: int
29 |
30 |
31 | @dataclass(frozen=True)
32 | class Type(Node): ...
33 |
34 |
35 | @dataclass(frozen=True)
36 | class Ref(Node):
37 | name: Name
38 |
39 |
40 | @dataclass(frozen=True)
41 | class FnType(Node):
42 | param: Param[Node]
43 | ret: Node
44 |
45 |
46 | @dataclass(frozen=True)
47 | class Fn(Node):
48 | param: Name
49 | body: Node
50 |
51 |
52 | @dataclass(frozen=True)
53 | class Call(Node):
54 | callee: Node
55 | arg: Node
56 | implicit: str | bool
57 |
58 |
59 | @dataclass(frozen=True)
60 | class Placeholder(Node):
61 | is_user: bool
62 |
63 |
64 | @dataclass(frozen=True)
65 | class Nomatch(Node):
66 | arg: Node
67 |
68 |
69 | @dataclass(frozen=True)
70 | class Case(Node):
71 | ctor: Ref
72 | params: list[Name]
73 | body: Node
74 |
75 |
76 | @dataclass(frozen=True)
77 | class Match(Node):
78 | arg: Node
79 | cases: list[Case]
80 |
81 |
82 | _g.name.add_parse_action(lambda r: Name(r[0][0]))
83 |
84 | _ops = {"+": "add", "-": "sub", "*": "mul", "/": "div"}
85 |
86 |
87 | def _infix(loc: int, ret: ParseResults):
88 | r = ret[0]
89 | if not isinstance(r, ParseResults):
90 | return r
91 | return Call(loc, Call(loc, Ref(loc, Name(_ops[r[1]])), r[0], False), r[2], False)
92 |
93 |
94 | _g.expr.add_parse_action(_infix)
95 | _g.type_.add_parse_action(lambda l, r: Type(l))
96 | _g.ph.add_parse_action(lambda l, r: Placeholder(l, True))
97 | _g.ref.add_parse_action(lambda l, r: Ref(l, r[0][0]))
98 | _g.i_param.add_parse_action(lambda r: Param(r[0], r[1], True))
99 | _g.e_param.add_parse_action(lambda r: Param(r[0], r[1], False))
100 | _g.c_param.add_parse_action(lambda r: Param(r[0], r[1], True, True))
101 | _g.fn_type.add_parse_action(lambda l, r: FnType(l, r[0], r[1]))
102 | _g.fn.add_parse_action(
103 | lambda l, r: reduce(lambda a, n: Fn(l, n, a), reversed(r[0]), r[1])
104 | )
105 | _g.match.add_parse_action(lambda l, r: Match(l, r[0], list(r[1])))
106 | _g.case.add_parse_action(lambda r: Case(r[0].loc, r[0], r[1], r[2]))
107 | _g.nomatch.add_parse_action(lambda l, r: Nomatch(l, r[0][0]))
108 | _g.i_arg.add_parse_action(lambda l, r: (r[1], r[0]))
109 | _g.e_arg.add_parse_action(lambda l, r: (r[0], False))
110 | _g.call.add_parse_action(
111 | lambda l, r: reduce(lambda a, b: Call(l, a, b[0], b[1]), r[1:], r[0])
112 | )
113 | _g.p_expr.add_parse_action(lambda r: r[0])
114 |
115 | _g.return_type.add_parse_action(lambda l, r: r[0] if len(r) else Placeholder(l, False))
116 | _g.def_.add_parse_action(lambda r: Def(r[0].loc, r[0].name, list(r[1]), r[2], r[3]))
117 | _g.example.add_parse_action(lambda l, r: Example(l, list(r[0]), r[1], r[2]))
118 | _g.type_arg.add_parse_action(lambda r: (r[0], r[1]))
119 | _g.ctor.add_parse_action(lambda r: Ctor(r[0].loc, r[0].name, list(r[1]), list(r[2])))
120 | _g.data.add_condition(
121 | lambda r: r[0].name.text == r[3], message="open and datatype name mismatch"
122 | ).add_parse_action(lambda r: Data(r[0].loc, r[0].name, list(r[1]), list(r[2])))
123 | _g.c_field.add_parse_action(lambda l, r: Field(l, r[0], r[1]))
124 | _g.class_.add_condition(
125 | lambda r: r[0].name.text == r[3], message="open and class name mismatch"
126 | ).add_parse_action(lambda r: Class(r[0].loc, r[0].name, list(r[1]), list(r[2])))
127 | _g.i_field.add_parse_action(lambda r: (r[0], r[1]))
128 | _g.inst.add_parse_action(lambda l, r: Instance(l, r[0], list(r[1])))
129 |
130 |
131 | @dataclass(frozen=True)
132 | class Parser:
133 | is_markdown: bool = False
134 |
135 | def __ror__(self, s: str):
136 | if not self.is_markdown:
137 | return list(_g.program.parse_string(s, parse_all=True))
138 | return chain.from_iterable(r[0] for r in _g.markdown.scan_string(s))
139 |
140 |
141 | class DuplicateVariableError(Exception): ...
142 |
143 |
144 | class UndefinedVariableError(Exception): ...
145 |
146 |
147 | @dataclass(frozen=True)
148 | class NameResolver:
149 | locals: dict[str, Name] = field(default_factory=dict)
150 | globals: dict[str, Name] = field(default_factory=dict)
151 |
152 | def __ror__(self, decls: list[Decl]):
153 | return [self._decl(d) for d in decls]
154 |
155 | def _decl(self, decl: Decl) -> Decl:
156 | self.locals.clear()
157 |
158 | if isinstance(decl, Def) or isinstance(decl, Data):
159 | self._insert_global(decl.loc, decl.name)
160 |
161 | if isinstance(decl, Def) or isinstance(decl, Example):
162 | return self._def_or_example(decl)
163 |
164 | if isinstance(decl, Data):
165 | return self._data(decl)
166 |
167 | if isinstance(decl, Class):
168 | return self._class(decl)
169 |
170 | return self._inst(_c(Instance, decl))
171 |
172 | def _def_or_example(self, d: Def[Node] | Example[Node]):
173 | params = self._params(d.params)
174 | ret = self.expr(d.ret)
175 | body = self.expr(d.body)
176 | if isinstance(d, Example):
177 | return Example(d.loc, params, ret, body)
178 | return Def(d.loc, d.name, params, ret, body)
179 |
180 | def _data(self, d: Data[Node]):
181 | params = self._params(d.params)
182 | ctors = [self._ctor(c, d.name) for c in d.ctors]
183 | return Data(d.loc, d.name, params, ctors)
184 |
185 | def _ctor(self, c: Ctor[Node], ty_name: Name):
186 | params = self._params(c.params)
187 | ty_args = [(self.expr(n), self.expr(t)) for n, t in c.ty_args]
188 | for p in params:
189 | del self.locals[p.name.text]
190 | self._insert_global(c.loc, c.name)
191 | return Ctor(c.loc, c.name, params, ty_args, ty_name)
192 |
193 | def _class(self, c: Class[Node]):
194 | params = self._params(c.params)
195 | fields = []
196 | for f in c.fields:
197 | self._insert_global(f.loc, f.name)
198 | fields.append(Field(f.loc, f.name, self.expr(f.type)))
199 | self._insert_global(c.loc, c.name)
200 | return Class(c.loc, c.name, params, fields)
201 |
202 | def _inst(self, i: Instance[Node]):
203 | t = self.expr(i.type)
204 | fields = []
205 | field_ids = set()
206 | for n, v in i.fields:
207 | n = _c(Ref, self.expr(n))
208 | if n.name.id in field_ids:
209 | raise DuplicateVariableError(n.name.text, n.loc)
210 | field_ids.add(n.name.id)
211 | fields.append((n, (self.expr(v))))
212 | return Instance(i.loc, t, fields)
213 |
214 | def _params(self, params: list[Param[Node]]):
215 | ret = []
216 | for p in params:
217 | self._insert_local(p.name)
218 | ret.append(Param(p.name, self.expr(p.type), p.is_implicit, p.is_class))
219 | return ret
220 |
221 | def expr(self, n: Node) -> Node:
222 | if isinstance(n, Ref):
223 | if v := self.locals.get(n.name.text, self.globals.get(n.name.text)):
224 | return Ref(n.loc, v)
225 | raise UndefinedVariableError(n.name.text, n.loc)
226 | if isinstance(n, FnType):
227 | typ = self.expr(n.param.type)
228 | b = self._with_locals(n.ret, n.param.name)
229 | p = Param(n.param.name, typ, n.param.is_implicit, n.param.is_class)
230 | return FnType(n.loc, p, b)
231 | if isinstance(n, Fn):
232 | return Fn(n.loc, n.param, self._with_locals(n.body, n.param))
233 | if isinstance(n, Call):
234 | return Call(n.loc, self.expr(n.callee), self.expr(n.arg), n.implicit)
235 | if isinstance(n, Nomatch):
236 | return Nomatch(n.loc, self.expr(n.arg))
237 | if isinstance(n, Match):
238 | arg = self.expr(n.arg)
239 | cases = []
240 | for c in n.cases:
241 | ctor = _c(Ref, self.expr(c.ctor))
242 | body = self._with_locals(c.body, *c.params)
243 | cases.append(Case(c.loc, ctor, c.params, body))
244 | return Match(n.loc, arg, cases)
245 | assert isinstance(n, Type) or isinstance(n, Placeholder)
246 | return n
247 |
248 | def _with_locals(self, node: Node, *names: Name):
249 | olds = [(v, self._insert_local(v)) for v in names]
250 | ret = self.expr(node)
251 | for v, old in olds:
252 | if old:
253 | self._insert_local(old)
254 | elif not v.is_unbound():
255 | del self.locals[v.text]
256 | return ret
257 |
258 | def _insert_local(self, v: Name):
259 | if v.is_unbound():
260 | return None
261 | old = self.locals.get(v.text)
262 | self.locals[v.text] = v
263 | return old
264 |
265 | def _insert_global(self, loc: int, name: Name):
266 | if not name.is_unbound():
267 | if name.text in self.globals:
268 | raise DuplicateVariableError(name.text, loc)
269 | self.globals[name.text] = name
270 |
271 |
272 | class TypeMismatchError(Exception): ...
273 |
274 |
275 | class UnsolvedPlaceholderError(Exception): ...
276 |
277 |
278 | class UnknownCaseError(Exception): ...
279 |
280 |
281 | class DuplicateCaseError(Exception): ...
282 |
283 |
284 | class CaseParamMismatchError(Exception): ...
285 |
286 |
287 | class CaseMissError(Exception): ...
288 |
289 |
290 | class FieldMissError(Exception): ...
291 |
292 |
293 | class UnknownFieldError(Exception): ...
294 |
295 |
296 | @dataclass(frozen=True)
297 | class TypeChecker:
298 | globals: dict[int, Decl] = field(default_factory=dict)
299 | locals: dict[int, Param[ir.IR]] = field(default_factory=dict)
300 | holes: OrderedDict[int, ir.Hole] = field(default_factory=OrderedDict)
301 | recur_ids: set[int] = field(default_factory=set)
302 |
303 | def __ror__(self, ds: list[Decl]):
304 | ret = [self._run(d) for d in ds]
305 | for i, h in self.holes.items():
306 | if h.answer.is_unsolved():
307 | ty = self._inliner().run(h.answer.type)
308 | if _is_solved_class(ty):
309 | continue
310 | p = ir.Placeholder(i, h.is_user)
311 | raise UnsolvedPlaceholderError(str(p), h.locals, ty, h.loc)
312 | return ret
313 |
314 | def _run(self, decl: Decl) -> Decl:
315 | self.locals.clear()
316 | if isinstance(decl, Def) or isinstance(decl, Example):
317 | return self._def_or_example(decl)
318 | if isinstance(decl, Data):
319 | return self._data(decl)
320 | if isinstance(decl, Instance):
321 | return self._inst(decl)
322 | return self._class(_c(Class, decl))
323 |
324 | def _def_or_example(self, d: Def[Node] | Example[Node]):
325 | params = self._params(d.params)
326 | ret = self.check(d.ret, ir.Type())
327 |
328 | if isinstance(d, Def):
329 | self.globals[d.name.id] = Sig(d.loc, d.name, params, ret)
330 | body = self.check(d.body, ret)
331 |
332 | if isinstance(d, Example):
333 | return Example(d.loc, params, ret, body)
334 |
335 | checked = Def(d.loc, d.name, params, ret, body)
336 | self.globals[d.name.id] = checked
337 | return checked
338 |
339 | def _data(self, d: Data[Node]):
340 | params = self._params(d.params)
341 | data = Data(d.loc, d.name, params, [])
342 | self.globals[d.name.id] = data
343 | data.ctors.extend(self._ctor(c) for c in d.ctors)
344 | return data
345 |
346 | def _ctor(self, c: Ctor[Node]):
347 | params = self._params(c.params)
348 | ty_args: list[tuple[ir.IR, ir.IR]] = []
349 | for x, v in c.ty_args:
350 | x_val, x_ty = self.infer(x)
351 | v_val = self.check(v, x_ty)
352 | ty_args.append((x_val, v_val))
353 | ctor = Ctor(c.loc, c.name, params, ty_args, c.ty_name)
354 | self.globals[c.name.id] = ctor
355 | return ctor
356 |
357 | def _class(self, c: Class[Node]):
358 | params = self._params(c.params)
359 | fs = [
360 | Field(f.loc, f.name, self.check(f.type, ir.Type()), c.name)
361 | for f in c.fields
362 | ]
363 | self.globals.update({f.name.id: f for f in fs})
364 | cls = Class(c.loc, c.name, params, fs)
365 | self.globals[c.name.id] = cls
366 | return cls
367 |
368 | def _inst(self, i: Instance[Node]):
369 | ty = self.check(i.type, ir.Type())
370 | if not isinstance(ty, ir.Class):
371 | raise TypeMismatchError("class", str(ty), i.type.loc)
372 | c = _c(Class, self.globals[ty.name.id])
373 | vals = {_c(Ref, n).name.id: (n, f) for n, f in i.fields}
374 | fields = []
375 | for f in c.fields:
376 | nv = vals.pop(f.name.id, None)
377 | if not nv:
378 | raise FieldMissError(f.name.text, i.loc)
379 | f_decl = _c(Field, self.globals[f.name.id])
380 | field_ty = ir.from_field(f_decl, c, False)[1]
381 | env = []
382 | for ty_arg in ty.args:
383 | assert isinstance(field_ty, ir.FnType)
384 | env.append((field_ty.param.name, ty_arg))
385 | field_ty = field_ty.ret
386 | f_type = self._inliner().run_with(field_ty, *env)
387 | fields.append((ir.Ref(f.name), self.check(nv[1], f_type)))
388 | for n, _ in vals.values():
389 | assert isinstance(n, Ref)
390 | raise UnknownFieldError(c.name.text, n.name.text, n.loc)
391 | c.instances.append(i.id)
392 | inst = Instance(i.loc, _c(ir.IR, ty), fields, i.id)
393 | self.globals[i.id] = inst
394 | return inst
395 |
396 | def _params(self, params: list[Param[Node]]):
397 | ret = []
398 | for p in params:
399 | t = self.check(p.type, ir.Type())
400 | if p.is_class:
401 | t = self._inliner().run(t)
402 | assert p.is_implicit
403 | if not isinstance(t, ir.Class):
404 | raise TypeMismatchError("class", str(t), p.type.loc)
405 | param = Param(p.name, t, p.is_implicit, p.is_class)
406 | self.locals[p.name.id] = param
407 | ret.append(param)
408 | return ret
409 |
410 | def check(self, n: Node, typ: ir.IR) -> ir.IR:
411 | if isinstance(n, Fn):
412 | t = self._inliner().run(typ)
413 | if not isinstance(t, ir.FnType):
414 | raise TypeMismatchError(str(t), "function", n.loc)
415 | ret = self._inliner().run_with(t.ret, (t.param.name, ir.Ref(n.param)))
416 | p = Param(n.param, t.param.type, t.param.is_implicit, t.param.is_class)
417 | return ir.Fn(p, self._check_with(n.body, ret, p))
418 |
419 | holes_len = len(self.holes)
420 | val, got = self.infer(n)
421 | got = self._inliner().run(got)
422 | want = self._inliner().run(typ)
423 |
424 | if _can_insert_placeholders(want):
425 | if new_f := _with_placeholders(n, got, False):
426 | # FIXME: No valid tests yet.
427 | assert len(self.holes) == holes_len
428 | val, got = self.infer(new_f)
429 |
430 | if not self._eq(got, want):
431 | raise TypeMismatchError(str(want), str(got), n.loc)
432 |
433 | return val
434 |
435 | def infer(self, n: Node) -> tuple[ir.IR, ir.IR]:
436 | if isinstance(n, Ref):
437 | if param := self.locals.get(n.name.id):
438 | return ir.Ref(param.name), param.type
439 | d = self.globals[n.name.id]
440 | if isinstance(d, Def):
441 | return ir.from_def(d)
442 | if isinstance(d, Sig):
443 | self.recur_ids.add(d.name.id)
444 | return ir.from_sig(d)
445 | if isinstance(d, Data):
446 | return ir.from_data(d)
447 | if isinstance(d, Ctor):
448 | data_decl = _c(Data, self.globals[d.ty_name.id])
449 | return ir.from_ctor(d, data_decl)
450 | if isinstance(d, Field):
451 | return ir.from_field(d, _c(Class, self.globals[d.cls_name.id]))
452 | return ir.from_class(_c(Class, d))
453 | if isinstance(n, FnType):
454 | p_typ = self.check(n.param.type, ir.Type())
455 | p = Param(n.param.name, p_typ, n.param.is_implicit, n.param.is_class)
456 | b_val = self._check_with(n.ret, ir.Type(), p)
457 | return ir.FnType(p, b_val), ir.Type()
458 | if isinstance(n, Call):
459 | holes_len = len(self.holes)
460 | f_val, got = self.infer(n.callee)
461 |
462 | if implicit_f := _with_placeholders(n.callee, got, n.implicit):
463 | [self.holes.popitem() for _ in range(len(self.holes) - holes_len)]
464 | return self.infer(Call(n.loc, implicit_f, n.arg, n.implicit))
465 |
466 | if not isinstance(got, ir.FnType):
467 | raise TypeMismatchError("function", str(got), n.callee.loc)
468 |
469 | x_tm = self._check_with(n.arg, got.param.type, got.param)
470 | typ = self._inliner().run_with(got.ret, (got.param.name, x_tm))
471 | val = self._inliner().apply(f_val, x_tm)
472 | return val, typ
473 | if isinstance(n, Placeholder):
474 | ty = self._insert_hole(n.loc, n.is_user, ir.Type())
475 | v = self._insert_hole(n.loc, n.is_user, ty)
476 | return v, ty
477 | if isinstance(n, Nomatch):
478 | _, got = self.infer(n.arg)
479 | if not isinstance(got, ir.Data):
480 | raise TypeMismatchError("datatype", str(got), n.arg.loc)
481 | data = _c(Data, self.globals[got.name.id])
482 | for c in data.ctors:
483 | self._exhaust(n.arg.loc, c, data, got)
484 | return ir.Nomatch(), self._insert_hole(n.loc, False, ir.Type())
485 | if isinstance(n, Match):
486 | arg, arg_ty = self.infer(n.arg)
487 | if not isinstance(arg_ty, ir.Data):
488 | raise TypeMismatchError("datatype", str(arg_ty), n.arg.loc)
489 | data = _c(Data, self.globals[arg_ty.name.id])
490 | ctors = {c.name.id: c for c in data.ctors}
491 | ty: ir.IR | None = None
492 | cases: dict[int, ir.Case] = {}
493 | for c in n.cases:
494 | ctor = ctors.get(c.ctor.name.id)
495 | if not ctor:
496 | raise UnknownCaseError(data.name.text, c.ctor.name.text, c.loc)
497 | with ir.dirty_holes(self.holes):
498 | c_ty = self._ctor_return_type(c.loc, ctor, data)
499 | if not self._eq(c_ty, arg_ty):
500 | raise TypeMismatchError(str(arg_ty), str(c_ty), c.loc)
501 | if ctor.name.id in cases:
502 | raise DuplicateCaseError(ctor.name.text, c.loc)
503 | if len(c.params) != len(ctor.params):
504 | raise CaseParamMismatchError(len(ctor.params), len(c.params), c.loc)
505 | ps = [Param(n, p.type, False) for n, p in zip(c.params, ctor.params)]
506 | if ty is None:
507 | body, ty = self._infer_with(c.body, *ps)
508 | else:
509 | body = self._check_with(c.body, ty, *ps)
510 | cases[ctor.name.id] = ir.Case(ctor.name, ps, body)
511 | for c in [c for i, c in ctors.items() if i not in cases]:
512 | self._exhaust(n.loc, c, data, arg_ty)
513 | return ir.Match(arg, cases), ty
514 | assert isinstance(n, Type)
515 | return ir.Type(), ir.Type()
516 |
517 | def _inliner(self):
518 | return ir.Inliner(self.holes, self.globals)
519 |
520 | def _eq(self, got: ir.IR, want: ir.IR):
521 | return ir.Converter(self.holes, self.globals).eq(got, want)
522 |
523 | def _check_with(self, n: Node, typ: ir.IR, *ps: Param[ir.IR]):
524 | self.locals.update({p.name.id: p for p in ps})
525 | ret = self.check(n, typ)
526 | [self.locals.pop(p.name.id, None) for p in ps]
527 | return ret
528 |
529 | def _infer_with(self, n: Node, *ps: Param[ir.IR]):
530 | self.locals.update({p.name.id: p for p in ps})
531 | v, ty = self.infer(n)
532 | [self.locals.pop(p.name.id, None) for p in ps]
533 | return v, ty
534 |
535 | def _insert_hole(self, loc: int, is_user: bool, typ: ir.IR):
536 | i = fresh()
537 | self.holes[i] = ir.Hole(loc, is_user, self.locals.copy(), ir.Answer(typ))
538 | return ir.Placeholder(i, is_user)
539 |
540 | def _ctor_return_type(self, loc: int, c: Ctor[ir.IR], d: Data[ir.IR]):
541 | _, ty = ir.from_ctor(c, d)
542 | while isinstance(ty, ir.FnType):
543 | p = ty.param
544 | x = (
545 | self._insert_hole(loc, False, p.type)
546 | if p.is_implicit
547 | else ir.Ref(p.name)
548 | )
549 | ty = self._inliner().run_with(ty.ret, (p.name, x))
550 | return ty
551 |
552 | def _exhaust(self, loc: int, c: Ctor[ir.IR], d: Data[ir.IR], want: ir.IR):
553 | with ir.dirty_holes(self.holes):
554 | if self._eq(self._ctor_return_type(loc, c, d), want):
555 | raise CaseMissError(c.name.text, loc)
556 |
557 |
558 | def _is_solved_class(ty: ir.IR):
559 | if isinstance(ty, ir.Class):
560 | return all(not isinstance(a, ir.Ref) for a in ty.args)
561 | return False
562 |
563 |
564 | def _can_insert_placeholders(ty: ir.IR):
565 | return not isinstance(ty, ir.FnType) or not ty.param.is_implicit
566 |
567 |
568 | def _with_placeholders(f: Node, f_ty: ir.IR, implicit: str | bool) -> Node | None:
569 | if not isinstance(f_ty, ir.FnType):
570 | return None
571 |
572 | if isinstance(implicit, bool):
573 | if not f_ty.param.is_implicit:
574 | return None
575 | return _call_placeholder(f) if not implicit else None
576 |
577 | pending = 0
578 | while True:
579 | if not isinstance(f_ty, ir.FnType) or not f_ty.param.is_implicit:
580 | raise UndefinedVariableError(implicit, f.loc)
581 | if f_ty.param.name.text == implicit:
582 | break
583 | pending += 1
584 | f_ty = f_ty.ret
585 |
586 | if not pending:
587 | return None
588 |
589 | for _ in range(pending):
590 | f = _call_placeholder(f)
591 | return f
592 |
593 |
594 | def _call_placeholder(f: Node):
595 | return Call(f.loc, f, Placeholder(f.loc, False), True)
596 |
597 |
598 | check_string = lambda s, md=False: s | Parser(md) | NameResolver() | TypeChecker()
599 |
--------------------------------------------------------------------------------
/src/TinyLean/grammar.py:
--------------------------------------------------------------------------------
1 | from pyparsing import *
2 |
3 | ParserElement.enable_packrat()
4 |
5 | COMMENT = Regex(r"/\-(?:[^-]|\-(?!/))*\-\/").set_name("comment")
6 |
7 | IDENT = unicode_set.identifier()
8 |
9 | DEF, EXAMPLE, IND, WHERE, OPEN, TYPE, NOMATCH, MATCH, WITH, UNDER, CLASS, INST = map(
10 | lambda w: Suppress(Keyword(w)),
11 | "def example inductive where open Type nomatch match with _ class instance".split(),
12 | )
13 |
14 | ASSIGN, ARROW, FUN, TO = map(
15 | lambda s: Suppress(s[0]) | Suppress(s[1:]), "≔:= →-> λfun ↦=>".split()
16 | )
17 |
18 | LPAREN, RPAREN, LBRACE, RBRACE, LBRACKET, RBRACKET, COLON, BAR, NEWLINE = map(
19 | Suppress, "(){}[]:|\n"
20 | )
21 | INLINE_WHITE = Opt(Suppress(White(" \t\r"))).set_name("inline_whitespace")
22 |
23 | forwards = lambda names: map(lambda n: Forward().set_name(n), names.split())
24 |
25 | expr, atom, fn_type, fn, match, nomatch, call, p_expr, type_, ph, ref = forwards(
26 | "expr atom fn_type fn match nomatch call paren_expr type placeholder ref"
27 | )
28 | case, i_arg, e_arg = forwards("case implicit_arg explicit_arg")
29 |
30 | infix_op = lambda s: (one_of(s), 2, OpAssoc.LEFT)
31 | expr <<= infix_notation(atom, [infix_op("* /"), infix_op("+ -")])
32 | atom <<= fn_type | fn | match | nomatch | call | p_expr | type_ | ph | ref
33 |
34 | name = Group(IDENT).set_name("name")
35 | i_param = (LBRACE + name + COLON + expr + RBRACE).set_name("implicit_param")
36 | e_param = (LPAREN + name + COLON + expr + RPAREN).set_name("explicit_param")
37 | c_param = (LBRACKET + name + COLON + expr + RBRACKET).set_name("class_param")
38 | param = (i_param | e_param | c_param).set_name("param")
39 | fn_type <<= param + ARROW + expr
40 | fn <<= FUN - Group(OneOrMore(name)) + TO + expr
41 | match <<= MATCH - (type_ | ref | p_expr) + WITH + Group(OneOrMore(case))
42 | case <<= BAR - ref + Group(ZeroOrMore(name)) + TO + expr
43 | nomatch <<= (NOMATCH - INLINE_WHITE + e_arg).leave_whitespace()
44 | callee = ref | p_expr
45 | call <<= (callee + OneOrMore(INLINE_WHITE + (i_arg | e_arg))).leave_whitespace()
46 | i_arg <<= LPAREN + IDENT + ASSIGN + expr + RPAREN
47 | e_arg <<= (type_ | ph | ref | p_expr).leave_whitespace()
48 | p_expr <<= LPAREN + expr + RPAREN
49 | type_ <<= Group(TYPE)
50 | ph <<= Group(UNDER)
51 | ref <<= Group(name)
52 |
53 | return_type = Opt(COLON + expr)
54 | params = Group(ZeroOrMore(param))
55 | def_ = (DEF - ref + params + return_type + ASSIGN + expr).set_name("definition")
56 | example = (EXAMPLE - params + return_type + ASSIGN + expr).set_name("example")
57 | type_arg = (LPAREN + ref + ASSIGN + expr + RPAREN).set_name("type_arg")
58 | ctor = (BAR - ref + params + Group(ZeroOrMore(type_arg))).set_name("constructor")
59 | data = (IND - ref + params + WHERE + Group(ZeroOrMore(ctor)) + OPEN + IDENT).set_name(
60 | "datatype"
61 | )
62 | c_field = (name + COLON + expr).set_name("class_field")
63 | class_ = (
64 | CLASS - ref + params + WHERE + Group(ZeroOrMore(c_field)) + OPEN + IDENT
65 | ).set_name("class")
66 | i_field = (ref + ASSIGN + expr).set_name("instance_field")
67 | inst = (INST - COLON + expr + WHERE + Group(ZeroOrMore(i_field))).set_name("instance")
68 | declaration = (def_ | example | data | class_ | inst).set_name("declaration")
69 |
70 | program = ZeroOrMore(declaration).ignore(COMMENT).set_name("program")
71 |
72 | line_exact = lambda w: Suppress(AtLineStart(w) + LineEnd())
73 | markdown = line_exact("```lean") + program + line_exact("```")
74 |
--------------------------------------------------------------------------------
/src/TinyLean/ir.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from dataclasses import dataclass, field
3 | from functools import reduce as _r
4 | from typing import Optional, cast as _c, OrderedDict
5 |
6 | from . import (
7 | Name,
8 | Param,
9 | Def,
10 | Data as DataDecl,
11 | Ctor as CtorDecl,
12 | Sig,
13 | Decl,
14 | Class as ClassDecl,
15 | Field as FieldDecl,
16 | Instance,
17 | )
18 |
19 |
20 | @dataclass(frozen=True)
21 | class IR: ...
22 |
23 |
24 | @dataclass(frozen=True)
25 | class Type(IR):
26 | def __str__(self):
27 | return "Type"
28 |
29 |
30 | @dataclass(frozen=True)
31 | class Ref(IR):
32 | name: Name
33 |
34 | def __str__(self):
35 | return str(self.name)
36 |
37 |
38 | @dataclass(frozen=True)
39 | class FnType(IR):
40 | param: Param[IR]
41 | ret: IR
42 |
43 | def __str__(self):
44 | return f"{self.param} → {self.ret}"
45 |
46 |
47 | @dataclass(frozen=True)
48 | class Fn(IR):
49 | param: Param[IR]
50 | body: IR
51 |
52 | def __str__(self):
53 | return f"λ {self.param} ↦ {self.body}"
54 |
55 |
56 | @dataclass(frozen=True)
57 | class Call(IR):
58 | callee: IR
59 | arg: IR
60 |
61 | def __str__(self):
62 | return f"({self.callee} {self.arg})"
63 |
64 |
65 | @dataclass(frozen=True)
66 | class Placeholder(IR):
67 | id: int
68 | is_user: bool
69 |
70 | def __str__(self):
71 | t = "u" if self.is_user else "m"
72 | return f"?{t}.{self.id}"
73 |
74 |
75 | @dataclass(frozen=True)
76 | class Data(IR):
77 | name: Name
78 | args: list[IR]
79 |
80 | def __str__(self):
81 | s = " ".join(str(x) for x in [self.name, *self.args])
82 | return f"({s})" if len(self.args) else s
83 |
84 |
85 | @dataclass(frozen=True)
86 | class Ctor(IR):
87 | ty_name: Name
88 | name: Name
89 | args: list[IR]
90 |
91 | def __str__(self):
92 | n = f"{self.ty_name}.{self.name}"
93 | s = " ".join(str(x) for x in [n, *self.args])
94 | return f"({s})" if len(self.args) else s
95 |
96 |
97 | @dataclass(frozen=True)
98 | class Nomatch(IR):
99 | def __str__(self):
100 | return "nomatch"
101 |
102 |
103 | @dataclass(frozen=True)
104 | class Case(IR):
105 | ctor: Name
106 | params: list[Param[IR]]
107 | body: IR
108 |
109 | def __str__(self):
110 | s = " ".join([str(self.ctor), *map(str, self.params)])
111 | return f"| {s} ↦ {self.body}"
112 |
113 |
114 | @dataclass(frozen=True)
115 | class Match(IR):
116 | arg: IR
117 | cases: dict[int, Case]
118 |
119 | def __str__(self):
120 | cs = " ".join(map(str, self.cases.values()))
121 | return f"match {self.arg} with {cs}"
122 |
123 |
124 | @dataclass(frozen=True)
125 | class Recur(IR):
126 | name: Name
127 |
128 | def __str__(self):
129 | return str(self.name)
130 |
131 |
132 | @dataclass(frozen=True)
133 | class Class(IR):
134 | name: Name
135 | args: list[IR]
136 |
137 | def __str__(self):
138 | s = " ".join(str(x) for x in [self.name, *self.args])
139 | return f"({s})" if len(self.args) else s
140 |
141 | def is_unsolved(self):
142 | return any(isinstance(a, Ref) for a in self.args)
143 |
144 |
145 | @dataclass(frozen=True)
146 | class Field(IR):
147 | name: Name
148 | type: IR
149 |
150 | def __str__(self):
151 | return str(self.name)
152 |
153 |
154 | @dataclass(frozen=True)
155 | class Renamer:
156 | locals: dict[int, int] = field(default_factory=dict)
157 |
158 | def run(self, v: IR) -> IR:
159 | if isinstance(v, Ref):
160 | if v.name.id in self.locals:
161 | return Ref(Name(v.name.text, self.locals[v.name.id]))
162 | return v
163 | if isinstance(v, Call):
164 | return Call(self.run(v.callee), self.run(v.arg))
165 | if isinstance(v, Fn):
166 | return Fn(self._param(v.param), self.run(v.body))
167 | if isinstance(v, FnType):
168 | return FnType(self._param(v.param), self.run(v.ret))
169 | if isinstance(v, Data):
170 | return Data(v.name, [self.run(x) for x in v.args])
171 | if isinstance(v, Ctor):
172 | return Ctor(v.ty_name, v.name, [self.run(x) for x in v.args])
173 | if isinstance(v, Match):
174 | arg = self.run(v.arg)
175 | cases = {
176 | i: Case(c.ctor, [self._param(p) for p in c.params], self.run(c.body))
177 | for i, c in v.cases.items()
178 | }
179 | return Match(arg, cases)
180 | if isinstance(v, Class):
181 | return Class(v.name, [self.run(x) for x in v.args])
182 | if isinstance(v, Field):
183 | return Field(v.name, self.run(v.type))
184 | assert any(isinstance(v, c) for c in (Type, Placeholder, Nomatch, Recur))
185 | return v
186 |
187 | def _param(self, p: Param[IR]):
188 | name = Name(p.name.text)
189 | self.locals[p.name.id] = name.id
190 | return Param(name, self.run(p.type), p.is_implicit, p.is_class)
191 |
192 |
193 | _rn = lambda v: Renamer().run(v)
194 |
195 |
196 | def _to(p: list[Param[IR]], v: IR, t=False):
197 | return _r(lambda a, q: _c(IR, FnType(q, a) if t else Fn(q, a)), reversed(p), v)
198 |
199 |
200 | def from_def(d: Def[IR]):
201 | return _rn(_to(d.params, d.body)), _rn(_to(d.params, d.ret, True))
202 |
203 |
204 | def from_sig(s: Sig[IR]):
205 | return Recur(s.name), _rn(_to(s.params, s.ret, True))
206 |
207 |
208 | def from_data(d: DataDecl[IR]):
209 | args = [Ref(p.name) for p in d.params]
210 | return _rn(_to(d.params, Data(d.name, args))), _rn(_to(d.params, Type(), True))
211 |
212 |
213 | def from_ctor(c: CtorDecl[IR], d: DataDecl[IR]):
214 | adhoc = {x.name.id: v for x, v in _c(dict[Ref, IR], c.ty_args)}
215 | miss = [Param(p.name, p.type, True) for p in d.params if p.name.id not in adhoc]
216 |
217 | v = _to(c.params, Ctor(d.name, c.name, [Ref(p.name) for p in c.params]))
218 | v = _to(miss, v)
219 |
220 | ty_args = [adhoc.get(p.name.id, Ref(p.name)) for p in d.params]
221 | ty = _to(c.params, Data(d.name, ty_args), True)
222 | ty = _to(miss, ty, True)
223 |
224 | return _rn(v), _rn(ty)
225 |
226 |
227 | def from_class(c: ClassDecl[IR]):
228 | args = [Ref(p.name) for p in c.params]
229 | return _rn(_to(c.params, Class(c.name, args))), _rn(_to(c.params, Type(), True))
230 |
231 |
232 | def from_field(f: FieldDecl[IR], c: ClassDecl[IR], has_c_param=True):
233 | t = Class(c.name, [Ref(p.name) for p in c.params])
234 | ps = [*c.params, Param(Name("inst"), t, True, True)] if has_c_param else c.params
235 | return _rn(_to(ps, Field(f.name, t))), _rn(_to(ps, f.type, True))
236 |
237 |
238 | @dataclass
239 | class Answer:
240 | type: IR
241 | value: Optional[IR] = None
242 |
243 | def is_unsolved(self):
244 | return self.value is None
245 |
246 |
247 | @dataclass(frozen=True)
248 | class Hole:
249 | loc: int
250 | is_user: bool
251 | locals: dict[int, Param[IR]]
252 | answer: Answer
253 |
254 |
255 | @contextmanager
256 | def dirty_holes(holes: OrderedDict[int, Hole]):
257 | l = len(holes)
258 | try:
259 | yield
260 | finally:
261 | [holes.popitem() for _ in range(len(holes) - l)]
262 |
263 |
264 | class NoInstanceError(Exception): ...
265 |
266 |
267 | @dataclass
268 | class Inliner:
269 | holes: OrderedDict[int, Hole]
270 | globals: dict[int, Decl]
271 | can_recurse: bool = True
272 | env: dict[int, IR] = field(default_factory=dict)
273 |
274 | def run(self, v: IR) -> IR:
275 | if isinstance(v, Ref):
276 | return self.run(_rn(self.env[v.name.id])) if v.name.id in self.env else v
277 | if isinstance(v, Call):
278 | f = self.run(v.callee)
279 | x = self.run(v.arg)
280 | if isinstance(f, Fn):
281 | return self.run_with(f.body, (f.param.name, x))
282 | return Call(f, x)
283 | if isinstance(v, Fn):
284 | return Fn(self._param(v.param), self.run(v.body))
285 | if isinstance(v, FnType):
286 | return FnType(self._param(v.param), self.run(v.ret))
287 | if isinstance(v, Placeholder):
288 | h = self.holes[v.id]
289 | h.answer.type = self.run(h.answer.type)
290 | return v if h.answer.is_unsolved() else self.run(h.answer.value)
291 | if isinstance(v, Ctor):
292 | return Ctor(v.ty_name, v.name, [self.run(v) for v in v.args])
293 | if isinstance(v, Data):
294 | return Data(v.name, [self.run(v) for v in v.args])
295 | if isinstance(v, Match):
296 | arg = self.run(v.arg)
297 | can_recurse = self.can_recurse
298 | self.can_recurse = False
299 | cases = {
300 | i: Case(c.ctor, [self._param(p) for p in c.params], self.run(c.body))
301 | for i, c in v.cases.items()
302 | }
303 | self.can_recurse = can_recurse
304 | if not isinstance(arg, Ctor):
305 | return Match(arg, cases)
306 | c = cases[arg.name.id]
307 | env = [(x.name, v) for x, v in zip(c.params, arg.args)]
308 | return self.run_with(c.body, *env)
309 | if isinstance(v, Recur):
310 | if self.can_recurse:
311 | d = self.globals[v.name.id]
312 | if isinstance(d, Def):
313 | return from_def(d)[0]
314 | assert isinstance(d, Sig)
315 | return v
316 | if isinstance(v, Class):
317 | return Class(v.name, [self.run(t) for t in v.args])
318 | if isinstance(v, Field):
319 | c = _c(Class, self.run(v.type))
320 | if c.is_unsolved():
321 | return Field(v.name, c)
322 | i = self._resolve_instance(c)
323 | val = next(val for n, val in i.fields if _c(Ref, n).name.id == v.name.id)
324 | return self.run(val)
325 | assert isinstance(v, Type) or isinstance(v, Nomatch)
326 | return v
327 |
328 | def run_with(self, x: IR, *env: tuple[Name, IR]):
329 | self.env.update({n.id: v for n, v in env})
330 | return self.run(x)
331 |
332 | def apply(self, f: IR, *args: IR):
333 | ret = f
334 | for x in args:
335 | if isinstance(ret, Fn):
336 | ret = self.run_with(ret.body, (ret.param.name, x))
337 | else:
338 | ret = Call(ret, x)
339 | return ret
340 |
341 | def _param(self, param: Param[IR]):
342 | p = Param(param.name, self.run(param.type), param.is_implicit, param.is_class)
343 | if not p.is_class:
344 | return p
345 | ty = _c(Class, p.type)
346 | if not ty.is_unsolved() and not self._resolve_instance(ty):
347 | raise NoInstanceError(str(ty), self.globals[ty.name.id].loc)
348 | return p
349 |
350 | def _resolve_instance(self, c: Class) -> Optional[Instance[IR]]:
351 | cls = _c(ClassDecl, self.globals[c.name.id])
352 | for inst_id in cls.instances:
353 | i = _c(Instance, self.globals[inst_id])
354 | with dirty_holes(self.holes):
355 | if Converter(self.holes, self.globals).eq(c, i.type):
356 | return i
357 | return None
358 |
359 |
360 | @dataclass(frozen=True)
361 | class Converter:
362 | holes: OrderedDict[int, Hole]
363 | globals: dict[int, Decl]
364 |
365 | def eq(self, lhs: IR, rhs: IR):
366 | match lhs, rhs:
367 | case Placeholder() as x, y:
368 | return self._solve(x, y)
369 | case x, Placeholder() as y:
370 | return self._solve(y, x)
371 | case Ref(x), Ref(y):
372 | return x.id == y.id
373 | case Call(f, x), Call(g, y):
374 | return self.eq(f, g) and self.eq(x, y)
375 | case Fn(p, b), Fn(q, c):
376 | env = [(q.name, Ref(p.name))]
377 | return self.eq(b, Inliner(self.holes, self.globals).run_with(c, *env))
378 | case FnType(p, b), FnType(q, c):
379 | if not self.eq(p.type, q.type):
380 | return False
381 | env = [(q.name, Ref(p.name))]
382 | return self.eq(b, Inliner(self.holes, self.globals).run_with(c, *env))
383 | case Data(x, xs), Data(y, ys):
384 | return x.id == y.id and self._args(xs, ys)
385 | case Ctor(t, x, xs), Ctor(u, y, ys):
386 | return t.id == u.id and x.id == y.id and self._args(xs, ys)
387 | case Type(), Type():
388 | return True
389 | case Class(x, xs), Class(y, ys):
390 | return x.id == y.id and self._args(xs, ys)
391 |
392 | # FIXME: Following cases not seen in tests yet:
393 | assert not (isinstance(lhs, Placeholder) and isinstance(rhs, Placeholder))
394 | assert not (isinstance(lhs, Match) and isinstance(rhs, Match))
395 | assert not (isinstance(lhs, Field) and isinstance(rhs, Field))
396 |
397 | return False
398 |
399 | def _solve(self, p: Placeholder, answer: IR):
400 | h = self.holes[p.id]
401 | if not h.answer.is_unsolved():
402 | return self.eq(h.answer.value, answer)
403 | h.answer.value = answer
404 |
405 | if isinstance(answer, Ref):
406 | for param in h.locals.values():
407 | if param.name.id == answer.name.id:
408 | assert self.eq(param.type, h.answer.type) # FIXME: will fail here?
409 |
410 | return True
411 |
412 | def _args(self, xs: list[IR], ys: list[IR]):
413 | assert len(xs) == len(ys)
414 | return all(self.eq(x, y) for x, y in zip(xs, ys))
415 |
--------------------------------------------------------------------------------
/src/TinyLean/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from pyparsing import ParserElement
2 |
3 | from .. import ast, grammar
4 |
5 |
6 | def parse(g: ParserElement, text: str):
7 | return g.parse_string(text, parse_all=True)
8 |
9 |
10 | def resolve(s: str):
11 | return s | ast.Parser() | ast.NameResolver()
12 |
13 |
14 | def resolve_md(s: str):
15 | return s | ast.Parser(True) | ast.NameResolver()
16 |
17 |
18 | def resolve_expr(s: str):
19 | return ast.NameResolver().expr(parse(grammar.expr, s)[0])
20 |
--------------------------------------------------------------------------------
/src/TinyLean/tests/onboard.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anqur/TinyLean/312b3eacfe8e0489971d871eaca886ee1d737246/src/TinyLean/tests/onboard.py
--------------------------------------------------------------------------------
/src/TinyLean/tests/test_checker.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from . import resolve_expr
4 | from .. import ast, Name, Param, ir, Data, Example, Def, Class, Instance
5 |
6 | check_expr = lambda s, t: ast.TypeChecker().check(resolve_expr(s), t)
7 | infer_expr = lambda s: ast.TypeChecker().infer(resolve_expr(s))
8 |
9 |
10 | class TestTypeChecker(TestCase):
11 | def test_check_expr_type(self):
12 | check_expr("Type", ir.Type())
13 | check_expr("{a: Type} -> (b: Type) -> a", ir.Type())
14 |
15 | def test_check_expr_type_failed(self):
16 | with self.assertRaises(ast.TypeMismatchError) as e:
17 | check_expr("fun a => a", ir.Type())
18 | want, got, loc = e.exception.args
19 | self.assertEqual(0, loc)
20 | self.assertEqual("Type", want)
21 | self.assertEqual("function", got)
22 |
23 | def test_check_expr_function(self):
24 | check_expr(
25 | "fun a => a",
26 | ir.FnType(Param(Name("a"), ir.Type(), False), ir.Type()),
27 | )
28 |
29 | def test_check_expr_on_infer(self):
30 | check_expr("Type", ir.Type())
31 |
32 | def test_check_expr_on_infer_failed(self):
33 | with self.assertRaises(ast.TypeMismatchError) as e:
34 | check_expr("(a: Type) -> a", ir.Ref(Name("a")))
35 | want, got, loc = e.exception.args
36 | self.assertEqual(0, loc)
37 | self.assertEqual("a", want)
38 | self.assertEqual("Type", got)
39 |
40 | def test_infer_expr_type(self):
41 | v, ty = infer_expr("Type")
42 | assert isinstance(v, ir.Type)
43 | assert isinstance(ty, ir.Type)
44 |
45 | def test_infer_expr_call_failed(self):
46 | with self.assertRaises(ast.TypeMismatchError) as e:
47 | infer_expr("(Type) Type")
48 | want, got, loc = e.exception.args
49 | self.assertEqual(1, loc)
50 | self.assertEqual("function", want)
51 | self.assertEqual("Type", got)
52 |
53 | def test_infer_expr_function_type(self):
54 | v, ty = infer_expr("{a: Type} -> a")
55 | assert isinstance(v, ir.FnType)
56 | self.assertEqual("{a: Type} → a", str(v))
57 | assert isinstance(ty, ir.Type)
58 |
59 | def test_check_program(self):
60 | ast.check_string("def a: Type := Type")
61 | ast.check_string("def f (a: Type): Type := a")
62 | ast.check_string("def f: (_: Type) -> Type := fun a => a")
63 | ast.check_string("def id (T: Type) (a: T): T := a")
64 |
65 | def test_check_program_failed(self):
66 | with self.assertRaises(ast.TypeMismatchError) as e:
67 | ast.check_string("def id (a: Type): a := Type")
68 | want, got, loc = e.exception.args
69 | self.assertEqual(23, loc)
70 | self.assertEqual("a", want)
71 | self.assertEqual("Type", got)
72 |
73 | def test_check_program_call(self):
74 | ast.check_string(
75 | """
76 | def f0 (a: Type): Type := a
77 | def f1: Type := f0 Type
78 | def f2: f0 Type := Type
79 | """
80 | )
81 |
82 | def test_check_program_call_failed(self):
83 | with self.assertRaises(ast.TypeMismatchError) as e:
84 | ast.check_string(
85 | """
86 | def f0 (a: Type): Type := a
87 | def f1 (a: Type): Type := f0
88 | """
89 | )
90 | want, got, loc = e.exception.args
91 | self.assertEqual(87, loc)
92 | self.assertEqual("Type", want)
93 | self.assertEqual("(a: Type) → Type", got)
94 |
95 | def test_check_program_placeholder(self):
96 | ast.check_string(
97 | """
98 | def a := Type
99 | def b: Type := a
100 | """
101 | )
102 |
103 | def test_check_program_placeholder_locals(self):
104 | ast.check_string("def f (T: Type) (a: T) := a")
105 |
106 | def test_check_program_placeholder_unsolved(self):
107 | with self.assertRaises(ast.UnsolvedPlaceholderError) as e:
108 | ast.check_string("def a: Type := _")
109 | name, ctx, ty, loc = e.exception.args
110 | self.assertTrue(name.startswith("?u"))
111 | self.assertEqual(0, len(ctx))
112 | assert isinstance(ty, ir.Type)
113 | self.assertEqual(15, loc)
114 |
115 | def test_check_program_call_implicit_arg(self):
116 | _, _, example = ast.check_string(
117 | """
118 | def id {T: Type} (a: T): T := a
119 | def f := id (T := Type) Type
120 | example := f
121 | """
122 | )
123 | assert isinstance(example.body, ir.Type)
124 |
125 | def test_check_program_call_implicit_arg_failed(self):
126 | with self.assertRaises(ast.UndefinedVariableError) as e:
127 | ast.check_string(
128 | """
129 | def id {T: Type} (a: T): T := a
130 | def f := id (U := Type) Type
131 | """
132 | )
133 | name, loc = e.exception.args
134 | self.assertEqual("U", name)
135 | self.assertEqual(74, loc)
136 |
137 | def test_check_program_call_implicit_arg_long(self):
138 | ast.check_string(
139 | """
140 | def f {T: Type} {U: Type} (a: U): Type := T
141 | def g: f (U := Type) Type := Type
142 | """
143 | )
144 |
145 | def test_check_program_call_implicit(self):
146 | _, _, example = ast.check_string(
147 | """
148 | def id {T: Type} (a: T): T := a
149 | def f := id Type
150 | example := f
151 | """
152 | )
153 | assert isinstance(example.body, ir.Type)
154 |
155 | def test_check_program_call_no_explicit_failed(self):
156 | with self.assertRaises(ast.UnsolvedPlaceholderError) as e:
157 | ast.check_string(
158 | """
159 | def f {T: Type}: Type := T
160 | def g: Type := f
161 | """
162 | )
163 | name, ctx, ty, loc = e.exception.args
164 | self.assertTrue(name.startswith("?m"))
165 | self.assertEqual(1, len(ctx))
166 | assert isinstance(ty, ir.Type)
167 | self.assertEqual(75, loc)
168 |
169 | def test_check_program_call_mixed_implicit(self):
170 | ast.check_string(
171 | """
172 | def f (T: Type) {U: Type}: Type := U
173 | /- Cannot insert placeholders for an implicit function type. -/
174 | example: {U: Type} -> Type := f Type
175 | """
176 | )
177 |
178 | def test_check_program_datatype_nat(self):
179 | x, _, _, _2 = ast.check_string(
180 | """
181 | inductive N where
182 | | Z
183 | | S (n: N)
184 | open N
185 |
186 | example: N := Z
187 | example: N := S Z
188 | example: N := S (S Z)
189 | """
190 | )
191 | assert isinstance(x, Data)
192 | self.assertEqual(2, len(x.ctors))
193 |
194 | n_v, n_ty = ir.from_data(x)
195 | self.assertEqual("N", str(n_v))
196 | self.assertEqual("Type", str(n_ty))
197 |
198 | z_v, z_ty = ir.from_ctor(x.ctors[0], x)
199 | self.assertEqual("N.Z", str(z_v))
200 | self.assertEqual("N", str(z_ty))
201 |
202 | s_v, s_ty = ir.from_ctor(x.ctors[1], x)
203 | self.assertEqual("λ (n: N) ↦ (N.S n)", str(s_v))
204 | self.assertEqual("(n: N) → N", str(s_ty))
205 |
206 | assert isinstance(_2, Example)
207 | self.assertEqual("(N.S (N.S N.Z))", str(_2.body))
208 |
209 | def test_check_program_datatype_nat_failed(self):
210 | with self.assertRaises(ast.TypeMismatchError) as e:
211 | ast.check_string(
212 | """
213 | inductive N where
214 | | Z
215 | | S (n: N)
216 | open N
217 |
218 | example: Type := Z
219 | """
220 | )
221 | want, got, loc = e.exception.args
222 | self.assertEqual("Type", want)
223 | self.assertEqual("N", got)
224 | self.assertEqual(139, loc)
225 |
226 | def test_check_program_datatype_maybe(self):
227 | x = ast.check_string(
228 | """
229 | inductive Maybe (A: Type) where
230 | | Nothing
231 | | Just (a: A)
232 | open Maybe
233 |
234 | example: Maybe Type := Nothing
235 | example: Maybe Type := Just Type
236 | """
237 | )[0]
238 | assert isinstance(x, Data)
239 | self.assertEqual(2, len(x.ctors))
240 |
241 | maybe_v, maybe_ty = ir.from_data(x)
242 | self.assertEqual("λ (A: Type) ↦ (Maybe A)", str(maybe_v))
243 | self.assertEqual("(A: Type) → Type", str(maybe_ty))
244 |
245 | nothing_v, nothing_ty = ir.from_ctor(x.ctors[0], x)
246 | self.assertEqual("λ {A: Type} ↦ Maybe.Nothing", str(nothing_v))
247 | self.assertEqual("{A: Type} → (Maybe A)", str(nothing_ty))
248 |
249 | just_v, just_ty = ir.from_ctor(x.ctors[1], x)
250 | self.assertEqual("λ {A: Type} ↦ λ (a: A) ↦ (Maybe.Just a)", str(just_v))
251 | self.assertEqual("{A: Type} → (a: A) → (Maybe A)", str(just_ty))
252 |
253 | assert isinstance(just_v, ir.Fn)
254 | assert isinstance(just_v.body, ir.Fn)
255 | assert isinstance(just_v.body.body, ir.Ctor)
256 | just_arg = just_v.body.body.args[0]
257 | assert isinstance(just_arg, ir.Ref)
258 | self.assertEqual(just_v.body.param.name.id, just_arg.name.id)
259 |
260 | assert isinstance(just_ty, ir.FnType)
261 | assert isinstance(just_ty.ret, ir.FnType)
262 | assert isinstance(just_ty.ret.ret, ir.Data)
263 | just_ty_arg = just_ty.ret.ret.args[0]
264 | assert isinstance(just_ty_arg, ir.Ref)
265 | self.assertEqual(just_ty.param.name.id, just_ty_arg.name.id)
266 |
267 | def test_check_program_datatype_maybe_unsolved(self):
268 | with self.assertRaises(ast.UnsolvedPlaceholderError) as e:
269 | ast.check_string(
270 | """
271 | inductive Maybe (A: Type) where
272 | | Nothing
273 | | Just (a: A)
274 | open Maybe
275 |
276 | example := Nothing
277 | """
278 | )
279 | name, ctx, ty, loc = e.exception.args
280 | self.assertTrue(name.startswith("?m"))
281 | self.assertEqual(1, len(ctx))
282 | assert isinstance(ty, ir.Type)
283 | self.assertEqual(160, loc)
284 |
285 | def test_check_program_datatype_maybe_failed(self):
286 | with self.assertRaises(ast.TypeMismatchError) as e:
287 | ast.check_string(
288 | """
289 | inductive Maybe (A: Type) where
290 | | Nothing
291 | | Just (a: A)
292 | open Maybe
293 |
294 | inductive A where | AA open A
295 | inductive B where | BB open B
296 | example: Maybe B := Just AA
297 | """
298 | )
299 | want, got, _ = e.exception.args
300 | self.assertEqual("(Maybe B)", want)
301 | self.assertEqual("(Maybe A)", got)
302 |
303 | def test_check_program_datatype_vec(self):
304 | x = ast.check_string(
305 | """
306 | inductive N where
307 | | Z
308 | | S (n: N)
309 | open N
310 |
311 | inductive Vec (A: Type) (n: N) where
312 | | Nil (n := Z)
313 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m)
314 | open Vec
315 |
316 | /- This will emit some "dirty" placeholders. -/
317 | def v1: Vec N (S Z) := Cons Z Nil
318 | """
319 | )[1]
320 | assert isinstance(x, Data)
321 | self.assertEqual(2, len(x.ctors))
322 |
323 | vec_v, vec_ty = ir.from_data(x)
324 | self.assertEqual("λ (A: Type) ↦ λ (n: N) ↦ (Vec A n)", str(vec_v))
325 | self.assertEqual("(A: Type) → (n: N) → Type", str(vec_ty))
326 |
327 | nil_v, nil_ty = ir.from_ctor(x.ctors[0], x)
328 | self.assertEqual("λ {A: Type} ↦ Vec.Nil", str(nil_v))
329 | self.assertEqual("{A: Type} → (Vec A N.Z)", str(nil_ty))
330 |
331 | cons_v, cons_ty = ir.from_ctor(x.ctors[1], x)
332 | self.assertEqual(
333 | "λ {A: Type} ↦ λ {m: N} ↦ λ (a: A) ↦ λ (v: (Vec A m)) ↦ (Vec.Cons m a v)",
334 | str(cons_v),
335 | )
336 | self.assertEqual(
337 | "{A: Type} → {m: N} → (a: A) → (v: (Vec A m)) → (Vec A (N.S m))",
338 | str(cons_ty),
339 | )
340 |
341 | def test_check_program_ctor_eq(self):
342 | ast.check_string(
343 | """
344 | def Eq {T: Type} (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b
345 | def refl {T: Type} (a: T): Eq a a := fun p pa => pa
346 | inductive A where | AA open A
347 | example: Eq AA AA := refl AA
348 | """
349 | )
350 |
351 | def test_check_program_nomatch(self):
352 | _, _, e = ast.check_string(
353 | """
354 | inductive Bottom where open Bottom
355 | def elimBot {A: Type} (x: Bottom): A := nomatch x
356 | example (x: Bottom): Type := elimBot x
357 | """
358 | )
359 | assert isinstance(e, Example)
360 | assert isinstance(e.body, ir.Nomatch)
361 | self.assertEqual("nomatch", str(e.body))
362 |
363 | def test_check_program_nomatch_non_data_failed(self):
364 | with self.assertRaises(ast.TypeMismatchError) as e:
365 | ast.check_string("example := nomatch Type")
366 | want, got, loc = e.exception.args
367 | self.assertEqual("datatype", want)
368 | self.assertEqual("Type", got)
369 | self.assertEqual(19, loc)
370 |
371 | def test_check_program_nomatch_non_empty_failed(self):
372 | with self.assertRaises(ast.CaseMissError) as e:
373 | ast.check_string(
374 | """
375 | inductive A where | AA open A
376 | example := nomatch AA
377 | """
378 | )
379 | name, loc = e.exception.args
380 | self.assertEqual("AA", name)
381 | self.assertEqual(82, loc)
382 |
383 | def test_check_program_nomatch_dpm(self):
384 | ast.check_string(
385 | """
386 | inductive N where
387 | | Z
388 | | S (n: N)
389 | open N
390 |
391 | inductive T (n: N) where
392 | | MkT (n := Z)
393 | open T
394 |
395 | example (x: T (S Z)): Type := nomatch x
396 | """
397 | )
398 |
399 | def test_check_program_nomatch_eq_failed(self):
400 | text = """
401 | def Eq {T: Type} (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b
402 | def refl {T: Type} (a: T): Eq a a := fun p pa => pa
403 | inductive Bottom where open Bottom
404 | def a (x: Bottom): Type := nomatch x
405 | def b (x: Bottom): Type := nomatch x
406 | /- (a x) and (b x) should not be the same thing, apparently. -/
407 | example (x: Bottom): Eq (a x) (b x) := refl (a x)
408 | """
409 | with self.assertRaises(ast.TypeMismatchError) as e:
410 | ast.check_string(text)
411 | want, got, loc = e.exception.args
412 | self.assertEqual(
413 | "(p: (v: Type) → Type) → (pa: (p nomatch)) → (p nomatch)", want
414 | )
415 | self.assertEqual("(p: (v: Type) → Type) → (pa: (p nomatch)) → (p nomatch)", got)
416 | self.assertEqual(text.index("refl (a x)"), loc)
417 |
418 | def test_check_program_match(self):
419 | _, f, e = ast.check_string(
420 | """
421 | inductive V where
422 | | A (x: Type) (y: Type)
423 | | B (x: Type)
424 | open V
425 |
426 | def f (v: V): Type :=
427 | match v with
428 | | A x y => x
429 | | B x => x
430 |
431 | example := f (A Type Type)
432 | """
433 | )
434 | assert isinstance(f, Def)
435 | self.assertEqual(
436 | "match v with | A (x: Type) (y: Type) ↦ x | B (x: Type) ↦ x", str(f.body)
437 | )
438 | assert isinstance(e, Example)
439 | assert isinstance(e.body, ir.Type)
440 |
441 | def test_check_program_match_dpm(self):
442 | ast.check_string(
443 | """
444 | inductive N where
445 | | Z
446 | | S (n: N)
447 | open N
448 |
449 | inductive Vec (A: Type) (n: N) where
450 | | Nil (n := Z)
451 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m)
452 | open Vec
453 |
454 | def v0: Vec N Z := Nil
455 |
456 | example :=
457 | match v0 with
458 | | Nil => Z
459 | """
460 | )
461 |
462 | def test_check_program_match_dpm_failed(self):
463 | text = """
464 | inductive N where
465 | | Z
466 | | S (n: N)
467 | open N
468 |
469 | inductive Vec (A: Type) (n: N) where
470 | | Nil (n := Z)
471 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m)
472 | open Vec
473 |
474 | def v0: Vec N Z := Nil
475 |
476 | example :=
477 | match v0 with
478 | | Nil => Z
479 | | Cons a v => Z
480 | """
481 | want_loc = text.index("| Cons a v") + 2
482 | with self.assertRaises(ast.TypeMismatchError) as e:
483 | ast.check_string(text)
484 | want, got, loc = e.exception.args
485 | self.assertIn("N.Z", want)
486 | self.assertIn("N.S", got)
487 | self.assertEqual(want_loc, loc)
488 |
489 | def test_check_program_match_type_failed(self):
490 | with self.assertRaises(ast.TypeMismatchError) as e:
491 | ast.check_string(
492 | """
493 | inductive A where | AA open A
494 | example :=
495 | match Type with
496 | | AA => AA
497 | """
498 | )
499 | want, got, loc = e.exception.args
500 | self.assertEqual("datatype", want)
501 | self.assertEqual("Type", got)
502 | self.assertEqual(98, loc)
503 |
504 | def test_check_program_match_unknown_case_failed(self):
505 | with self.assertRaises(ast.UnknownCaseError) as e:
506 | ast.check_string(
507 | """
508 | inductive A where | AA open A
509 | inductive B where | BB open B
510 | example (x: A) :=
511 | match x with
512 | | BB => AA
513 | """
514 | )
515 | want, got, loc = e.exception.args
516 | self.assertEqual("A", want)
517 | self.assertEqual("BB", got)
518 | self.assertEqual(178, loc)
519 |
520 | def test_check_program_match_duplicate_case_failed(self):
521 | text = """
522 | inductive A where | AA open A
523 |
524 | example (x: A): Type :=
525 | match x with
526 | | AA => (a: Type) -> Type
527 | | AA => Type
528 | """
529 | with self.assertRaises(ast.DuplicateCaseError) as e:
530 | ast.check_string(text)
531 | name, loc = e.exception.args
532 | self.assertEqual("AA", name)
533 | self.assertEqual(text.index("AA => Type"), loc)
534 |
535 | def test_check_program_match_param_mismatch_failed(self):
536 | text = """
537 | inductive A where | AA open A
538 | example (x: A): Type :=
539 | match x with
540 | | AA a => AA
541 | """
542 | with self.assertRaises(ast.CaseParamMismatchError) as e:
543 | ast.check_string(text)
544 | want, got, loc = e.exception.args
545 | self.assertEqual(0, want)
546 | self.assertEqual(1, got)
547 | self.assertEqual(text.index("AA a"), loc)
548 |
549 | def test_check_program_match_miss_failed(self):
550 | text = """
551 | inductive A where | AA | BB open A
552 | example (x: A): Type :=
553 | match x with
554 | | AA => AA
555 | """
556 | with self.assertRaises(ast.CaseMissError) as e:
557 | ast.check_string(text)
558 | name, loc = e.exception.args
559 | self.assertEqual("BB", name)
560 | self.assertEqual(text.index("match x with"), loc)
561 |
562 | def test_check_program_match_inline(self):
563 | ast.check_string(
564 | """
565 | inductive A where | AA open A
566 | def f (x: A) :=
567 | match x with
568 | | AA => AA
569 | def g (x: A) := f x /- match expression not inlined yet -/
570 | """
571 | )
572 |
573 | def test_check_program_eq(self):
574 | ast.check_string(
575 | """
576 | inductive N where
577 | | Z
578 | | S (n: N)
579 | open N
580 |
581 | inductive Eq {T: Type} (a: T) (b: T) where
582 | | Refl (a := b)
583 | open Eq
584 |
585 | example: Eq (S Z) (S Z) := Refl (T := N)
586 | """
587 | )
588 |
589 | def test_check_program_eq_failed(self):
590 | text = """
591 | inductive N where
592 | | Z
593 | | S (n: N)
594 | open N
595 |
596 | inductive Eq {T: Type} (a: T) (b: T) where
597 | | Refl (a := b)
598 | open Eq
599 |
600 | example: Eq Z (S Z) := Refl (T := N)
601 | """
602 | with self.assertRaises(ast.TypeMismatchError) as e:
603 | ast.check_string(text)
604 | want, got, loc = e.exception.args
605 | self.assertEqual("(Eq N N.Z (N.S N.Z))", want)
606 | got = " ".join(["_" if "?m." in s else s for s in got[1:-1].split()])
607 | self.assertEqual("Eq N _ _", got)
608 | self.assertEqual(text.index("Refl (T := N)"), loc)
609 |
610 | def test_check_program_recurse(self):
611 | _, add, e = ast.check_string(
612 | """
613 | inductive N where
614 | | Z
615 | | S (n: N)
616 | open N
617 |
618 | def add (n: N) (m: N): N :=
619 | match n with
620 | | Z => m
621 | | S pred => S (add pred m)
622 |
623 | example := add (S Z) (S Z)
624 | """
625 | )
626 | assert isinstance(add, Def)
627 | self.assertEqual(
628 | "match n with | Z ↦ m | S (pred: N) ↦ (N.S ((add pred) m))", str(add.body)
629 | )
630 | assert isinstance(e, Example)
631 | self.assertEqual("(N.S (N.S N.Z))", str(e.body))
632 |
633 | def test_check_program_class(self):
634 | ast.check_string(
635 | """
636 | class C where open C
637 | example [p: C] := Type
638 | """
639 | )
640 |
641 | def test_check_program_class_param_failed(self):
642 | text = "example [p: Type] := Type"
643 | with self.assertRaises(ast.TypeMismatchError) as e:
644 | ast.check_string(text)
645 | want, got, loc = e.exception.args
646 | self.assertEqual("class", want)
647 | self.assertEqual("Type", got)
648 | self.assertEqual(text.index("Type]"), loc)
649 |
650 | def test_check_program_class_stuck(self):
651 | _, _, e = ast.check_string(
652 | """
653 | class Default (T: Type) where
654 | default: T
655 | open Default
656 | def f (U: Type) [p: Default U] := default U (inst := p)
657 | example (V: Type) [q: Default V] := f V (p := q)
658 | """
659 | )
660 | assert isinstance(e, Example)
661 | self.assertEqual("default", str(e.body))
662 |
663 | def test_check_program_class_failed(self):
664 | text = """
665 | class C where open C
666 | def f [p: C] := Type
667 | example := f
668 | """
669 | with self.assertRaises(ir.NoInstanceError) as e:
670 | ast.check_string(text)
671 | got, loc = e.exception.args
672 | self.assertEqual("C", got)
673 | self.assertEqual(text.index("C where"), loc)
674 |
675 | def test_check_program_instance(self):
676 | c, i, _, _ = ast.check_string(
677 | """
678 | class C where open C
679 | instance: C
680 | where
681 | def f [p: C] := Type
682 | example := f
683 | """
684 | )
685 | assert isinstance(c, Class)
686 | assert isinstance(i, Instance)
687 | self.assertEqual(c.instances[0], i.id)
688 |
689 | def test_check_program_instance_miss_failed(self):
690 | text = """
691 | class C where
692 | c: Type
693 | open C
694 | instance: C
695 | where
696 | """
697 | with self.assertRaises(ast.FieldMissError) as e:
698 | ast.check_string(text)
699 | name, loc = e.exception.args
700 | self.assertEqual("c", name)
701 | self.assertEqual(text.index("instance: C"), loc)
702 |
703 | def test_check_program_instance_unknown_failed(self):
704 | text = """
705 | class A where
706 | a: Type
707 | open A
708 | class B where
709 | b: Type
710 | open B
711 | instance: A
712 | where
713 | a := Type
714 | b := Type
715 | """
716 | with self.assertRaises(ast.UnknownFieldError) as e:
717 | ast.check_string(text)
718 | want, got, loc = e.exception.args
719 | self.assertEqual("A", want)
720 | self.assertEqual("b", got)
721 | self.assertEqual(text.index("b :="), loc)
722 |
723 | def test_check_program_instance_mismatch_failed(self):
724 | text = "instance: Type\nwhere"
725 | with self.assertRaises(ast.TypeMismatchError) as e:
726 | ast.check_string(text)
727 | want, got, loc = e.exception.args
728 | self.assertEqual("class", want)
729 | self.assertEqual("Type", got)
730 | self.assertEqual(text.index("Type"), loc)
731 |
732 | def test_check_program_field(self):
733 | _, _, _, e = ast.check_string(
734 | """
735 | inductive Void where open Void
736 |
737 | class C where
738 | c: Type
739 | open C
740 |
741 | instance: C
742 | where
743 | c := Void
744 |
745 | example: Type := c
746 | """
747 | )
748 | assert isinstance(e, Example)
749 | assert isinstance(e.body, ir.Data)
750 | self.assertEqual("Void", e.body.name.text)
751 |
752 | def test_check_program_field_parametric(self):
753 | _, _, _, d = ast.check_string(
754 | """
755 | class Default (T: Type) where
756 | default: T
757 | open Default
758 |
759 | inductive Data where
760 | | A
761 | | B
762 | open Data
763 |
764 | instance: Default Data
765 | where
766 | default := A
767 |
768 | def f := default Data
769 | """
770 | )
771 | assert isinstance(d, Def)
772 | self.assertEqual("Data.A", str(d.body))
773 |
774 | def test_check_program_class_add(self):
775 | _, _, _, _, f = ast.check_string(
776 | """
777 | inductive N where
778 | | Z
779 | | S (n: N)
780 | open N
781 |
782 | def addN (a: N) (b: N): N :=
783 | match a with
784 | | Z => b
785 | | S pred => S (addN pred b)
786 |
787 | class Add {T: Type} where
788 | add: (a: T) -> (b: T) -> T
789 | open Add
790 |
791 | instance: Add (T := N)
792 | where
793 | add := addN
794 |
795 | def f := add (S Z) (S Z)
796 | """
797 | )
798 | assert isinstance(f, Def)
799 | self.assertEqual("(N.S (N.S N.Z))", str(f.body))
800 |
--------------------------------------------------------------------------------
/src/TinyLean/tests/test_main.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest import TestCase
3 |
4 | from .. import ast, ir
5 |
6 |
7 | def nat_to_int(v: ir.IR):
8 | n = 0
9 | while True:
10 | if isinstance(v, ir.Fn):
11 | v = v.body
12 | else:
13 | break
14 | while True:
15 | if isinstance(v, ir.Call):
16 | assert isinstance(v.callee, ir.Ref)
17 | assert v.callee.name.text == "S"
18 | v = v.arg
19 | n += 1
20 | else:
21 | assert isinstance(v, ir.Ref)
22 | assert v.name.text == "Z"
23 | break
24 | return n
25 |
26 |
27 | class TestMain(TestCase):
28 | def test_nat(self):
29 | _, _, _, _3, _6, _9 = ast.check_string(
30 | """
31 | def Nat: Type :=
32 | (T: Type) -> (S: (n: T) -> T) -> (Z: T) -> T
33 |
34 | def add (a: Nat) (b: Nat): Nat :=
35 | fun T S Z => (a T S) (b T S Z)
36 |
37 | def mul (a: Nat) (b: Nat): Nat :=
38 | fun T S Z => (a T) (b T S) Z
39 |
40 | def _3: Nat := fun T S Z => S (S (S Z))
41 |
42 | def _6: Nat := add _3 _3
43 |
44 | def _9: Nat := mul _3 _3
45 | """
46 | )
47 | self.assertEqual(3, nat_to_int(_3.body))
48 | self.assertEqual(6, nat_to_int(_6.body))
49 | self.assertEqual(9, nat_to_int(_9.body))
50 |
51 | def test_leibniz_equality(self):
52 | ast.check_string(
53 | """
54 | def Eq (T: Type) (a: T) (b: T): Type :=
55 | (p: (v: T) -> Type) -> (pa: p a) -> p b
56 |
57 | def refl (T: Type) (a: T): Eq T a a :=
58 | fun p pa => pa
59 |
60 | def sym (T: Type) (a: T) (b: T) (p: Eq T a b): Eq T b a :=
61 | (p (fun b => Eq T b a)) (refl T a)
62 |
63 | def A: Type := Type
64 |
65 | def B: Type := Type
66 |
67 | def lemma: Eq Type A B := refl Type A
68 |
69 | def theorem (p: Eq Type A B): Eq Type B A := sym Type A B lemma
70 | """
71 | )
72 |
73 | def test_leibniz_equality_failed(self):
74 | with self.assertRaises(ast.TypeMismatchError) as e:
75 | ast.check_string(
76 | """
77 | def Eq (T: Type) (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b
78 | def refl (T: Type) (a: T): Eq T a a := fun p => fun pa => pa
79 | def A: Type := (a: Type) -> Type
80 | def B: Type := (a: (b: Type) -> Type) -> Type
81 | def _: Eq Type A B := refl Type A
82 | /- ^~~^ failed here -/
83 | """
84 | )
85 | want, got, loc = e.exception.args
86 | self.assertEqual(323, loc)
87 | self.assertEqual(
88 | "(p: (v: Type) → Type) → (pa: (p (a: Type) → Type)) → (p (a: (b: Type) → Type) → Type)",
89 | str(want),
90 | )
91 | self.assertEqual(
92 | "(p: (v: Type) → Type) → (pa: (p (a: Type) → Type)) → (p (a: Type) → Type)",
93 | str(got),
94 | )
95 |
96 | def test_markdown(self):
97 | results = ast.check_string(
98 | """\
99 | # Heading 1
100 |
101 | ```lean
102 | def Eq (T: Type) (a: T) (b: T): Type := (p: (v: T) -> Type) -> (pa: p a) -> p b
103 |
104 | def refl (T: Type) (a: T): Eq T a a := fun p => fun pa => pa
105 | ```
106 |
107 | ```lean
108 | def sym (T: Type) (a: T) (b: T) (p: Eq T a b): Eq T b a := (p (fun b => Eq T b a)) (refl T a)
109 | ```
110 |
111 | ```lean4
112 | def A: Type := Type
113 | ```
114 |
115 | ```python
116 | print("Hello, world!")
117 | ```
118 |
119 | ```
120 | Broken code.
121 | ```````
122 |
123 | Footer.
124 | """,
125 | True,
126 | )
127 | self.assertEqual(3, len(results))
128 | eq, refl, sym = results
129 | self.assertEqual("Eq", eq.name.text)
130 | self.assertEqual("refl", refl.name.text)
131 | self.assertEqual("sym", sym.name.text)
132 |
133 | def test_readme(self):
134 | p = Path(__file__).parent / ".." / ".." / ".." / ".github" / "README.md"
135 | with open(p, encoding="utf-8") as f:
136 | results = ast.check_string(f.read(), True)
137 | self.assertGreater(len(results), 1)
138 |
139 | def test_example(self):
140 | ast.check_string(
141 | """
142 | def T: Type := Type
143 | example: Type := T
144 | """
145 | )
146 |
147 | def test_leibniz_equality_implicit(self):
148 | ast.check_string(
149 | """
150 | def Eq {T: Type} (a: T) (b: T): Type :=
151 | (p: (v: T) -> Type) -> (pa: p a) -> p b
152 |
153 | def refl {T: Type} (a: T): Eq a a :=
154 | fun p pa => pa
155 |
156 | def sym {T: Type} (a: T) (b: T) (p: Eq a b): Eq b a :=
157 | (p (fun b => Eq b a)) (refl a)
158 |
159 | def A: Type := Type
160 |
161 | def B: Type := Type
162 |
163 | def lemma: Eq A B := refl A
164 |
165 | def theorem (p: Eq A B): Eq B A := sym A B lemma
166 | """
167 | )
168 |
169 | def test_operator_overloading(self):
170 | ast.check_string(
171 | """
172 | inductive N where
173 | | Z
174 | | S (n: N)
175 | open N
176 |
177 | def addN (a: N) (b: N): N :=
178 | match a with
179 | | Z => b
180 | | S pred => S (addN pred b)
181 |
182 | class Add {T: Type} where
183 | add: (a: T) -> (b: T) -> T
184 | open Add
185 |
186 | instance: Add (T := N)
187 | where
188 | add := addN
189 |
190 | def f := (S Z) + (S Z)
191 | """
192 | )
193 |
--------------------------------------------------------------------------------
/src/TinyLean/tests/test_parser.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from pyparsing import ParseException
4 |
5 | from . import parse
6 | from .. import (
7 | ast,
8 | Name,
9 | grammar,
10 | Param,
11 | Decl,
12 | Def,
13 | Example,
14 | Data,
15 | Ctor,
16 | Class,
17 | Instance,
18 | )
19 |
20 |
21 | class TestParser(TestCase):
22 | def test_fresh(self):
23 | self.assertNotEqual(Name("i").id, Name("j").id)
24 |
25 | def test_parse_name(self):
26 | x = parse(grammar.name, " hello")[0]
27 | assert isinstance(x, Name)
28 | self.assertEqual("hello", x.text)
29 |
30 | def test_parse_name_unbound(self):
31 | x = parse(grammar.name, "_")[0]
32 | self.assertTrue(x.is_unbound())
33 |
34 | def test_parse_type(self):
35 | x = parse(grammar.type_, " Type")[0]
36 | assert isinstance(x, ast.Type)
37 | self.assertEqual(2, x.loc)
38 |
39 | def test_parse_reference(self):
40 | x = parse(grammar.ref, " hello")[0]
41 | assert isinstance(x, ast.Ref)
42 | self.assertEqual(2, x.loc)
43 | self.assertEqual("hello", x.name.text)
44 |
45 | def test_parse_paren_expr(self):
46 | x = parse(grammar.p_expr, "(hello)")[0]
47 | assert isinstance(x, ast.Ref)
48 | self.assertEqual(1, x.loc)
49 | self.assertEqual("hello", x.name.text)
50 |
51 | def test_parse_implicit_param(self):
52 | x = parse(grammar.i_param, " {a: b}")[0]
53 | assert isinstance(x, Param)
54 | self.assertTrue(x.is_implicit)
55 | self.assertEqual("a", x.name.text)
56 | assert isinstance(x.type, ast.Ref)
57 | self.assertEqual(5, x.type.loc)
58 |
59 | def test_parse_explicit_param(self):
60 | x = parse(grammar.e_param, " (a : Type)")[0]
61 | assert isinstance(x, Param)
62 | self.assertFalse(x.is_implicit)
63 | self.assertEqual("a", x.name.text)
64 | assert isinstance(x.type, ast.Type)
65 | self.assertEqual(6, x.type.loc)
66 |
67 | def test_parse_call(self):
68 | x = parse(grammar.call, "a b")[0]
69 | assert isinstance(x, ast.Call)
70 | self.assertEqual(0, x.loc)
71 | assert isinstance(x.callee, ast.Ref)
72 | self.assertEqual(0, x.callee.loc)
73 | self.assertEqual("a", x.callee.name.text)
74 | self.assertEqual(2, x.arg.loc)
75 | assert isinstance(x.arg, ast.Ref)
76 | self.assertEqual("b", x.arg.name.text)
77 |
78 | def test_parse_call_paren(self):
79 | x = parse(grammar.call, "(a) b (Type)")[0]
80 | assert isinstance(x, ast.Call)
81 | self.assertEqual(0, x.loc)
82 | assert isinstance(x.callee, ast.Call)
83 | assert isinstance(x.callee.callee, ast.Ref)
84 | self.assertEqual(1, x.callee.callee.loc)
85 | self.assertEqual("a", x.callee.callee.name.text)
86 | assert isinstance(x.callee.arg, ast.Ref)
87 | self.assertEqual(4, x.callee.arg.loc)
88 | self.assertEqual("b", x.callee.arg.name.text)
89 | assert isinstance(x.arg, ast.Type)
90 | self.assertEqual(7, x.arg.loc)
91 |
92 | def test_parse_call_paren_function(self):
93 | x = parse(grammar.call, "(fun _ => Type) Type")[0]
94 | assert isinstance(x, ast.Call)
95 | self.assertEqual(0, x.loc)
96 | assert isinstance(x.callee, ast.Fn)
97 | self.assertEqual(1, x.callee.loc)
98 | self.assertTrue(x.callee.param.is_unbound())
99 | assert isinstance(x.callee.body, ast.Type)
100 | self.assertEqual(10, x.callee.body.loc)
101 | assert isinstance(x.arg, ast.Type)
102 | self.assertEqual(16, x.arg.loc)
103 |
104 | def test_parse_function_type(self):
105 | x = parse(grammar.fn_type, " (a : Type) -> a")[0]
106 | assert isinstance(x, ast.FnType)
107 | assert isinstance(x.param, Param)
108 | self.assertEqual("a", x.param.name.text)
109 | assert isinstance(x.param.type, ast.Type)
110 | self.assertEqual(7, x.param.type.loc)
111 | assert isinstance(x.ret, ast.Ref)
112 | self.assertEqual("a", x.ret.name.text)
113 | self.assertEqual(16, x.ret.loc)
114 |
115 | def test_parse_function_type_long(self):
116 | x = parse(grammar.fn_type, " {a : Type} -> (b: Type) -> a")[0]
117 | assert isinstance(x, ast.FnType)
118 | assert isinstance(x.param, Param)
119 | self.assertEqual("a", x.param.name.text)
120 | assert isinstance(x.param.type, ast.Type)
121 | self.assertEqual(6, x.param.type.loc)
122 | assert isinstance(x.ret, ast.FnType)
123 | assert isinstance(x.ret.param, Param)
124 | self.assertEqual("b", x.ret.param.name.text)
125 | assert isinstance(x.ret.param.type, ast.Type)
126 | self.assertEqual(19, x.ret.param.type.loc)
127 | assert isinstance(x.ret.ret, ast.Ref)
128 | self.assertEqual("a", x.ret.ret.name.text)
129 | self.assertEqual(28, x.ret.ret.loc)
130 |
131 | def test_parse_function(self):
132 | x = parse(grammar.fn, " fun a => a")[0]
133 | assert isinstance(x, ast.Fn)
134 | self.assertEqual(2, x.loc)
135 | assert isinstance(x.param, Name)
136 | self.assertEqual("a", x.param.text)
137 | assert isinstance(x.body, ast.Ref)
138 | self.assertEqual("a", x.body.name.text)
139 | self.assertEqual(11, x.body.loc)
140 |
141 | def test_parse_function_long(self):
142 | x = parse(grammar.fn, " fun a => fun b => a b")[0]
143 | assert isinstance(x, ast.Fn)
144 | self.assertEqual(3, x.loc)
145 | assert isinstance(x.param, Name)
146 | self.assertEqual("a", x.param.text)
147 | assert isinstance(x.body, ast.Fn)
148 | self.assertEqual(12, x.body.loc)
149 | assert isinstance(x.body.param, Name)
150 | self.assertEqual("b", x.body.param.text)
151 | assert isinstance(x.body.body, ast.Call)
152 | self.assertEqual(21, x.body.body.loc)
153 | assert isinstance(x.body.body.callee, ast.Ref)
154 | self.assertEqual("a", x.body.body.callee.name.text)
155 | assert isinstance(x.body.body.arg, ast.Ref)
156 | self.assertEqual("b", x.body.body.arg.name.text)
157 |
158 | def test_parse_function_multi(self):
159 | x = parse(grammar.fn, " fun c d => c d")[0]
160 | assert isinstance(x, ast.Fn)
161 | self.assertEqual(2, x.loc)
162 | assert isinstance(x.param, Name)
163 | self.assertEqual("c", x.param.text)
164 | assert isinstance(x.body, ast.Fn)
165 | self.assertEqual(2, x.body.loc)
166 | assert isinstance(x.body.param, Name)
167 | self.assertEqual("d", x.body.param.text)
168 |
169 | def test_parse_definition_constant(self):
170 | x = parse(grammar.def_, " def f : Type := Type")[0]
171 | assert isinstance(x, Def)
172 | self.assertEqual(6, x.loc)
173 | self.assertEqual("f", x.name.text)
174 | self.assertEqual(0, len(x.params))
175 | assert isinstance(x.ret, ast.Type)
176 | self.assertEqual(10, x.ret.loc)
177 | assert isinstance(x.body, ast.Type)
178 | self.assertEqual(18, x.body.loc)
179 |
180 | def test_parse_definition(self):
181 | x = parse(grammar.def_, " def f {a: Type} (b: Type): Type := a")[0]
182 | assert isinstance(x, Def)
183 | self.assertEqual(6, x.loc)
184 | self.assertEqual("f", x.name.text)
185 | assert isinstance(x.params, list)
186 | self.assertEqual(2, len(x.params))
187 | assert isinstance(x.params[0], ast.Param)
188 | self.assertTrue(x.params[0].is_implicit)
189 | self.assertEqual("a", x.params[0].name.text)
190 | assert isinstance(x.params[0].type, ast.Type)
191 | self.assertEqual(12, x.params[0].type.loc)
192 | assert isinstance(x.params[1], ast.Param)
193 | self.assertFalse(x.params[1].is_implicit)
194 | self.assertEqual("b", x.params[1].name.text)
195 | assert isinstance(x.params[1].type, ast.Type)
196 | self.assertEqual(22, x.params[1].type.loc)
197 |
198 | def test_parse_program(self):
199 | x = list(
200 | parse(
201 | grammar.program,
202 | """
203 | def a: Type := Type
204 | def b: Type := Type
205 | """,
206 | )
207 | )
208 | self.assertEqual(2, len(x))
209 | assert isinstance(x[0], Decl)
210 | self.assertEqual("a", x[0].name.text)
211 | assert isinstance(x[1], Decl)
212 | self.assertEqual("b", x[1].name.text)
213 |
214 | def test_parse_example(self):
215 | x = parse(grammar.example, " example: Type := Type")[0]
216 | assert isinstance(x, Example)
217 | self.assertEqual(2, x.loc)
218 | self.assertEqual(0, len(x.params))
219 | assert isinstance(x.ret, ast.Type)
220 | assert isinstance(x.body, ast.Type)
221 |
222 | def test_parse_placeholder(self):
223 | x = parse(grammar.fn, " fun _ => _")[0]
224 | assert isinstance(x, ast.Fn)
225 | self.assertTrue(x.param.is_unbound())
226 | assert isinstance(x.body, ast.Placeholder)
227 | self.assertEqual(10, x.body.loc)
228 |
229 | def test_parse_return_type(self):
230 | x = parse(grammar.return_type, ": Type")[0]
231 | assert isinstance(x, ast.Type)
232 | self.assertEqual(2, x.loc)
233 |
234 | def test_parse_return_placeholder(self):
235 | x = parse(grammar.return_type, "")[0]
236 | assert isinstance(x, ast.Placeholder)
237 | self.assertFalse(x.is_user)
238 |
239 | def test_parse_definition_no_return(self):
240 | x = parse(grammar.def_, "def a := Type")[0]
241 | assert isinstance(x, Def)
242 | assert isinstance(x.ret, ast.Placeholder)
243 | self.assertFalse(x.ret.is_user)
244 |
245 | def test_parse_call_implicit(self):
246 | x = parse(grammar.call, "a ( T := Nat )")[0]
247 | assert isinstance(x, ast.Call)
248 | assert isinstance(x.callee, ast.Ref)
249 | self.assertEqual("a", x.callee.name.text)
250 | self.assertEqual("T", x.implicit)
251 | assert isinstance(x.arg, ast.Ref)
252 | self.assertEqual("Nat", x.arg.name.text)
253 |
254 | def test_parse_call_explicit(self):
255 | x = parse(grammar.call, "a b")[0]
256 | assert isinstance(x, ast.Call)
257 | assert isinstance(x.callee, ast.Ref)
258 | self.assertEqual("a", x.callee.name.text)
259 | self.assertFalse(x.implicit)
260 | assert isinstance(x.arg, ast.Ref)
261 | self.assertEqual("b", x.arg.name.text)
262 |
263 | def test_parse_definition_call_implicit(self):
264 | x = parse(
265 | grammar.def_,
266 | """
267 | def f: Type := a (
268 | T := Nat
269 | ) b
270 | """,
271 | )[0]
272 | assert isinstance(x, Def)
273 | assert isinstance(x.body, ast.Call)
274 | assert isinstance(x.body.callee, ast.Call)
275 | assert isinstance(x.body.callee.callee, ast.Ref)
276 | self.assertEqual("a", x.body.callee.callee.name.text)
277 | assert isinstance(x.body.callee.arg, ast.Ref)
278 | self.assertEqual("Nat", x.body.callee.arg.name.text)
279 | self.assertEqual("T", x.body.callee.implicit)
280 | assert isinstance(x.body.arg, ast.Ref)
281 | self.assertEqual("b", x.body.arg.name.text)
282 |
283 | def test_parse_datatype_empty(self):
284 | x = parse(
285 | grammar.data,
286 | """
287 | inductive Void where
288 | open Void
289 | """,
290 | )[0]
291 | assert isinstance(x, Data)
292 | self.assertEqual("Void", x.name.text)
293 | self.assertEqual(0, len(x.params))
294 | self.assertEqual(0, len(x.ctors))
295 |
296 | def test_parse_datatype_empty_failed(self):
297 | with self.assertRaises(ParseException) as e:
298 | parse(grammar.data, "inductive Foo where open Bar")
299 | self.assertIn("open and datatype name mismatch", str(e.exception))
300 |
301 | def test_parse_datatype_ctors(self):
302 | x = parse(
303 | grammar.data,
304 | """
305 | inductive D {T: Type} (U: Type) where
306 | | A
307 | | B {X: Type} (Y: Type) (U := Type)
308 | | C (T := Type)
309 | open D
310 | """,
311 | )[0]
312 | assert isinstance(x, Data)
313 | self.assertEqual("D", x.name.text)
314 | self.assertEqual(2, len(x.params))
315 | self.assertEqual("T", x.params[0].name.text)
316 | self.assertEqual("U", x.params[1].name.text)
317 | self.assertEqual(3, len(x.ctors))
318 | a, b, c = x.ctors
319 | assert isinstance(a, Ctor)
320 | self.assertEqual("A", a.name.text)
321 | self.assertEqual(0, len(a.params))
322 | assert isinstance(b, Ctor)
323 | self.assertEqual("B", b.name.text)
324 | self.assertEqual(2, len(b.params))
325 | self.assertEqual(1, len(b.ty_args))
326 | self.assertEqual("U", b.ty_args[0][0].name.text)
327 | assert isinstance(c, Ctor)
328 | self.assertEqual("C", c.name.text)
329 | self.assertEqual(0, len(c.params))
330 | self.assertEqual(1, len(c.ty_args))
331 | self.assertEqual("T", c.ty_args[0][0].name.text)
332 |
333 | def test_parse_expr_nomatch(self):
334 | x = parse(grammar.nomatch, "nomatch x")[0]
335 | assert isinstance(x, ast.Nomatch)
336 | assert isinstance(x.arg, ast.Ref)
337 | self.assertEqual("x", x.arg.name.text)
338 |
339 | def test_parse_expr_case(self):
340 | x = parse(grammar.case, "| A a b => a")[0]
341 | assert isinstance(x, ast.Case)
342 | self.assertEqual(2, x.loc)
343 | self.assertEqual("A", x.ctor.name.text)
344 | self.assertEqual(2, len(x.params))
345 | self.assertEqual("a", x.params[0].text)
346 | self.assertEqual("b", x.params[1].text)
347 | assert isinstance(x.body, ast.Ref)
348 | self.assertEqual("a", x.body.name.text)
349 |
350 | def test_parse_expr_match(self):
351 | x = parse(
352 | grammar.match,
353 | """
354 | match Type with
355 | | A a => a
356 | | B b => b
357 | | _ => x /- not actually a default case -/
358 | """,
359 | )[0]
360 | assert isinstance(x, ast.Match)
361 | assert isinstance(x.arg, ast.Type)
362 | self.assertEqual(3, len(x.cases))
363 | self.assertTrue(x.cases[2].ctor.name.is_unbound())
364 |
365 | def test_parse_class_param(self):
366 | x = parse(grammar.c_param, "[p: GAdd Type]")[0]
367 | assert isinstance(x, Param)
368 | self.assertEqual("p", x.name.text)
369 | assert isinstance(x.type, ast.Call)
370 | assert isinstance(x.type.callee, ast.Ref)
371 | self.assertEqual("GAdd", x.type.callee.name.text)
372 | assert isinstance(x.type.arg, ast.Type)
373 | self.assertTrue(x.is_implicit)
374 |
375 | def test_parse_class_empty(self):
376 | x = parse(
377 | grammar.class_,
378 | """
379 | class A where
380 | open A
381 | """,
382 | )[0]
383 | assert isinstance(x, Class)
384 | self.assertEqual("A", x.name.text)
385 | self.assertEqual(0, len(x.params))
386 | self.assertEqual(0, len(x.fields))
387 |
388 | def test_parse_class_empty_failed(self):
389 | with self.assertRaises(ParseException) as e:
390 | parse(grammar.class_, "class A where open B")
391 | self.assertIn("open and class name mismatch", str(e.exception))
392 |
393 | def test_parse_class_fields(self):
394 | x = parse(
395 | grammar.class_,
396 | """
397 | class Op {T: Type} where
398 | add: (a: T) -> (b: T) -> T
399 | mul: (a: T) -> (b: T) -> T
400 | open Op
401 | """,
402 | )[0]
403 | assert isinstance(x, Class)
404 | self.assertEqual(1, len(x.params))
405 | self.assertEqual("T", x.params[0].name.text)
406 | self.assertEqual("Op", x.name.text)
407 | self.assertEqual(2, len(x.fields))
408 | self.assertEqual("add", x.fields[0].name.text)
409 | assert isinstance(x.fields[0].type, ast.FnType)
410 | self.assertEqual("mul", x.fields[1].name.text)
411 | assert isinstance(x.fields[1].type, ast.FnType)
412 |
413 | def test_parse_instance_empty(self):
414 | x = parse(grammar.inst, "instance : Monad A\nwhere")[0]
415 | assert isinstance(x, Instance)
416 | assert isinstance(x.type, ast.Call)
417 | assert isinstance(x.type.callee, ast.Ref)
418 | self.assertEqual("Monad", x.type.callee.name.text)
419 | assert isinstance(x.type.arg, ast.Ref)
420 | self.assertEqual("A", x.type.arg.name.text)
421 | self.assertEqual(0, len(x.fields))
422 |
423 | def test_parse_instance_fields(self):
424 | x = parse(
425 | grammar.inst,
426 | """
427 | instance : AddOp T
428 | where
429 | add := (a: T) -> (b: T) -> T
430 | """,
431 | )[0]
432 | assert isinstance(x, Instance)
433 | self.assertEqual(1, len(x.fields))
434 | f = x.fields[0]
435 | assert isinstance(f, tuple)
436 | n, v = f
437 | assert isinstance(n, ast.Ref)
438 | self.assertEqual("add", n.name.text)
439 | assert isinstance(v, ast.FnType)
440 |
441 | def test_parse_infix_op(self):
442 | x = parse(grammar.expr, "x + y")[0]
443 | assert isinstance(x, ast.Call)
444 | assert isinstance(x.callee, ast.Call)
445 | assert isinstance(x.callee.callee, ast.Ref)
446 | self.assertEqual("add", x.callee.callee.name.text)
447 | assert isinstance(x.callee.arg, ast.Ref)
448 | self.assertEqual("x", x.callee.arg.name.text)
449 | assert isinstance(x.arg, ast.Ref)
450 | self.assertEqual("y", x.arg.name.text)
451 |
--------------------------------------------------------------------------------
/src/TinyLean/tests/test_resolver.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from . import resolve_expr, resolve, resolve_md
4 | from .. import ast, Data
5 |
6 |
7 | class TestNameResolver(TestCase):
8 | def test_resolve_expr_function(self):
9 | x = resolve_expr("fun a => fun b => a b")
10 | assert isinstance(x, ast.Fn)
11 | assert isinstance(x.body, ast.Fn)
12 | assert isinstance(x.body.body, ast.Call)
13 | assert isinstance(x.body.body.callee, ast.Ref)
14 | assert isinstance(x.body.body.arg, ast.Ref)
15 | self.assertEqual(x.param.id, x.body.body.callee.name.id)
16 | self.assertEqual(x.body.param.id, x.body.body.arg.name.id)
17 |
18 | def test_resolve_expr_function_shadowed(self):
19 | x = resolve_expr("fun a => fun a => a")
20 | assert isinstance(x, ast.Fn)
21 | assert isinstance(x.body, ast.Fn)
22 | assert isinstance(x.body.body, ast.Ref)
23 | self.assertNotEqual(x.param.id, x.body.body.name.id)
24 | self.assertEqual(x.body.param.id, x.body.body.name.id)
25 |
26 | def test_resolve_expr_function_failed(self):
27 | with self.assertRaises(ast.UndefinedVariableError) as e:
28 | resolve_expr("fun a => b")
29 | n, loc = e.exception.args
30 | self.assertEqual(9, loc)
31 | self.assertEqual("b", n)
32 |
33 | def test_resolve_expr_function_type(self):
34 | x = resolve_expr("{a: Type} -> (b: Type) -> a")
35 | assert isinstance(x, ast.FnType)
36 | assert isinstance(x.ret, ast.FnType)
37 | assert isinstance(x.ret.ret, ast.Ref)
38 | self.assertEqual(x.param.name.id, x.ret.ret.name.id)
39 | self.assertNotEqual(x.ret.param.name.id, x.ret.ret.name.id)
40 |
41 | def test_resolve_expr_function_type_failed(self):
42 | with self.assertRaises(ast.UndefinedVariableError) as e:
43 | resolve_expr("{a: Type} -> (b: Type) -> c")
44 | n, loc = e.exception.args
45 | self.assertEqual(26, loc)
46 | self.assertEqual("c", n)
47 |
48 | def test_resolve_program(self):
49 | resolve(
50 | """
51 | def f0 (a: Type): Type := a
52 | def f1 (a: Type): Type := f0 a
53 | """
54 | )
55 |
56 | def test_resolve_program_failed(self):
57 | with self.assertRaises(ast.UndefinedVariableError) as e:
58 | resolve("def f (a: Type) (b: c): Type := Type")
59 | n, loc = e.exception.args
60 | self.assertEqual(20, loc)
61 | self.assertEqual("c", n)
62 |
63 | def test_resolve_program_duplicate(self):
64 | with self.assertRaises(ast.DuplicateVariableError) as e:
65 | resolve(
66 | """
67 | def f0: Type := Type
68 | def f0: Type := Type
69 | """
70 | )
71 | n, loc = e.exception.args
72 | self.assertEqual(58, loc)
73 | self.assertEqual("f0", n)
74 |
75 | def test_resolve_md(self):
76 | resolve_md(
77 | """\
78 | # Heading
79 |
80 | ```lean
81 | def a := Type
82 | ```
83 |
84 | Some text.
85 |
86 | ```lean
87 | def b := a
88 | ```
89 |
90 | Footer.
91 | """
92 | )
93 |
94 | def test_resolve_expr_placeholder(self):
95 | resolve_expr("{a: Type} -> (b: Type) -> _")
96 |
97 | def test_resolve_datatype_empty(self):
98 | resolve("inductive Void where open Void")
99 |
100 | def test_resolve_datatype_nat(self):
101 | x = resolve(
102 | """
103 | inductive N where
104 | | Z
105 | | S (n: N)
106 | open N
107 | """
108 | )[0]
109 | assert isinstance(x, Data)
110 | a = x.name.id
111 | b_typ = x.ctors[1].params[0].type
112 | assert isinstance(b_typ, ast.Ref)
113 | b = b_typ.name.id
114 | self.assertEqual(a, b)
115 |
116 | def test_resolve_datatype_maybe(self):
117 | x = resolve(
118 | """
119 | inductive Maybe (A: Type) where
120 | | Nothing
121 | | Just (a: A)
122 | open Maybe
123 | """
124 | )[0]
125 | assert isinstance(x, Data)
126 | a = x.params[0].name.id
127 | b_typ = x.ctors[1].params[0].type
128 | assert isinstance(b_typ, ast.Ref)
129 | b = b_typ.name.id
130 | self.assertEqual(a, b)
131 |
132 | def test_resolve_datatype_vec(self):
133 | resolve(
134 | """
135 | inductive N where
136 | | Z
137 | | S (n: N)
138 | open N
139 |
140 | inductive Vec (A : Type) (n : N) where
141 | | Nil (n := Z)
142 | | Cons {m: N} (a: A) (v: Vec A m) (n := S m)
143 | open Vec
144 | """
145 | )
146 |
147 | def test_resolve_datatype_duplicate(self):
148 | with self.assertRaises(ast.DuplicateVariableError) as e:
149 | resolve("inductive A where | A open A")
150 | name, loc = e.exception.args
151 | self.assertEqual("A", name)
152 | self.assertEqual(20, loc)
153 |
154 | def test_resolve_match(self):
155 | resolve(
156 | """
157 | inductive A where | AA (T: Type) open A
158 | example (x: A) :=
159 | match x with
160 | | A t => t
161 | """
162 | )
163 |
164 | def test_resolve_match_failed(self):
165 | with self.assertRaises(ast.UndefinedVariableError) as e:
166 | resolve(
167 | """
168 | inductive A where | AA open A
169 | example (x: A) :=
170 | match x with
171 | | A => b
172 | """
173 | )
174 | name, loc = e.exception.args
175 | self.assertEqual("b", name)
176 | self.assertEqual(137, loc)
177 |
178 | def test_resolve_class(self):
179 | resolve(
180 | """
181 | class A (T: Type) where
182 | a: T
183 | open A
184 | example := a
185 | """
186 | )
187 |
188 | def test_resolve_instance(self):
189 | resolve(
190 | """
191 | inductive Void where open Void
192 |
193 | class A (T: Type) where
194 | a: T
195 | open A
196 |
197 | instance: A Void
198 | where
199 | a := Type
200 | """
201 | )
202 |
203 | def test_resolve_instance_failed(self):
204 | text = """
205 | class C where
206 | c: Type
207 | open C
208 | instance: C
209 | where
210 | c := Type
211 | c := Type
212 | """
213 | with self.assertRaises(ast.DuplicateVariableError) as e:
214 | resolve(text)
215 | name, loc = e.exception.args
216 | self.assertEqual("c", name)
217 | self.assertEqual(text.rindex("c :="), loc)
218 |
--------------------------------------------------------------------------------