├── lean-toolchain ├── .mailmap ├── archive ├── data │ ├── smallV.mtx │ ├── smallM.mtx │ └── smaller.mtx ├── leanpkg.toml ├── gen_taco_kernels.sh └── src │ ├── verification │ ├── finsupp_lemmas.lean │ ├── semantics │ │ ├── replicate.lean │ │ ├── zero.lean │ │ ├── examples.lean │ │ ├── README.md │ │ ├── contract.lean │ │ ├── split.lean │ │ └── dense.lean │ ├── code_generation │ │ └── frames.lean │ └── test.lean │ └── test.lean ├── graphs ├── taco_scaling.pdf ├── tpch_scaling.pdf ├── wcoj_scaling.pdf ├── filtered_spmv.pdf ├── tpch_q5_scaling.pdf ├── tpch_q9_scaling.pdf ├── pyproject.toml ├── Dockerfile └── run.sh ├── bench ├── filtered-spmv-duckdb.sql ├── filtered-spmv-sqlite.sql ├── wcoj-duckdb.sql ├── wcoj-sqlite.sql ├── wcoj-sqlite-import.sql ├── filtered-spmv-sqlite-prep.sql ├── filtered-spmv-duckdb-prep.sql ├── wcoj-duckdb-prep.sql ├── tpch-q5-sqlite.sql ├── wcoj-sqlite-prep.sql ├── tpch-q5-duckdb.sql ├── tpch-q9-duckdb.sql ├── tpch-q9-sqlite.sql ├── wcoj-datagen.py ├── matmul-datagen.py ├── filtered-spmv-datagen.py ├── tpch-q5-duckdb-export-to-csv.sql ├── tpch-q9-duckdb-export-to-csv.sql ├── tpch-q5-sqlite-prep.sql ├── tpch-q9-sqlite-prep.sql ├── tpch-q9-duckdb-prep.sql ├── tpch-q5-duckdb-prep.sql ├── tpch-q5-duckdb-prep-foreign-key.sql ├── tpch-q9-duckdb-prep-foreign-key.sql └── taco-datagen.py ├── Etch ├── StreamFusion │ ├── Proofs │ │ ├── Proofs.lean │ │ ├── NestedEval.lean │ │ └── Imap.lean │ ├── ReuseTest.lean │ ├── Benchmark-PeopleMovieDistance.lean │ ├── Benchmark-CountEmployeesOfSmallCompanies.lean │ ├── Main.lean │ ├── SequentialStream.lean │ ├── ExpandSeq.lean │ ├── TestUtil.lean │ └── Tutorial.lean ├── Benchmark │ ├── OldSQL.lean │ ├── WCOJ.lean │ ├── TPCHq5.lean │ ├── SQL.lean │ ├── Basic.lean │ └── TPCHq9.lean ├── Benchmark.lean ├── Verification │ └── FinsuppLemmas.lean ├── Basic.lean ├── Mul.lean ├── Compile.lean ├── Add.lean ├── InductiveStreamTest.lean ├── Util │ └── Labels.lean ├── InductiveStreamDeriving.lean ├── KRelation.lean ├── C.lean ├── Op.lean └── LVal.lean ├── taco ├── sum_B_csr.c ├── spmv.c ├── sum_mul2_csr.c ├── wcoj.c ├── sum_mul2.c ├── sum_mul2_inner.c ├── sum_mul2_inner_ss.c ├── mul2_inner.c ├── sum_ttm.c ├── sum_mttkrp.c ├── mttkrp.c ├── inner2ss.c ├── sum_add2.c ├── sum_inner3.c └── add2.c ├── .gitignore ├── genplots.py ├── taco_kernels.c ├── lakefile.lean ├── impls_readable.h ├── bench-wcoj.cpp ├── lake-manifest.json ├── bench-matmul.cpp ├── bench-duckdb.cpp ├── bench-filtered-spmv.cpp ├── operators.h ├── common.h ├── readme.md └── bench-sqlite.cpp /lean-toolchain: -------------------------------------------------------------------------------- 1 | leanprover/lean4:v4.7.0 2 | -------------------------------------------------------------------------------- /.mailmap: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /archive/data/smallV.mtx: -------------------------------------------------------------------------------- 1 | 2597 2 | 2.0 3 | 3.0 4 | 6.0 5 | 10.0 -------------------------------------------------------------------------------- /graphs/taco_scaling.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovach/etch/HEAD/graphs/taco_scaling.pdf -------------------------------------------------------------------------------- /graphs/tpch_scaling.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovach/etch/HEAD/graphs/tpch_scaling.pdf -------------------------------------------------------------------------------- /graphs/wcoj_scaling.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovach/etch/HEAD/graphs/wcoj_scaling.pdf -------------------------------------------------------------------------------- /graphs/filtered_spmv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovach/etch/HEAD/graphs/filtered_spmv.pdf -------------------------------------------------------------------------------- /graphs/tpch_q5_scaling.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovach/etch/HEAD/graphs/tpch_q5_scaling.pdf -------------------------------------------------------------------------------- /graphs/tpch_q9_scaling.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovach/etch/HEAD/graphs/tpch_q9_scaling.pdf -------------------------------------------------------------------------------- /bench/filtered-spmv-duckdb.sql: -------------------------------------------------------------------------------- 1 | SELECT SUM(A.v * V.v) 2 | FROM A, V 3 | WHERE V.v >= 0.8 AND A.i == V.i; 4 | -------------------------------------------------------------------------------- /bench/filtered-spmv-sqlite.sql: -------------------------------------------------------------------------------- 1 | SELECT SUM(A.v * V.v) 2 | FROM A, V 3 | WHERE V.v >= 0.8 AND A.i == V.i; 4 | -------------------------------------------------------------------------------- /archive/data/smallM.mtx: -------------------------------------------------------------------------------- 1 | 9 11 76367 2 | 1 7 4 3 | 3 7 2 4 | 9 7 2 5 | 1 8 0 6 | 2 8 3 7 | 10 8 3 8 | 6 9 7 9 | 7 9 5 10 | -------------------------------------------------------------------------------- /Etch/StreamFusion/Proofs/Proofs.lean: -------------------------------------------------------------------------------- 1 | import Etch.StreamFusion.Proofs.StreamMul 2 | import Etch.StreamFusion.Proofs.OfStream 3 | import Etch.StreamFusion.Proofs.Imap 4 | -------------------------------------------------------------------------------- /bench/wcoj-duckdb.sql: -------------------------------------------------------------------------------- 1 | -- for a description of this join problem see https://arxiv.org/pdf/1310.3314.pdf, Figure 2 2 | 3 | SELECT COUNT(*) 4 | FROM r, s, t 5 | WHERE r.a = t.a AND r.b = s.b AND s.c = t.c; 6 | -------------------------------------------------------------------------------- /bench/wcoj-sqlite.sql: -------------------------------------------------------------------------------- 1 | -- for a description of this join problem see https://arxiv.org/pdf/1310.3314.pdf, Figure 2 2 | 3 | SELECT COUNT(*) 4 | FROM r, s, t 5 | WHERE r.a = t.a AND r.b = s.b AND s.c = t.c; 6 | -------------------------------------------------------------------------------- /taco/sum_B_csr.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t i = 0; i < B1_dimension; i++) { 5 | for (int32_t jB = B2_pos[i]; jB < B2_pos[(i + 1)]; jB++) { 6 | out_val += B_vals[jB]; 7 | } 8 | } 9 | 10 | return out_val; 11 | -------------------------------------------------------------------------------- /archive/leanpkg.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "etch" 3 | version = "0.1" 4 | lean_version = "leanprover-community/lean:3.50.3" 5 | path = "src" 6 | 7 | [dependencies] 8 | mathlib = {git = "https://github.com/leanprover-community/mathlib", rev = "cc5dd6244981976cc9da7afc4eee5682b037a013"} 9 | -------------------------------------------------------------------------------- /taco/spmv.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t iA = A1_pos[0]; iA < A1_pos[1]; iA++) { 5 | for (int32_t jA = A2_pos[iA]; jA < A2_pos[(iA + 1)]; jA++) { 6 | int32_t j = A2_crd[jA]; 7 | out_val += A_vals[jA] * V_vals[j]; 8 | } 9 | } 10 | 11 | return out_val; 12 | -------------------------------------------------------------------------------- /taco/sum_mul2_csr.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t iA = A1_pos[0]; iA < A1_pos[1]; iA++) { 5 | for (int32_t jA = A2_pos[iA]; jA < A2_pos[(iA + 1)]; jA++) { 6 | int32_t j = A2_crd[jA]; 7 | for (int32_t kB = B2_pos[j]; kB < B2_pos[(j + 1)]; kB++) { 8 | out_val += A_vals[jA] * B_vals[kB]; 9 | } 10 | } 11 | } 12 | 13 | return out_val; 14 | -------------------------------------------------------------------------------- /bench/wcoj-sqlite-import.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE R (A INTEGER NOT NULL, 2 | B INTEGER NOT NULL, 3 | PRIMARY KEY (A, B)); 4 | CREATE TABLE S (B INTEGER NOT NULL, 5 | C INTEGER NOT NULL, 6 | PRIMARY KEY (B, C)); 7 | CREATE TABLE T (A INTEGER NOT NULL, 8 | C INTEGER NOT NULL, 9 | PRIMARY KEY (A, C)); 10 | 11 | .mode csv 12 | .import wcoj-csv/r.csv R 13 | .import wcoj-csv/s.csv S 14 | .import wcoj-csv/t.csv T 15 | -------------------------------------------------------------------------------- /graphs/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "graphs" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Timothy Gu "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | # IMPORTANT: run: poetry export -f requirements.txt --output requirements.txt 10 | # when changing dependencies 11 | python = "^3.10" 12 | matplotlib = "^3.7.1" 13 | 14 | 15 | [build-system] 16 | requires = ["poetry-core"] 17 | build-backend = "poetry.core.masonry.api" 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.olean 2 | *~ 3 | junk/ 4 | /_target 5 | /leanpkg.path 6 | out*.cpp 7 | *.asm 8 | TAGS 9 | 10 | .lake/ 11 | /lake-packages 12 | 13 | *.o 14 | .gdb_history 15 | bench-duckdb 16 | bench-filtered-spmv 17 | bench-matmul 18 | bench-sqlite 19 | bench-taco 20 | bench-tpch-q[0-9] 21 | bench-wcoj 22 | 23 | # data files (see Makefile) 24 | sqlite3.c 25 | sqlite3.h 26 | sqlite-amalgamation-* 27 | sqlite3/ 28 | libduckdb-* 29 | duckdb_cli-* 30 | duckdb/ 31 | tmp-*.duckdb 32 | data/ 33 | gen*.c 34 | 35 | *.svg 36 | *.csv 37 | *.pdf 38 | -------------------------------------------------------------------------------- /bench/filtered-spmv-sqlite-prep.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE A (I INTEGER NOT NULL, 2 | J INTEGER NOT NULL, 3 | V REAL NOT NULL, 4 | PRIMARY KEY (I, J)); 5 | CREATE TABLE V (I INTEGER NOT NULL, 6 | V REAL NOT NULL, 7 | PRIMARY KEY (I)); 8 | 9 | ATTACH DATABASE 'data/filtered-spmv-2000000.db' AS t; 10 | 11 | INSERT INTO A 12 | SELECT * 13 | FROM t.A 14 | ORDER BY I, J; 15 | 16 | INSERT INTO V 17 | SELECT * 18 | FROM t.V 19 | ORDER BY I; 20 | 21 | DETACH DATABASE t; 22 | -------------------------------------------------------------------------------- /Etch/StreamFusion/ReuseTest.lean: -------------------------------------------------------------------------------- 1 | inductive L1 | nil | cons : Nat → L1 → L1 2 | inductive L2 | nil | cons : Nat → L2 → L2 3 | 4 | def List.toL1 : List Nat → L1 5 | | .nil => .nil 6 | | .cons x xs => .cons x xs.toL1 7 | 8 | set_option trace.compiler.ir true 9 | def map1 : L1 → L1 := 10 | let rec go1 11 | | x, .nil => x 12 | | x, .cons y xs => go1 (.cons y x) xs 13 | go1 .nil 14 | 15 | def map2 : L1 → L2 := 16 | let rec go2 17 | | x, .nil => x 18 | | x, .cons y xs => go2 (.cons y x) xs 19 | go2 .nil 20 | 21 | def main : IO Unit := pure () 22 | -------------------------------------------------------------------------------- /bench/filtered-spmv-duckdb-prep.sql: -------------------------------------------------------------------------------- 1 | INSTALL sqlite; 2 | LOAD sqlite; 3 | 4 | CREATE TABLE A (I INTEGER NOT NULL, 5 | J INTEGER NOT NULL, 6 | V REAL NOT NULL, 7 | PRIMARY KEY (I, J)); 8 | CREATE TABLE V (I INTEGER NOT NULL, 9 | V REAL NOT NULL, 10 | PRIMARY KEY (I)); 11 | 12 | INSERT INTO A 13 | SELECT * 14 | FROM sqlite_scan('data/filtered-spmv-2000000.db', 'A') 15 | ORDER BY I, J; 16 | 17 | INSERT INTO V 18 | SELECT * 19 | FROM sqlite_scan('data/filtered-spmv-2000000.db', 'V') 20 | ORDER BY I; 21 | 22 | PRAGMA threads=1; 23 | -------------------------------------------------------------------------------- /bench/wcoj-duckdb-prep.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE R (A INTEGER NOT NULL, 2 | B INTEGER NOT NULL, 3 | PRIMARY KEY (A, B)); 4 | CREATE TABLE S (B INTEGER NOT NULL, 5 | C INTEGER NOT NULL, 6 | PRIMARY KEY (B, C)); 7 | CREATE TABLE T (A INTEGER NOT NULL, 8 | C INTEGER NOT NULL, 9 | PRIMARY KEY (A, C)); 10 | 11 | COPY R FROM 'wcoj-csv/r.csv' (HEADER, DELIMITER ','); 12 | COPY S FROM 'wcoj-csv/s.csv' (HEADER, DELIMITER ','); 13 | COPY T FROM 'wcoj-csv/t.csv' (HEADER, DELIMITER ','); 14 | 15 | PRAGMA database_size; 16 | 17 | PRAGMA threads=1; 18 | -------------------------------------------------------------------------------- /taco/wcoj.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | int32_t iA = A1_pos[0]; 5 | int32_t pA1_end = A1_pos[1]; 6 | int32_t iB = B1_pos[0]; 7 | int32_t pB1_end = B1_pos[1]; 8 | 9 | while (iA < pA1_end && iB < pB1_end) { 10 | int32_t iA0 = A1_crd[iA]; 11 | int32_t iB0 = B1_crd[iB]; 12 | int32_t i = TACO_MIN(iA0,iB0); 13 | if (iA0 == i && iB0 == i) { 14 | for (int32_t jA = A2_pos[iA]; jA < A2_pos[(iA + 1)]; jA++) { 15 | for (int32_t kB = B2_pos[iB]; kB < B2_pos[(iB + 1)]; kB++) { 16 | out_val += A_vals[jA] * B_vals[kB]; 17 | } 18 | } 19 | } 20 | iA += (int32_t)(iA0 == i); 21 | iB += (int32_t)(iB0 == i); 22 | } 23 | 24 | return out_val; 25 | 26 | -------------------------------------------------------------------------------- /bench/tpch-q5-sqlite.sql: -------------------------------------------------------------------------------- 1 | -- https://github.com/duckdb/duckdb/blob/88b1bfa74d2b79a51ffc4bab18ddeb6a034652f1/extension/tpch/dbgen/queries/q05.sql 2 | -- Except ORDER BY 3 | 4 | SELECT 5 | n_name, 6 | sum(l_extendedprice * (1 - l_discount)) AS revenue 7 | FROM 8 | customer, 9 | orders, 10 | lineitem, 11 | supplier, 12 | nation, 13 | region 14 | WHERE 15 | c_custkey = o_custkey 16 | AND l_orderkey = o_orderkey 17 | AND l_suppkey = s_suppkey 18 | AND c_nationkey = s_nationkey 19 | AND s_nationkey = n_nationkey 20 | AND n_regionkey = r_regionkey 21 | AND r_name = 'ASIA' 22 | AND o_orderdate >= '1994-01-01' 23 | AND o_orderdate < '1995-01-01' 24 | GROUP BY 25 | n_name; 26 | -------------------------------------------------------------------------------- /bench/wcoj-sqlite-prep.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE R (A INTEGER NOT NULL, 2 | B INTEGER NOT NULL, 3 | PRIMARY KEY (A, B)); 4 | CREATE TABLE S (B INTEGER NOT NULL, 5 | C INTEGER NOT NULL, 6 | PRIMARY KEY (B, C)); 7 | CREATE TABLE T (A INTEGER NOT NULL, 8 | C INTEGER NOT NULL, 9 | PRIMARY KEY (A, C)); 10 | 11 | ATTACH DATABASE 'wcoj.db' AS orig; 12 | 13 | INSERT INTO R 14 | SELECT a, b 15 | FROM orig.R; 16 | 17 | INSERT INTO S 18 | SELECT b, c 19 | FROM orig.S; 20 | 21 | INSERT INTO T 22 | SELECT a, c 23 | FROM orig.T; 24 | 25 | DETACH DATABASE orig; 26 | 27 | SELECT page_count * page_size as size 28 | FROM pragma_page_count(), pragma_page_size(); 29 | -------------------------------------------------------------------------------- /taco/sum_mul2.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t iA = A1_pos[0]; iA < A1_pos[1]; iA++) { 5 | int32_t jA = A2_pos[iA]; 6 | int32_t pA2_end = A2_pos[(iA + 1)]; 7 | int32_t jB = B1_pos[0]; 8 | int32_t pB1_end = B1_pos[1]; 9 | 10 | while (jA < pA2_end && jB < pB1_end) { 11 | int32_t jA0 = A2_crd[jA]; 12 | int32_t jB0 = B1_crd[jB]; 13 | int32_t j = TACO_MIN(jA0,jB0); 14 | if (jA0 == j && jB0 == j) { 15 | for (int32_t kB = B2_pos[jB]; kB < B2_pos[(jB + 1)]; kB++) { 16 | out_val += A_vals[jA] * B_vals[kB]; 17 | } 18 | } 19 | jA += (int32_t)(jA0 == j); 20 | jB += (int32_t)(jB0 == j); 21 | } 22 | } 23 | 24 | return out_val; 25 | -------------------------------------------------------------------------------- /taco/sum_mul2_inner.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t i = 0; i < A1_dimension; i++) { 5 | for (int32_t j = 0; j < B1_dimension; j++) { 6 | int32_t kA = A2_pos[i]; 7 | int32_t pA2_end = A2_pos[(i + 1)]; 8 | int32_t kB = B2_pos[j]; 9 | int32_t pB2_end = B2_pos[(j + 1)]; 10 | 11 | while (kA < pA2_end && kB < pB2_end) { 12 | int32_t kA0 = A2_crd[kA]; 13 | int32_t kB0 = B2_crd[kB]; 14 | int32_t k = TACO_MIN(kA0,kB0); 15 | if (kA0 == k && kB0 == k) { 16 | out_val += A_vals[kA] * B_vals[kB]; 17 | } 18 | kA += (int32_t)(kA0 == k); 19 | kB += (int32_t)(kB0 == k); 20 | } 21 | } 22 | } 23 | 24 | return out_val; 25 | -------------------------------------------------------------------------------- /bench/tpch-q5-duckdb.sql: -------------------------------------------------------------------------------- 1 | -- https://github.com/duckdb/duckdb/blob/88b1bfa74d2b79a51ffc4bab18ddeb6a034652f1/extension/tpch/dbgen/queries/q05.sql 2 | -- Except ORDER BY 3 | 4 | SELECT 5 | n_name, 6 | sum(l_extendedprice * (1 - l_discount)) AS revenue 7 | FROM 8 | customer, 9 | orders, 10 | lineitem, 11 | supplier, 12 | nation, 13 | region 14 | WHERE 15 | c_custkey = o_custkey 16 | AND l_orderkey = o_orderkey 17 | AND l_suppkey = s_suppkey 18 | AND c_nationkey = s_nationkey 19 | AND s_nationkey = n_nationkey 20 | AND n_regionkey = r_regionkey 21 | AND r_name = 'ASIA' 22 | AND o_orderdate >= CAST('1994-01-01' AS date) 23 | AND o_orderdate < CAST('1995-01-01' AS date) 24 | GROUP BY 25 | n_name; 26 | -------------------------------------------------------------------------------- /taco/sum_mul2_inner_ss.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t iA = A1_pos[0]; iA < A1_pos[1]; iA++) { 5 | for (int32_t jB = B1_pos[0]; jB < B1_pos[1]; jB++) { 6 | int32_t kA = A2_pos[iA]; 7 | int32_t pA2_end = A2_pos[(iA + 1)]; 8 | int32_t kB = B2_pos[jB]; 9 | int32_t pB2_end = B2_pos[(jB + 1)]; 10 | 11 | while (kA < pA2_end && kB < pB2_end) { 12 | int32_t kA0 = A2_crd[kA]; 13 | int32_t kB0 = B2_crd[kB]; 14 | int32_t k = TACO_MIN(kA0,kB0); 15 | if (kA0 == k && kB0 == k) { 16 | out_val += A_vals[kA] * B_vals[kB]; 17 | } 18 | kA += (int32_t)(kA0 == k); 19 | kB += (int32_t)(kB0 == k); 20 | } 21 | } 22 | } 23 | 24 | return out_val; 25 | 26 | -------------------------------------------------------------------------------- /archive/gen_taco_kernels.sh: -------------------------------------------------------------------------------- 1 | #$cmd -prefix="ttv_" -print-nocolor -print-evaluate "out = A(i,j,k)*V(k)" -f=A:sss -f=V:s -s="reorder(i,j,k)" > ttv.cpp 2 | cmd=/home/scott/Dropbox/2022/taco/build/bin/taco 3 | # $cmd -prefix="ttv_" -print-nocolor "out = C(i,j,k)*V(k)" -f=C:sss -f=V:s -s="reorder(i,j,k)" > ttv.cpp 4 | # $cmd -prefix="ttm_" -print-nocolor "out = C(i,j,l)*A(k,l)" -f=C:sss -f=A:ss -s="reorder(i,j,k,l)" > ttm.cpp 5 | # $cmd -prefix="mttkrp_" -print-nocolor "out = C(i,j,k)*A(j,l)*B(k,l)" -f=C:sss -f=A:ss -f=B:ss -s="reorder(i,j,k,l)" > mttkrp.cpp 6 | # $cmd -prefix="inner3_" -print-nocolor "out = C(i,j,k)*D(i,j,k)" -f=C:sss -f=D:sss -s="reorder(i,j,k)" > inner3.cpp 7 | # $cmd -prefix="mmul2_" -print-nocolor "out = A(i,k)*B(j,k)" -f=A:ss -f=B:ss -s="reorder(i,j,k)" > mmul2.cpp 8 | -------------------------------------------------------------------------------- /genplots.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import seaborn as sns, matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | sns.set(style="whitegrid") 6 | sns.set(font_scale=2) 7 | 8 | prefix = sys.argv[1] if len(sys.argv) > 1 else "" 9 | 10 | n = 5 11 | titles = [ f'title {i}' for i in range(1,n+1) ] 12 | 13 | for (i,title) in enumerate(titles): 14 | plt.figure() 15 | name = "{prefix}plot{number}".format(prefix=prefix,number=i+1) 16 | tips = pd.read_csv(name + ".csv") 17 | gfg = sns.barplot(x="test", y="time", data=tips, capsize=.1, errorbar="sd") 18 | plt.xticks(rotation=22, fontsize=16) 19 | sns.swarmplot(x="test", y="time", data=tips, color="0", alpha=.35) 20 | gfg.set(xlabel ="", ylabel = "execution time", title = "") 21 | 22 | plt.savefig(name + ".pdf", bbox_inches='tight') 23 | -------------------------------------------------------------------------------- /taco/mul2_inner.c: -------------------------------------------------------------------------------- 1 | 2 | for (int32_t i = 0; i < A1_dimension; i++) { 3 | for (int32_t j = 0; j < B1_dimension; j++) { 4 | int32_t jout = i * out2_dimension + j; 5 | double tkout_val = 0.0; 6 | int32_t kA = A2_pos[i]; 7 | int32_t pA2_end = A2_pos[(i + 1)]; 8 | int32_t kB = B2_pos[j]; 9 | int32_t pB2_end = B2_pos[(j + 1)]; 10 | 11 | while (kA < pA2_end && kB < pB2_end) { 12 | int32_t kA0 = A2_crd[kA]; 13 | int32_t kB0 = B2_crd[kB]; 14 | int32_t k = TACO_MIN(kA0,kB0); 15 | if (kA0 == k && kB0 == k) { 16 | tkout_val += A_vals[kA] * B_vals[kB]; 17 | } 18 | kA += (int32_t)(kA0 == k); 19 | kB += (int32_t)(kB0 == k); 20 | } 21 | out_vals[jout] = tkout_val; 22 | } 23 | } 24 | 25 | return 0; 26 | 27 | -------------------------------------------------------------------------------- /archive/src/verification/finsupp_lemmas.lean: -------------------------------------------------------------------------------- 1 | import data.finsupp.basic 2 | import data.finsupp.indicator 3 | 4 | variables {α β γ : Type*} [add_comm_monoid β] 5 | open_locale classical 6 | noncomputable theory 7 | 8 | namespace finsupp 9 | 10 | def sum_range : (α →₀ β) →+ β := 11 | { to_fun := λ f, (f.map_domain default) (), 12 | map_zero' := rfl, 13 | map_add' := by simp [map_domain_add] } 14 | 15 | variables (f g : α →₀ β) 16 | lemma sum_range_eq_sum : f.sum_range = f.sum (λ _ v, v) := 17 | by simp [sum_range, map_domain] 18 | 19 | @[simp] lemma sum_range_single (x : α) (y : β) : (finsupp.single x y).sum_range = y := 20 | by simp [sum_range] 21 | 22 | @[simp] lemma map_domain_sum_range (h : α → γ) : 23 | (f.map_domain h).sum_range = f.sum_range := 24 | by simp [sum_range, ← finsupp.map_domain_comp] 25 | 26 | end finsupp 27 | -------------------------------------------------------------------------------- /taco/sum_ttm.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t i = 0; i < C1_dimension; i++) { 5 | for (int32_t jC = C2_pos[i]; jC < C2_pos[(i + 1)]; jC++) { 6 | for (int32_t k = 0; k < B1_dimension; k++) { 7 | int32_t lC = C3_pos[jC]; 8 | int32_t pC3_end = C3_pos[(jC + 1)]; 9 | int32_t lB = B2_pos[k]; 10 | int32_t pB2_end = B2_pos[(k + 1)]; 11 | 12 | while (lC < pC3_end && lB < pB2_end) { 13 | int32_t lC0 = C3_crd[lC]; 14 | int32_t lB0 = B2_crd[lB]; 15 | int32_t l = TACO_MIN(lC0,lB0); 16 | if (lC0 == l && lB0 == l) { 17 | out_val += C_vals[lC] * B_vals[lB]; 18 | } 19 | lC += (int32_t)(lC0 == l); 20 | lB += (int32_t)(lB0 == l); 21 | } 22 | } 23 | } 24 | } 25 | 26 | return out_val; 27 | -------------------------------------------------------------------------------- /Etch/Benchmark/OldSQL.lean: -------------------------------------------------------------------------------- 1 | import Etch.Benchmark.Basic 2 | import Etch.Benchmark.SQL 3 | import Etch.ShapeInference 4 | import Etch.Stream 5 | 6 | namespace Etch.Benchmark.OldSQL 7 | 8 | def FSQLCallback : (E ℕ × E ℕ × E R) := 9 | (.call Op.atoi ![.access "argv" 0], 10 | .call Op.atoi ![.access "argv" 1], 11 | 1) 12 | 13 | def l_ssF : lvl ℕ (lvl ℕ (MemLoc R)) := dcsr "ssF" 14 | 15 | abbrev cause := (0, ℕ) 16 | abbrev year := (1, ℕ) 17 | abbrev objid := (2, ℕ) 18 | 19 | def fires : year ↠ₛ objid ↠ₛ E R := (SQL.ss "ssF" : ℕ →ₛ ℕ →ₛ E R) 20 | def range_06_08 : year ↠ₛ E R := (S.predRange 2006 2008 : ℕ →ₛ E R) 21 | def countRange := ∑ year, objid: range_06_08 * fires 22 | 23 | def funcs : List (String × String) := [ 24 | ("gen_query_fires.c", go l_ssF FSQLCallback), 25 | ("count_range", compileFun R "count_range" countRange) ] 26 | 27 | end Etch.Benchmark.OldSQL -------------------------------------------------------------------------------- /Etch/StreamFusion/Benchmark-PeopleMovieDistance.lean: -------------------------------------------------------------------------------- 1 | import Etch.StreamFusion.Stream 2 | import Etch.StreamFusion.Expand 3 | import Etch.StreamFusion.TestUtil 4 | namespace Etch.Verification.SStream 5 | 6 | variable {I J K α β : Type} 7 | 8 | abbrev PID := ℕ -- Person ID 9 | abbrev MID := ℕ -- Movie ID 10 | 11 | variable {I : Type} 12 | [LinearOrder I] 13 | 14 | abbrev pid : LabelIdx := LabelIdx.nth 0 15 | abbrev mid : LabelIdx := LabelIdx.nth 1 16 | abbrev i : LabelIdx := LabelIdx.nth 2 17 | 18 | @[inline] 19 | def peopleMovieDistance 20 | [ToStream P (PID →ₛ I →ₛ Float)] 21 | [ToStream M (MID →ₛ I →ₛ Float)] 22 | [ToStream R (PID →ₛ MID →ₛ Bool)] 23 | (personStream : P) 24 | (movieStream : M) 25 | (requestStream : R) 26 | : ℕ := 27 | let result := Σ i => requestStream(pid,mid) * personStream(pid,i) * movieStream(mid,i) 28 | 42 29 | 30 | end Etch.Verification.SStream 31 | -------------------------------------------------------------------------------- /Etch/Benchmark.lean: -------------------------------------------------------------------------------- 1 | import Etch.Basic 2 | import Etch.Stream 3 | import Etch.LVal 4 | import Etch.Add 5 | import Etch.Mul 6 | import Etch.Compile 7 | import Etch.ShapeInference 8 | import Etch.Benchmark.OldSQL 9 | import Etch.Benchmark.TACO 10 | import Etch.Benchmark.TPCHq5 11 | import Etch.Benchmark.TPCHq9 12 | import Etch.Benchmark.WCOJ 13 | 14 | open Etch.Benchmark 15 | 16 | private def files : List (String × List (String × String)) := [ 17 | ("old_sql", OldSQL.funcs), 18 | ("taco", TACO.funcs), 19 | ("matmul", TACO.funcsMatmul), 20 | ("filtered_spmv", TACO.funcsFilterSpMV), 21 | ("tpch_q5", TPCHq5.funcs), 22 | ("tpch_q9", TPCHq9.funcs), 23 | ("wcoj", WCOJ.funcs) ] 24 | 25 | def main : IO Unit := do 26 | for (f, ops) in files do 27 | let mut file := "" 28 | for x in ops do 29 | file := file.append (x.2 ++ "\n") 30 | IO.FS.writeFile s!"gen_{f}.c" file 31 | 32 | #eval main 33 | -------------------------------------------------------------------------------- /bench/tpch-q9-duckdb.sql: -------------------------------------------------------------------------------- 1 | -- https://github.com/duckdb/duckdb/blob/88b1bfa74d2b79a51ffc4bab18ddeb6a034652f1/extension/tpch/dbgen/queries/q09.sql 2 | -- Except ORDER BY 3 | 4 | SELECT 5 | nation, 6 | o_year, 7 | sum(amount) AS sum_profit 8 | FROM ( 9 | SELECT 10 | n_name AS nation, 11 | extract(year FROM o_orderdate) AS o_year, 12 | l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity AS amount 13 | FROM 14 | part, 15 | supplier, 16 | lineitem, 17 | partsupp, 18 | orders, 19 | nation 20 | WHERE 21 | s_suppkey = l_suppkey 22 | AND ps_suppkey = l_suppkey 23 | AND ps_partkey = l_partkey 24 | AND p_partkey = l_partkey 25 | AND o_orderkey = l_orderkey 26 | AND s_nationkey = n_nationkey 27 | AND p_name LIKE '%green%') AS profit 28 | GROUP BY 29 | nation, 30 | o_year; 31 | -------------------------------------------------------------------------------- /archive/data/smaller.mtx: -------------------------------------------------------------------------------- 1 | %%MatrixMarket matrix coordinate real general 2 | %------------------------------------------------------------------------------- 3 | % UF Sparse Matrix Collection, Tim Davis 4 | % http://www.cise.ufl.edu/research/sparse/matrices/DRIVCAV/cavity11 5 | % name: DRIVCAV/cavity11 6 | % [Driven Cavity 15 x 15, Reynolds number: 200] 7 | % id: 390 8 | % date: 1996 9 | % author: A. Chapman 10 | % ed: A. Baggag, Y. Saad 11 | % fields: title A Zeros b x name id kind notes date author ed 12 | % kind: subsequent computational fluid dynamics problem 13 | %------------------------------------------------------------------------------- 14 | % notes: 15 | % next: DRIVCAV/cavity12 first: DRIVCAV/cavity10 16 | % pattern is the same as the transpose of DRIVCAV/cavity10 17 | %------------------------------------------------------------------------------- 18 | 4 4 4 19 | 1 1 1.0 20 | 1 2 2 21 | 1 3 6 22 | 2 1 1 23 | 2 2 2.0 24 | 3 3 7 -------------------------------------------------------------------------------- /taco/sum_mttkrp.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t i = 0; i < C1_dimension; i++) { 5 | for (int32_t jC = C2_pos[i]; jC < C2_pos[(i + 1)]; jC++) { 6 | int32_t j = C2_crd[jC]; 7 | for (int32_t kC = C3_pos[jC]; kC < C3_pos[(jC + 1)]; kC++) { 8 | int32_t k = C3_crd[kC]; 9 | int32_t lA = A2_pos[j]; 10 | int32_t pA2_end = A2_pos[(j + 1)]; 11 | int32_t lB = B2_pos[k]; 12 | int32_t pB2_end = B2_pos[(k + 1)]; 13 | 14 | while (lA < pA2_end && lB < pB2_end) { 15 | int32_t lA0 = A2_crd[lA]; 16 | int32_t lB0 = B2_crd[lB]; 17 | int32_t l = TACO_MIN(lA0,lB0); 18 | if (lA0 == l && lB0 == l) { 19 | out_val += (C_vals[kC] * A_vals[lA]) * B_vals[lB]; 20 | } 21 | lA += (int32_t)(lA0 == l); 22 | lB += (int32_t)(lB0 == l); 23 | } 24 | } 25 | } 26 | } 27 | 28 | return out_val; 29 | -------------------------------------------------------------------------------- /taco/mttkrp.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t iC = C1_pos[0]; iC < C1_pos[1]; iC++) { 5 | for (int32_t jC = C2_pos[iC]; jC < C2_pos[(iC + 1)]; jC++) { 6 | int32_t j = C2_crd[jC]; 7 | for (int32_t kC = C3_pos[jC]; kC < C3_pos[(jC + 1)]; kC++) { 8 | int32_t k = C3_crd[kC]; 9 | int32_t lA = A2_pos[j]; 10 | int32_t pA2_end = A2_pos[(j + 1)]; 11 | int32_t lB = B2_pos[k]; 12 | int32_t pB2_end = B2_pos[(k + 1)]; 13 | 14 | while (lA < pA2_end && lB < pB2_end) { 15 | int32_t lA0 = A2_crd[lA]; 16 | int32_t lB0 = B2_crd[lB]; 17 | int32_t l = TACO_MIN(lA0,lB0); 18 | if (lA0 == l && lB0 == l) { 19 | out_val += (C_vals[kC] * A_vals[lA]) * B_vals[lB]; 20 | } 21 | lA += (int32_t)(lA0 == l); 22 | lB += (int32_t)(lB0 == l); 23 | } 24 | } 25 | } 26 | } 27 | 28 | return out_val; 29 | 30 | -------------------------------------------------------------------------------- /bench/tpch-q9-sqlite.sql: -------------------------------------------------------------------------------- 1 | -- https://github.com/duckdb/duckdb/blob/88b1bfa74d2b79a51ffc4bab18ddeb6a034652f1/extension/tpch/dbgen/queries/q09.sql 2 | -- Except ORDER BY 3 | 4 | SELECT 5 | nation, 6 | o_year, 7 | sum(amount) AS sum_profit 8 | FROM ( 9 | SELECT 10 | n_name AS nation, 11 | -- extract(year FROM o_orderdate) AS o_year, 12 | strftime('%Y', o_orderdate) AS o_year, 13 | -- substr(o_orderdate, 1, 4) AS o_year, 14 | l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity AS amount 15 | FROM 16 | part, 17 | supplier, 18 | lineitem, 19 | partsupp, 20 | orders, 21 | nation 22 | WHERE 23 | s_suppkey = l_suppkey 24 | AND ps_suppkey = l_suppkey 25 | AND ps_partkey = l_partkey 26 | AND p_partkey = l_partkey 27 | AND o_orderkey = l_orderkey 28 | AND s_nationkey = n_nationkey 29 | AND p_name LIKE '%green%') AS profit 30 | GROUP BY 31 | nation, 32 | o_year; 33 | -------------------------------------------------------------------------------- /taco/inner2ss.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | int32_t iA = A1_pos[0]; 5 | int32_t pA1_end = A1_pos[1]; 6 | int32_t iB = B1_pos[0]; 7 | int32_t pB1_end = B1_pos[1]; 8 | 9 | while (iA < pA1_end && iB < pB1_end) { 10 | int32_t iA0 = A1_crd[iA]; 11 | int32_t iB0 = B1_crd[iB]; 12 | int32_t i = TACO_MIN(iA0,iB0); 13 | 14 | if (iA0 == i && iB0 == i) { 15 | int32_t jA = A2_pos[iA]; 16 | int32_t pA2_end = A2_pos[(iA + 1)]; 17 | int32_t jB = B2_pos[iB]; 18 | int32_t pB2_end = B2_pos[(iB + 1)]; 19 | 20 | while (jA < pA2_end && jB < pB2_end) { 21 | int32_t jA0 = A2_crd[jA]; 22 | int32_t jB0 = B2_crd[jB]; 23 | int32_t j = TACO_MIN(jA0,jB0); 24 | if (jA0 == j && jB0 == j) { 25 | out_val += A_vals[jA] * B_vals[jB]; 26 | } 27 | jA += (int32_t)(jA0 == j); 28 | jB += (int32_t)(jB0 == j); 29 | } 30 | } 31 | iA += (int32_t)(iA0 == i); 32 | iB += (int32_t)(iB0 == i); 33 | } 34 | 35 | return out_val; 36 | 37 | -------------------------------------------------------------------------------- /taco/sum_add2.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0; 3 | 4 | for (int32_t i = 0; i < B1_dimension; i++) { 5 | 6 | int32_t jA = A2_pos[i]; 7 | int32_t pA2_end = A2_pos[(i + 1)]; 8 | int32_t jB = B2_pos[i]; 9 | int32_t pB2_end = B2_pos[(i + 1)]; 10 | 11 | while (jA < pA2_end && jB < pB2_end) { 12 | int32_t jA0 = A2_crd[jA]; 13 | int32_t jB0 = B2_crd[jB]; 14 | int32_t j = TACO_MIN(jA0,jB0); 15 | if (jA0 == j && jB0 == j) { 16 | out_val += A_vals[jA] + B_vals[jB]; 17 | } 18 | else if (jA0 == j) { 19 | out_val += A_vals[jA]; 20 | } 21 | else { 22 | out_val += B_vals[jB]; 23 | } 24 | jA += (int32_t)(jA0 == j); 25 | jB += (int32_t)(jB0 == j); 26 | } 27 | while (jA < pA2_end) { 28 | int32_t j = A2_crd[jA]; 29 | out_val += A_vals[jA]; 30 | jA++; 31 | } 32 | while (jB < pB2_end) { 33 | int32_t j = B2_crd[jB]; 34 | out_val += B_vals[jB]; 35 | jB++; 36 | } 37 | 38 | } 39 | 40 | return out_val; 41 | -------------------------------------------------------------------------------- /Etch/Verification/FinsuppLemmas.lean: -------------------------------------------------------------------------------- 1 | import Mathlib.Data.Finsupp.Basic 2 | import Mathlib.Data.Finsupp.Indicator 3 | import Mathlib.Data.Finsupp.Notation 4 | 5 | variable {α β γ : Type _} [AddCommMonoid β] 6 | 7 | open Classical 8 | 9 | noncomputable section 10 | 11 | namespace Finsupp 12 | 13 | def sumRange : (α →₀ β) →+ β where 14 | toFun f := (f.mapDomain default) () 15 | map_zero' := rfl 16 | map_add' := by simp [mapDomain_add] 17 | #align finsupp.sum_range Finsupp.sumRange 18 | 19 | variable (f g : α →₀ β) 20 | 21 | theorem sumRange_eq_sum : sumRange f = f.sum fun _ v => v := by simp [sumRange, mapDomain] 22 | #align finsupp.sum_range_eq_sum Finsupp.sumRange_eq_sum 23 | 24 | @[simp] 25 | theorem sumRange_single (x : α) (y : β) : sumRange (single x y) = y := by simp [sumRange] 26 | #align finsupp.sum_range_single Finsupp.sumRange_single 27 | 28 | @[simp] 29 | theorem mapDomain_sumRange (h : α → γ) : sumRange (f.mapDomain h) = sumRange f := by 30 | simp [sumRange, ← mapDomain_comp] 31 | #align finsupp.map_domain_sum_range Finsupp.mapDomain_sumRange 32 | 33 | end Finsupp 34 | -------------------------------------------------------------------------------- /archive/src/verification/semantics/replicate.lean: -------------------------------------------------------------------------------- 1 | import verification.semantics.dense 2 | 3 | /-! 4 | # Stream replicate 5 | 6 | This file defines replicate, which is a specialization of the dense vector 7 | stream to a constant function. 8 | 9 | **Note: the current version of the paper uses the terminology *expand*, which was from a previous version of the paper. 10 | The terms *replicate* and *expand* should be considered to be synonymous**. 11 | 12 | -/ 13 | 14 | variables {α β : Type*} 15 | 16 | @[derive is_bounded] 17 | def Stream.replicate (n : ℕ) (v : α) : Stream (fin n) α := 18 | Stream.denseVec (λ _, v) 19 | 20 | @[simp] lemma Stream.replicate_map (f : α → β) (n : ℕ) (v : α) : 21 | (Stream.replicate n v).map f = Stream.replicate n (f v) := rfl 22 | 23 | variables [add_zero_class α] 24 | 25 | instance (n : ℕ) (v : α) : is_strict_lawful (Stream.replicate n v) := 26 | by { dunfold Stream.replicate, apply_instance, } 27 | 28 | @[simp] lemma Stream.replicate_eval (n : ℕ) (v : α) (j : fin n) : 29 | (Stream.replicate n v).eval (0 : fin (n + 1)) j = v := 30 | by { dunfold Stream.replicate, simp, } 31 | -------------------------------------------------------------------------------- /bench/wcoj-datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | import sys 4 | 5 | # for a description of this join problem see https://arxiv.org/pdf/1310.3314.pdf, Figure 2 6 | def gen(n: int): 7 | all_ones = np.ones((n, 1), dtype=int) 8 | all_nums = np.array(range(1, n+1)).reshape((-1, 1)) 9 | num_one = np.hstack([all_ones, all_nums]) 10 | one_num = np.hstack([all_nums, all_ones]) 11 | 12 | # interweave the two arrays 13 | v = np.empty((2 * n - 1, 2), dtype=int) 14 | v[0::2] = num_one 15 | v[1::2] = one_num[1:] # skip duplicate 1,1 16 | return v 17 | 18 | 19 | def export(outdir: Path, tbl: str, arr: np.ndarray): 20 | outfile = outdir / f"{tbl}.csv" 21 | header = { 22 | "r": "a,b", 23 | "s": "b,c", 24 | "t": "a,c", 25 | } 26 | np.savetxt(outfile, arr, delimiter=",", header=header[tbl], comments='', fmt="%d") 27 | 28 | 29 | def main(outdir: Path, n: int): 30 | out = gen(n) 31 | export(outdir, "r", out) 32 | export(outdir, "s", out) 33 | export(outdir, "t", out) 34 | 35 | 36 | main(Path(sys.argv[1]), int(sys.argv[2])) 37 | -------------------------------------------------------------------------------- /taco/sum_inner3.c: -------------------------------------------------------------------------------- 1 | 2 | double out_val = 0.0; 3 | 4 | for (int32_t i = 0; i < D1_dimension; i++) { 5 | int32_t jC = C2_pos[i]; 6 | int32_t pC2_end = C2_pos[(i + 1)]; 7 | int32_t jD = D2_pos[i]; 8 | int32_t pD2_end = D2_pos[(i + 1)]; 9 | 10 | while (jC < pC2_end && jD < pD2_end) { 11 | int32_t jC0 = C2_crd[jC]; 12 | int32_t jD0 = D2_crd[jD]; 13 | int32_t j = TACO_MIN(jC0,jD0); 14 | if (jC0 == j && jD0 == j) { 15 | int32_t kC = C3_pos[jC]; 16 | int32_t pC3_end = C3_pos[(jC + 1)]; 17 | int32_t kD = D3_pos[jD]; 18 | int32_t pD3_end = D3_pos[(jD + 1)]; 19 | 20 | while (kC < pC3_end && kD < pD3_end) { 21 | int32_t kC0 = C3_crd[kC]; 22 | int32_t kD0 = D3_crd[kD]; 23 | int32_t k = TACO_MIN(kC0,kD0); 24 | if (kC0 == k && kD0 == k) { 25 | out_val += C_vals[kC] * D_vals[kD]; 26 | } 27 | kC += (int32_t)(kC0 == k); 28 | kD += (int32_t)(kD0 == k); 29 | } 30 | } 31 | jC += (int32_t)(jC0 == j); 32 | jD += (int32_t)(jD0 == j); 33 | } 34 | } 35 | 36 | return out_val; 37 | -------------------------------------------------------------------------------- /taco_kernels.c: -------------------------------------------------------------------------------- 1 | 2 | double taco_sum_mul2() { 3 | load_ssA(); 4 | load_ssB(); 5 | #include "taco/sum_mul2.c" 6 | } 7 | double taco_sum_add2() { 8 | load_ssA(); 9 | load_ssB(); 10 | #include "taco/sum_add2.c" 11 | } 12 | double taco_sum_mul2_csr() { 13 | load_ssA(); 14 | load_dsB(); 15 | #include "taco/sum_mul2_csr.c" 16 | } 17 | double taco_inner2ss() { 18 | load_ssA(); 19 | load_ssB(); 20 | #include "taco/inner2ss.c" 21 | } 22 | 23 | double taco_wcoj() { 24 | load_ssR(); 25 | load_ssT(); 26 | #include "taco/wcoj.c" 27 | } 28 | 29 | double taco_mttkrp() { 30 | load_dsA(); 31 | load_dsB(); 32 | load_sssC(); 33 | //printf("TODO\n"); 34 | #include "taco/mttkrp.c" 35 | return 0; 36 | } 37 | double taco_sum_mul2_inner() { 38 | load_ssA(); 39 | load_ssB(); 40 | load_dsA(); 41 | load_dsB(); 42 | #include "taco/sum_mul2_inner.c" 43 | } 44 | double taco_sum_mul2_inner_ss() { 45 | load_ssA(); 46 | load_ssB(); 47 | #include "taco/sum_mul2_inner_ss.c" 48 | } 49 | double taco_spmv() { 50 | load_ssA(); 51 | load_dV(); 52 | #include "taco/spmv.c" 53 | } 54 | double taco_filter_spmv() { 55 | //load_sV(); 56 | //load_dsA(); 57 | //#include "taco/spmv.c" 58 | return 0.0; 59 | } 60 | /* here end */ 61 | 62 | -------------------------------------------------------------------------------- /Etch/StreamFusion/Benchmark-CountEmployeesOfSmallCompanies.lean: -------------------------------------------------------------------------------- 1 | import Std.Data.HashMap 2 | 3 | import Etch.StreamFusion.Stream 4 | import Etch.StreamFusion.Expand 5 | import Etch.StreamFusion.Multiply 6 | import Etch.StreamFusion.TestUtil 7 | namespace Etch.Verification.SStream 8 | 9 | open Std (HashMap) 10 | 11 | abbrev Id := ℕ 12 | 13 | open ToStream 14 | open SStream 15 | 16 | def_index_enum_group 17 | eid, ename, 18 | cid, cname, state, 19 | companySize 20 | 21 | -- yields employee Ids who work for companies based in CA with at most 50 employees 22 | def employeesOfSmallCompanies 23 | (employee : (Id →ₛ String →ₛ Id →ₛ Bool)) 24 | (company : (Id →ₛ String →ₛ String →ₛ Bool)) := 25 | -- label columns 26 | let employee := employee(eid,ename,cid) 27 | let company := company(cid,cname,state) 28 | -- convert `Bool` entries to 0/1 29 | let company := Bool.toNat $$[state] company 30 | -- count employees per company in CA 31 | let counts := memo HashMap Id ℕ from 32 | select cid => employee * I(state = "CA") * company 33 | let counts := (fun id => singleton (counts id))(cid, companySize) 34 | --let counts := (counts.map singleton)(cid, companySize) 35 | let small := I(companySize ≤ 50) 36 | -- get result of shape eid~Id →ₛ Bool 37 | select eid => small * counts * employee 38 | 39 | end Etch.Verification.SStream 40 | -------------------------------------------------------------------------------- /Etch/StreamFusion/Main.lean: -------------------------------------------------------------------------------- 1 | import Std.Data.RBMap 2 | import Std.Data.HashMap 3 | 4 | import Etch.StreamFusion.Basic 5 | import Etch.StreamFusion.Stream 6 | import Etch.StreamFusion.Multiply 7 | import Etch.StreamFusion.Expand 8 | import Etch.StreamFusion.Traversals 9 | import Etch.StreamFusion.TestUtil 10 | 11 | open Std (RBMap RBSet HashMap) 12 | open Etch.Verification 13 | open SStream RB 14 | open ToStream 15 | 16 | namespace test 17 | 18 | 19 | def vecMul_rb (num : Nat) : IO Unit := do 20 | IO.println "-----------" 21 | let v := vecStream num 22 | let s := v * (v.map fun _ => 1) 23 | time "vec mul rb" fun _ => 24 | for _ in [0:10] do 25 | let x : RBMap ℕ ℕ Ord.compare := eval s 26 | IO.println s!"{x.1.size}" 27 | pure () 28 | 29 | def vecMul_hash (num : Nat) : IO Unit := do 30 | IO.println "-----------" 31 | let v := vecStream num 32 | let v' := vecStream num |>.map fun _ => 1 33 | let s := v * v' 34 | time "vec mul hash" fun _ => 35 | for _ in [0:10] do 36 | let x : HashMap ℕ ℕ := eval s 37 | IO.println s!"{x.1.size}" 38 | pure () 39 | 40 | def_index_enum_group i,j 41 | 42 | end test 43 | 44 | def tests (args : List String) : IO Unit := do 45 | let num := (args[0]!).toNat?.getD 1000 46 | IO.println s!"test of size {num}" 47 | IO.println "starting" 48 | 49 | pure () 50 | 51 | def main := tests 52 | -------------------------------------------------------------------------------- /lakefile.lean: -------------------------------------------------------------------------------- 1 | import Lake 2 | open Lake DSL 3 | 4 | package etch 5 | 6 | lean_lib Etch where defaultFacets := #[LeanLib.sharedFacet] 7 | 8 | @[default_target] 9 | lean_exe bench { 10 | root := `Etch.Benchmark 11 | } 12 | 13 | @[default_target] 14 | lean_exe proofs { 15 | root := `Etch.StreamFusion.Proofs.Proofs 16 | } 17 | 18 | @[default_target] 19 | lean_exe fusion { 20 | root := `Etch.StreamFusion.Main 21 | moreLeancArgs := #["-fno-omit-frame-pointer", "-g"] 22 | } 23 | 24 | lean_exe fusion_mat { 25 | root := `Etch.StreamFusion.MainMatrix 26 | moreLeancArgs := #["-fno-omit-frame-pointer", "-g"] 27 | } 28 | 29 | @[default_target] 30 | lean_exe eg { 31 | root := `Etch.StreamFusion.Examples.Benchmarks 32 | moreLeancArgs := #["-fno-omit-frame-pointer", "-g"] 33 | } 34 | 35 | lean_exe tutorial { 36 | root := `Etch.StreamFusion.Tutorial 37 | moreLeancArgs := #["-fno-omit-frame-pointer", "-g"] 38 | } 39 | 40 | lean_exe reuse { 41 | root := `Etch.StreamFusion.ReuseTest 42 | moreLeancArgs := #["-fno-omit-frame-pointer", "-g"] 43 | } 44 | 45 | lean_exe seq { 46 | root := `Etch.StreamFusion.Sequence 47 | moreLeancArgs := #["-fno-omit-frame-pointer", "-g"] 48 | } 49 | 50 | @[default_target] 51 | lean_lib Etch.Verification.Semantics.Example 52 | 53 | require mathlib from git 54 | "https://github.com/leanprover-community/mathlib4/"@"3897434e80c1e66658416557947b9b9604e336a7" 55 | -------------------------------------------------------------------------------- /archive/src/verification/semantics/zero.lean: -------------------------------------------------------------------------------- 1 | import verification.semantics.skip_stream 2 | 3 | /-! 4 | 5 | # Zero stream 6 | 7 | In this file, we define the zero stream, which immediately terminates 8 | producing no output. This stream is important for defining 9 | a nested version of `add`, since `add` requires the value type to 10 | have a `0`. 11 | 12 | ## Main results 13 | All the results in this file are trivial (i.e. follow from `false.elim : false → C`) 14 | because the stream itself does not produce anything, and has no valid states. 15 | 16 | -/ 17 | 18 | def Stream.zero (ι : Type) (α : Type*) : Stream ι α := 19 | { σ := unit, 20 | valid := λ _, false, 21 | ready := λ _, false, 22 | skip := λ _, false.elim, 23 | index := λ _, false.elim, 24 | value := λ _, false.elim } 25 | 26 | variables {ι : Type} [linear_order ι] {α β : Type*} 27 | 28 | instance : is_bounded (Stream.zero ι α) := 29 | ⟨⟨empty_relation, empty_wf, λ q, false.drec _⟩⟩ 30 | 31 | @[simp] lemma Stream.zero_map (f : α → β) : 32 | (Stream.zero ι α).map f = Stream.zero ι β := 33 | by { ext; solve_refl, exfalso, assumption, } 34 | 35 | variables [add_zero_class α] 36 | 37 | @[simp] lemma Stream.zero_eval : 38 | (Stream.zero ι α).eval = 0 := 39 | by { ext q i, rw Stream.eval_invalid, { simp, }, exact not_false, } 40 | 41 | instance : is_strict_lawful (Stream.zero ι α) := 42 | { mono := λ q, false.drec _, 43 | skip_spec := λ q, false.drec _, 44 | strict_mono := ⟨λ q, false.drec _, λ q, false.drec _⟩ } 45 | -------------------------------------------------------------------------------- /Etch/Benchmark/WCOJ.lean: -------------------------------------------------------------------------------- 1 | import Etch.Benchmark.Basic 2 | import Etch.Benchmark.SQL 3 | import Etch.LVal 4 | import Etch.Mul 5 | import Etch.ShapeInference 6 | import Etch.Stream 7 | 8 | namespace Etch.Benchmark.WCOJ 9 | 10 | -- For data loading 11 | 12 | def SQLCallback : (E ℕ × E ℕ × E ℕ) := 13 | (.call Op.atoi ![.access "argv" 0], 14 | .call Op.atoi ![.access "argv" 1], 15 | 1) 16 | 17 | def load_ss (l : String) : lvl ℕ (lvl ℕ (Dump ℕ)) := 18 | (interval_vl $ l ++ "1_pos").value 0 |> 19 | (with_values (sparse_il (l ++ "1_crd" : ArrayVar ℕ)) (interval_vl $ l ++ "2_pos")) ⊚ 20 | (without_values (sparse_il (l ++ "2_crd" : ArrayVar ℕ))) 21 | def l_dsR : lvl ℕ (lvl ℕ (Dump ℕ)) := load_ss "dsR" 22 | def l_dsS : lvl ℕ (lvl ℕ (Dump ℕ)) := load_ss "dsS" 23 | def l_dsT : lvl ℕ (lvl ℕ (Dump ℕ)) := load_ss "dsT" 24 | 25 | abbrev a := (0, ℕ) 26 | abbrev b := (1, ℕ) 27 | abbrev c := (2, ℕ) 28 | 29 | def r : a ↠ₛ b ↠ₛ E ℕ := (SQL.ss "dsR" : ℕ →ₛ ℕ →ₛ E ℕ) 30 | def s : b ↠ₛ c ↠ₛ E ℕ := (SQL.ss "dsS" : ℕ →ₛ ℕ →ₛ E ℕ) 31 | def t : a ↠ₛ c ↠ₛ E ℕ := (SQL.ss "dsT" : ℕ →ₛ ℕ →ₛ E ℕ) 32 | def out := ∑ a, b, c: r * s * t 33 | 34 | def funcs : List (String × String) := [ 35 | let fn := "gen_callback_wcoj_R"; (fn, compileSqliteCb fn [go l_dsR SQLCallback]), 36 | let fn := "gen_callback_wcoj_S"; (fn, compileSqliteCb fn [go l_dsS SQLCallback]), 37 | let fn := "gen_callback_wcoj_T"; (fn, compileSqliteCb fn [go l_dsT SQLCallback]), 38 | let fn := "wcoj"; (fn, compileFun ℕ fn out) 39 | ] 40 | 41 | end Etch.Benchmark.WCOJ 42 | -------------------------------------------------------------------------------- /Etch/Basic.lean: -------------------------------------------------------------------------------- 1 | --set_option trace.Meta.synthInstance.instances true 2 | --set_option pp.all true 3 | import Mathlib.Algebra.Ring.Basic 4 | 5 | -- not working yet 6 | def Fin.nil {β : Fin 0 → Type _} : (i : Fin 0) → β i := (nomatch .) 7 | syntax "!![" withoutPosition(sepBy(term, ", ")) "]" : term 8 | open Lean in 9 | macro_rules 10 | | `(!![ $elems,* ]) => (elems.getElems.foldrM (mkIdent `Fin.nil) (f := fun x xs ↦ `(Fin.cons $x $xs)) : MacroM Term) 11 | 12 | instance : Add Bool := ⟨ or ⟩ 13 | instance : Mul Bool := ⟨ and ⟩ 14 | 15 | -- todo, generalize? 16 | abbrev Fin.mk1 {γ : Fin 1 → Type _} (a : γ 0) : (i : Fin 1) → (γ i) | 0 => a 17 | abbrev Fin.mk2 {γ : Fin 2 → Type _} (a : γ 0) (b : γ 1) : (i : Fin 2) → (γ i) | 0 => a | 1 => b 18 | abbrev Fin.mk3 {γ : Fin 3 → Type _} (a : γ 0) (b : γ 1) (c : γ 2) : (i : Fin 3) → (γ i) | 0 => a | 1 => b | 2 => c 19 | 20 | set_option quotPrecheck false 21 | notation "![]" => (λ i => nomatch i : (_ : Fin 0) → _) 22 | set_option quotPrecheck true 23 | notation "![" a "]" => Fin.mk1 a 24 | notation "![" a "," b "]" => Fin.mk2 a b 25 | notation "![" a "," b "," c "]" => Fin.mk3 a b c 26 | 27 | def rev_fmap_comp {f} [Functor f] (x : α → f β) (y : β → γ) := Functor.map y ∘ x 28 | infixr:90 "⊚" => rev_fmap_comp 29 | -- todo remove 30 | def rev_app : α → (α → β) → β := Function.swap (. $ .) 31 | infixr:9 "&" => rev_app 32 | 33 | abbrev DecidableLE (α : Type u) [LE α] := @DecidableRel α LE.le 34 | abbrev DecidableLT (α : Type u) [LT α] := @DecidableRel α LT.lt 35 | 36 | instance : Zero (Option α) where zero := none -------------------------------------------------------------------------------- /Etch/Mul.lean: -------------------------------------------------------------------------------- 1 | import Etch.Stream 2 | 3 | variable {ι : Type} [Tagged ι] [DecidableEq ι] 4 | 5 | def S.mul [HMul α β γ] [Max ι] (a : S ι α) (b : S ι β) : (S ι γ) where 6 | σ := a.σ × b.σ 7 | value p := a.value p.1 * b.value p.2 8 | skip p i := a.skip p.1 i;; b.skip p.2 i 9 | succ p i := a.succ p.1 i;; b.succ p.2 i 10 | ready p := a.ready p.1 * b.ready p.2 * (a.index p.1 == b.index p.2) 11 | index p := .call .max ![a.index p.1, b.index p.2] 12 | valid p := a.valid p.1 * b.valid p.2 13 | init := seqInit a b 14 | 15 | instance [Mul α] [Max ι] : Mul (S ι α) := ⟨S.mul⟩ 16 | instance [HMul α β γ] [Max ι] : HMul (S ι α) (S ι β) (S ι γ) := ⟨S.mul⟩ 17 | 18 | instance [HMul α β γ] : HMul (ι →ₛ α) (ι →ₐ β) (ι →ₛ γ) where hMul a b := {a with value := λ s => a.value s * b (a.index s)} 19 | instance [HMul β α γ] : HMul (ι →ₐ β) (ι →ₛ α) (ι →ₛ γ) where hMul b a := {a with value := λ s => b (a.index s) * a.value s} 20 | instance [HMul α β γ] : HMul (ι →ₐ α) (ι →ₐ β) (ι →ₐ γ) where hMul a b := λ v => a v * b v 21 | 22 | instance : HMul (ι →ₛ α) (ι →ₐ E Bool) (ι →ₛ α) where hMul a b := 23 | { a with ready := fun p => a.ready p * b (a.index p), 24 | skip := fun p i => 25 | .if1 (a.ready p * -(b (a.index p))) 26 | (a.succ p i);; 27 | (a.skip p i) } 28 | 29 | instance : HMul (ι →ₐ E Bool) (ι →ₛ β) (ι →ₛ β) where hMul a b := 30 | { b with ready := fun p => a (b.index p) * b.ready p, 31 | skip := fun p i => 32 | .if1 (-(a (b.index p)) * b.ready p) 33 | (b.succ p i);; 34 | (b.skip p i) } 35 | -------------------------------------------------------------------------------- /archive/src/verification/semantics/examples.lean: -------------------------------------------------------------------------------- 1 | import verification.semantics.nested_eval 2 | 3 | /-! 4 | # Examples 5 | 6 | This file instantiates the abstract theorems in `LawfulEval` 7 | in some concrete cases. The examples here correspond to the figures in the paper. 8 | 9 | -/ 10 | 11 | variables {ι₁ ι₂ ι₃ : Type} [linear_order ι₁] [linear_order ι₂] 12 | [linear_order ι₃] {R : Type} [semiring R] 13 | 14 | open Eval (eval) 15 | 16 | example (a b c d : ι₁ ⟶ₛ ι₂ ⟶ₛ ι₃ ⟶ₛ R) : 17 | eval (a * (b + c) * d) = 18 | (eval a) * ((eval b) + (eval c)) * (eval d) := 19 | by simp 20 | 21 | local notation `∑ᵢ ` s := s.contract 22 | 23 | open_locale big_operators 24 | 25 | /-- This is the more Lean appropriate way to state the next example -/ 26 | example (a b : ι₁ ⟶ₛ ι₂ ⟶ₛ R) (j : ι₂) : 27 | (eval (∑ᵢ (a * b)) : unit →₀ ι₂ →₀ R) () j = 28 | (finsupp.sum_range (eval a * eval b) : ι₂ →₀ R) j := 29 | by simp 30 | 31 | -- Unfortunately, Lean doesn't like the notation `eval s x y` because it doesn't know `eval s x` is going to be a function 32 | @[reducible] noncomputable def eval₂ {ι₁ ι₂ R : Type*} [linear_order ι₁] [linear_order ι₂] [semiring R] 33 | (x : ι₁ ⟶ₛ ι₂ ⟶ₛ R) : ι₁ →₀ ι₂ →₀ R := eval x 34 | 35 | /-- This is the same as the previous example, but `finsupp.sum_range` 36 | is changed to a summation notation that might be more understandable 37 | because it is closer to "math" notation. -/ 38 | example (a b : ι₁ ⟶ₛ ι₂ ⟶ₛ R) (j : ι₂) : 39 | (eval (∑ᵢ (a * b)) : unit →₀ ι₂ →₀ R) () j = 40 | ∑ i in (eval₂ a * eval₂ b).support, 41 | (eval₂ a i j * eval₂ b i j) := 42 | by simp [finsupp.sum_range_eq_sum, finsupp.sum, finsupp.finset_sum_apply] 43 | -------------------------------------------------------------------------------- /bench/matmul-datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sqlite3 3 | import sys 4 | from pathlib import Path 5 | 6 | 7 | def makeA(p=0.1, nonzeros=400000): 8 | # 2000 × 2000 × 0.1 = 400000 9 | 10 | n = int(np.round(np.sqrt(nonzeros / p))) 11 | matrix_size = n * n 12 | m = np.random.rand(nonzeros) 13 | m *= (matrix_size - nonzeros) / np.sum(m) 14 | m += 1 15 | 16 | # Extra correction 17 | # print(np.sum(np.round(m))) 18 | m /= np.sum(np.round(m)) / matrix_size 19 | # print(np.sum(np.round(m))) 20 | m /= np.sum(np.round(m)) / matrix_size 21 | # print(np.sum(np.round(m))) 22 | 23 | result = set() 24 | last = 0 25 | for r in m: 26 | if last >= matrix_size: 27 | break 28 | i, j = last // n, last % n 29 | result.add((i, j, float(r))) 30 | last += int(np.round(r)) 31 | print( 32 | f"expected={nonzeros} actual={len(result)} expect_sparsity={p} actual_sparsity={len(result) / matrix_size}" 33 | ) 34 | return result 35 | 36 | 37 | def main(db: Path = Path("data/pldi.db"), p: float = 0.0002, nonzeros: int = 20000): 38 | c = sqlite3.connect(str(db)) 39 | c.execute("DROP TABLE IF EXISTS A") 40 | c.execute("DROP TABLE IF EXISTS B") 41 | c.execute("CREATE TABLE A(i INTEGER NOT NULL, j INTEGER NOT NULL, v REAL NOT NULL)") 42 | c.execute("CREATE TABLE B(i INTEGER NOT NULL, j INTEGER NOT NULL, v REAL NOT NULL)") 43 | print("A") 44 | c.executemany(f"INSERT INTO A VALUES(?,?,?)", makeA(p, nonzeros)) 45 | print("B") 46 | c.executemany(f"INSERT INTO B VALUES(?,?,?)", makeA(p, nonzeros)) 47 | c.commit() 48 | 49 | 50 | main(Path(sys.argv[1]), float(sys.argv[2]), int(sys.argv[3])) 51 | -------------------------------------------------------------------------------- /impls_readable.h: -------------------------------------------------------------------------------- 1 | #if 0 2 | // Not the actual implementations, but substitution rules to make Etch compiler output more readable. 3 | // See Makefile. 4 | #endif 5 | 6 | #define num_add(a, b) (a + b) 7 | #define num_sub(a, b) (a - b) 8 | #define num_mul(a, b) (a * b) 9 | #define num_one() 1 10 | #define num_zero() 0 11 | #define num_lt(a, b) (a < b) 12 | #define num_le(a, b) (a <= b) 13 | #define num_eq(a, b) (a == b) 14 | #define num_max(a, b) max(a, b) 15 | #define num_min(a, b) min(a, b) 16 | #define num_succ(a) (a + 1) 17 | #define num_neg(a) (!a) 18 | #define num_ofBool(x) (x ? 1. : 0.) 19 | #define num_toMin(x) (x) 20 | #define num_toMax(x) (x) 21 | #define num_toNum(x) (x) 22 | 23 | #define nat_add(a, b) (a + b) 24 | #define nat_sub(a, b) (a - b) 25 | #define nat_mul(a, b) (a * b) 26 | #define nat_one() 1 27 | #define nat_zero() 0 28 | #define nat_lt(a, b) (a < b) 29 | #define nat_le(a, b) (a <= b) 30 | #define nat_eq(a, b) (a == b) 31 | #define nat_max(a, b) max(a, b) 32 | #define nat_min(a, b) min(a, b) 33 | #define nat_succ(a) (a + 1) 34 | #define nat_neg(a) (!a) 35 | 36 | #define int_add(a, b) (a + b) 37 | #define int_sub(a, b) (a - b) 38 | #define int_mul(a, b) (a * b) 39 | #define int_one() 1 40 | #define int_zero() 0 41 | #define int_lt(a, b) (a < b) 42 | #define int_le(a, b) (a <= b) 43 | #define int_eq(a, b) (a == b) 44 | #define int_max(a, b) max(a, b) 45 | #define int_min(a, b) min(a, b) 46 | #define int_succ(a) (a + 1) 47 | 48 | #define bool_add(a, b) (a || b) 49 | #define bool_mul(a, b) (a && b) 50 | #define bool_one() true 51 | #define bool_zero() false 52 | #define bool_neg(a) (!a) 53 | 54 | #define str_zero() "" 55 | #define macro_ternary(c, x, y) ((c) ? x : y) 56 | #define index_map(a, ...) &a[{__VA_ARGS__}] 57 | -------------------------------------------------------------------------------- /archive/src/verification/code_generation/frames.lean: -------------------------------------------------------------------------------- 1 | import data.set.function 2 | 3 | namespace function 4 | variables {α β γ δ : Type*} (f : (α → β) → γ) 5 | 6 | def has_frame (S : set α) : Prop := 7 | ∃ (g : (S → β) → γ), f = g ∘ (set.restrict S) 8 | 9 | variables {f} {S : set α} 10 | 11 | theorem has_frame_iff [nonempty β] : has_frame f S ↔ ∀ ⦃c₁ c₂ : α → β⦄, (∀ x ∈ S, c₁ x = c₂ x) → f c₁ = f c₂ := 12 | begin 13 | split, 14 | { rintro ⟨g, rfl⟩, intros c₁ c₂ h, dsimp only [function.comp_app], 15 | congr' 1, ext, simp only [set.restrict_apply], apply h, exact subtype.mem x, }, 16 | classical, 17 | intro h, 18 | use (λ c : S → β, f (λ v, if h : v ∈ S then c ⟨v, h⟩ else nonempty.some infer_instance)), 19 | ext c, simp only [function.comp_app], apply h, intros x hx, simp [hx], 20 | end 21 | 22 | theorem has_frame.mono {S'} (h : has_frame f S) (hS' : S ⊆ S') : has_frame f S' := 23 | by { rcases h with ⟨g, rfl⟩, use (λ c : S' → β, g (λ v, c ⟨v.1, hS' v.2⟩)), ext x, simp, congr, } 24 | 25 | theorem has_frame.const (α β : Type*) (C : γ) : has_frame (const (α → β) C) ∅ := 26 | ⟨λ _, C, by { ext, simp, }⟩ 27 | 28 | theorem has_frame.postcomp (h : has_frame f S) (g : γ → δ) : 29 | has_frame (g ∘ f) S := 30 | by { rcases h with ⟨g', rfl⟩, use (g ∘ g'), } 31 | 32 | end function 33 | 34 | section examples 35 | 36 | def test_fun (f : ℕ → ℤ) : ℤ := (f 0) + (f 1) * (f 2) 37 | 38 | theorem test_fun_frame : function.has_frame test_fun {0, 1, 2} := 39 | begin 40 | rw function.has_frame_iff, 41 | intros c₁ c₂ h, 42 | simp [test_fun, h], 43 | end 44 | 45 | def test_fun₂ (f : ℕ → ℤ) : ℤ := if test_fun f = 5 then f 3 else -(f 3) 46 | 47 | theorem test_fun₂_frame : function.has_frame test_fun₂ {0, 1, 2, 3} := 48 | begin 49 | rw function.has_frame_iff, 50 | intros c₁ c₂ h, 51 | simp [test_fun₂], 52 | end 53 | 54 | end examples 55 | -------------------------------------------------------------------------------- /Etch/StreamFusion/Proofs/NestedEval.lean: -------------------------------------------------------------------------------- 1 | import Etch.StreamFusion.Proofs.StreamProof 2 | 3 | namespace Etch.Verification.Stream 4 | 5 | section bdd_stream 6 | 7 | variable (ι : Type) [Preorder ι] 8 | 9 | structure BddSStream (α : Type*) extends SStream ι α := 10 | (bdd : IsBounded toStream) 11 | 12 | attribute [instance] BddSStream.bdd 13 | 14 | infixr:25 " →ₛb " => BddSStream 15 | 16 | variable {ι} {α : Type*} 17 | 18 | @[macro_inline] def BddSStream.map {α β : Type*} (f : α → β) (s : ι →ₛb α) : ι →ₛb β := 19 | { s with 20 | value := f ∘ s.value 21 | bdd := ⟨s.bdd.out⟩ } 22 | 23 | @[simp] lemma BddSStream.map_eq_map {α β : Type*} (f : α → β) (s : ι →ₛb α) : 24 | (BddSStream.map f s).toSStream = s.toSStream.map f := rfl 25 | 26 | @[inline, simp] def BddSStream.fold {α : Type*} (f : β → ι → α → β) (s : ι →ₛb α) (b : β) : β := 27 | s.toStream.fold_wf f s.q b 28 | 29 | noncomputable def BddSStream.eval [AddZeroClass α] (s : ι →ₛb α) : ι →₀ α := 30 | s.toStream.eval s.q 31 | 32 | instance : Zero (ι →ₛb α) where 33 | zero := { toSStream := 0, bdd := inferInstanceAs (IsBounded 0) } 34 | 35 | @[simp] lemma zero_toStream : (0 : ι →ₛb α).toStream = 0 := rfl 36 | @[simp] lemma zero_state : (0 : ι →ₛb α).q = () := rfl 37 | 38 | end bdd_stream 39 | 40 | class EvalToFinsupp (α : Type*) (β : outParam (Type*)) [Zero α] [Zero β] where 41 | evalFinsupp : ZeroHom α β 42 | 43 | open EvalToFinsupp 44 | 45 | @[simps] 46 | instance [Scalar α] [AddZeroClass α] : EvalToFinsupp α α where 47 | evalFinsupp := ⟨id, rfl⟩ 48 | 49 | @[simps] 50 | noncomputable instance BddSStream.instEvalToFinsupp [LinearOrder ι] [Zero α] [AddZeroClass β] [EvalToFinsupp α β] : EvalToFinsupp (ι →ₛb α) (ι →₀ β) where 51 | evalFinsupp := ⟨fun f => (f.map evalFinsupp).eval, by 52 | change (0 : ι →ₛb β).eval = 0 53 | dsimp [BddSStream.eval] 54 | simp 55 | ⟩ 56 | 57 | end Etch.Verification.Stream 58 | -------------------------------------------------------------------------------- /Etch/StreamFusion/SequentialStream.lean: -------------------------------------------------------------------------------- 1 | import Mathlib.Data.Prod.Lex 2 | import Mathlib.Data.String.Basic 3 | import Init.Data.Array.Basic 4 | import Std.Data.RBMap 5 | import Std.Data.HashMap 6 | 7 | import Etch.StreamFusion.Basic 8 | 9 | open Std (RBMap HashMap) 10 | 11 | namespace Etch.Verification 12 | 13 | structure SequentialStream (ι : Type) (α : Type u) where 14 | σ : Type 15 | q : σ 16 | valid : σ → Bool 17 | index : {x // valid x} → ι 18 | next : {x // valid x} → σ 19 | ready : {x // valid x} → Bool 20 | value : {x // ready x} → α 21 | 22 | infixr:25 " →ₛ! " => SequentialStream 23 | 24 | namespace SequentialStream 25 | 26 | class ToStream (α : Type u) (β : outParam $ Type v) where 27 | stream : α → β 28 | 29 | --instance instBase [Scalar α] [Add α] : OfStream α α where 30 | -- eval := Add.add 31 | 32 | @[inline] partial def fold (f : β → ι → α → β) (s : SequentialStream ι α) (acc : β) : β := 33 | let rec @[specialize] go f 34 | (valid : s.σ → Bool) (ready : (x : s.σ) → valid x → Bool) 35 | (index : (x : s.σ) → valid x → ι) (value : (x : s.σ) → (h : valid x) → ready x h → α) 36 | (next : {x // valid x} → s.σ) 37 | --(next : (x : s.σ) → valid x → Bool → s.σ) 38 | (acc : β) (q : s.σ) := 39 | if hv : valid q then 40 | let i := index q hv 41 | let hr := ready q hv 42 | let acc' := if hr : hr then f acc i (value q hv hr) else acc 43 | let q' := next ⟨q, hv⟩ 44 | go f valid ready index value next acc' q' 45 | else acc 46 | go f s.valid (fun q h => s.ready ⟨q,h⟩) (fun q h => s.index ⟨q,h⟩) (fun q v r => s.value ⟨⟨q,v⟩,r⟩) s.next 47 | acc s.q 48 | 49 | @[macro_inline] 50 | def map (f : α → β) (s : ι →ₛ! α) : ι →ₛ! β := { 51 | s with value := fun x => f (s.value x) 52 | } 53 | 54 | @[simps, macro_inline] 55 | def contract (s : ι →ₛ! α) : Unit →ₛ! α := { s with index := default } 56 | 57 | end SequentialStream 58 | 59 | end Etch.Verification 60 | -------------------------------------------------------------------------------- /Etch/Compile.lean: -------------------------------------------------------------------------------- 1 | import Etch.Basic 2 | import Etch.Stream 3 | import Etch.LVal 4 | import Etch.Add 5 | import Etch.Mul 6 | import Etch.ShapeInference 7 | 8 | class Compile (location value : Type _) where compile : Name → location → value → P 9 | 10 | section Compile 11 | open Compile 12 | 13 | variable {L R} 14 | 15 | instance base_var [Tagged α] [Add α] : Compile (Var α) (E α) where 16 | compile _ l v := .store_var l (E.var l + v) 17 | 18 | instance base_mem [Tagged α] [Add α] : Compile (MemLoc α) (E α) where 19 | compile _ l v := .store_mem l.arr l.ind (l.access + v) 20 | 21 | instance base_dump [Tagged α] : Compile (Dump α) (E α) where 22 | compile _ _ _ := .skip 23 | 24 | instance S.step [Compile L R] [TaggedC ι] : Compile (lvl ι L) (ι →ₛ R) where 25 | compile n l r := 26 | let (init, s) := r.init n 27 | let (push, position) := l.push (r.index s) 28 | let temp := ("index_lower_bound" : Var ι).fresh n 29 | init;; .while (r.valid s) 30 | (.decl temp (r.index s);; 31 | .branch (r.ready s) 32 | (push;; compile (n.fresh 0) position (r.value s);; (r.succ s temp)) 33 | (r.skip s temp)) 34 | 35 | instance S.step' {n} [Compile L R] [TaggedC ι] : Compile (lvl ι L) (n × ι ⟶ₛ R) where 36 | compile := fun n l (.str s) => S.step.compile n l s 37 | 38 | instance contract [Compile α β] : Compile α (Contraction β) where 39 | compile n := λ storage ⟨ι, _, v⟩ => 40 | let (init, s) := v.init n 41 | let temp := ("index_lower_bound" : Var ι).fresh n 42 | init ;; .while (v.valid s) 43 | (.decl temp (v.index s);; 44 | .branch (v.ready s) 45 | (Compile.compile (n.fresh 0) storage (v.value s);; v.succ s temp) 46 | (v.skip s temp)) 47 | 48 | -- Used only to generate callback for data loading 49 | instance [Compile α β] : Compile (lvl ι α) (E ι × β) where 50 | compile n := λ storage v => 51 | let (push, position) := storage.push v.1 52 | push;; Compile.compile n.freshen position v.2 53 | 54 | end Compile 55 | 56 | def go [Compile α β] (l : α) (r : β) : String := (Compile.compile emptyName l r).compile.emit.run 57 | -------------------------------------------------------------------------------- /Etch/Benchmark/TPCHq5.lean: -------------------------------------------------------------------------------- 1 | import Etch.Benchmark.Basic 2 | import Etch.Benchmark.SQL 3 | import Etch.LVal 4 | import Etch.Mul 5 | import Etch.ShapeInference 6 | import Etch.Stream 7 | 8 | namespace Etch.Benchmark.TPCHq5 9 | 10 | -- Schema 11 | 12 | abbrev orderkey := (0, ℕ) 13 | abbrev orderdate := (1, ℕ) 14 | abbrev custkey := (2, ℕ) 15 | abbrev suppkey := (3, ℕ) 16 | abbrev nationkey := (4, ℕ) 17 | abbrev regionkey := (5, ℕ) 18 | abbrev nationname := (6, String) 19 | abbrev regionname := (7, String) 20 | abbrev extendedprice := (8, R) 21 | abbrev discount := (9, R) 22 | 23 | def lineitem : orderkey ↠ₛ suppkey ↠ₛ extendedprice ↠ₛ discount ↠ₛ E R := 24 | (SQL.ss__ "tpch_lineitem" : ℕ →ₛ ℕ →ₛ R →ₛ R →ₛ E R) 25 | 26 | def revenue_calc' : R →ₐ R →ₐ E R := fun p d => p * (1 - d) 27 | def revenue_calc : extendedprice ↠ₐ discount ↠ₐ E R := revenue_calc' 28 | def lineitem_revenue : orderkey ↠ₛ suppkey ↠ₛ extendedprice ↠ₛ discount ↠ₛ E R := lineitem * revenue_calc 29 | 30 | def orders : orderkey ↠ₐ orderdate ↠ₛ custkey ↠ₛ E R := (SQL.dss "tpch_orders" .binarySearch : ℕ →ₐ ℕ →ₛ ℕ →ₛ E R) 31 | def customer : custkey ↠ₐ nationkey ↠ₛ E R := (SQL.ds "tpch_customer" : ℕ →ₐ ℕ →ₛ E R) 32 | def supplier : suppkey ↠ₐ nationkey ↠ₛ E R := (SQL.ds "tpch_supplier" : ℕ →ₐ ℕ →ₛ E R) 33 | def nation : nationkey ↠ₐ regionkey ↠ₛ nationname ↠ₛ E R := (SQL.ds_ "tpch_nation" : ℕ →ₐ ℕ →ₛ String →ₛ E R) 34 | def region : regionkey ↠ₐ regionname ↠ₛ E R := (SQL.ds "tpch_region" : ℕ →ₐ String →ₛ E R) 35 | 36 | -- Query 37 | 38 | def asia : regionname ↠ₛ E R := (S.predRangeIncl "ASIA" "ASIA" : String →ₛ E R) 39 | 40 | def year1994unix := 757411200 41 | def year1995unix := 788947200 42 | def orders1994 : orderdate ↠ₛ E R := (S.predRange year1994unix year1995unix : ℕ →ₛ E R) 43 | 44 | def q5 := ∑ orderkey, orderdate, custkey: ∑ suppkey, nationkey, regionkey: ∑ regionname, extendedprice, discount: 45 | lineitem_revenue * orders * orders1994 * customer * supplier * (nation * region * asia) 46 | 47 | def funcs : List (String × String) := [ 48 | let fn := "q5"; (fn, compileFunMap String R fn q5) 49 | ] 50 | 51 | end Etch.Benchmark.TPCHq5 52 | -------------------------------------------------------------------------------- /graphs/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | 3 | # This Dockerfile creates an environment that can generate the graphs in our paper. 4 | # Steps: 5 | # 6 | # 1. Install Docker (make sure the "Buildx" plugin is installed as well) 7 | # https://docs.docker.com/build/architecture/#buildx 8 | # https://docs.docker.com/engine/reference/commandline/buildx/ 9 | # 10 | # 2. Build Docker image. From the etch4/ directory, run: 11 | # docker build -t etch-bench -f graphs/Dockerfile . 12 | # 13 | # 3. Run benchmarks. Run: 14 | # docker run --rm -v '.:/mnt' -e HOME=/mnt -w /mnt -u $(id -u):$(id -g) etch-bench bash graphs/run.sh 15 | # 16 | # 4. Generate graphs. Run: 17 | # docker run --rm -v '.:/mnt' -e HOME=/mnt -w /mnt -u $(id -u):$(id -g) etch-bench python3 graphs/graph.py 18 | 19 | FROM python:3.11-bullseye 20 | 21 | # Install clang and other apt dependencies 22 | COPY < 18 | ((csr.of f 2 ι₂).inherit <$> ·) ⊚ 19 | Functor.mapConst 1 20 | def d_ : ℕ →ₐ ι₂ →ₛ E α := 21 | range |> 22 | ((csr.of f 2 ι₂).inherit <$> ·) ⊚ 23 | Functor.mapConst 1 24 | def ss : ι₁ →ₛ ι₂ →ₛ E α := 25 | ((csr.of f 1 ι₁).level t₁ 0) |> 26 | ((csr.of f 2 ι₂).level t₂ <$> ·) ⊚ 27 | Functor.mapConst 1 28 | def ds : ℕ →ₐ ι₂ →ₛ E α := 29 | range |> 30 | ((csr.of f 2 ι₂).level .step <$> ·) ⊚ 31 | Functor.mapConst 1 32 | def dss : ℕ →ₐ ι₂ →ₛ ι₃ →ₛ E α := 33 | range |> 34 | ((csr.of f 2 ι₂).level t₂ <$> ·) ⊚ 35 | ((csr.of f 3 ι₃).level t₃ <$> ·) ⊚ 36 | Functor.mapConst 1 37 | def ds_ : ℕ →ₐ ι₂ →ₛ ι₃ →ₛ E α := 38 | range |> 39 | ((csr.of f 2 ι₂).level t₂ <$> ·) ⊚ 40 | ((csr.of f 3 ι₃).inherit <$> ·) ⊚ 41 | Functor.mapConst 1 42 | def ss__ : ι₁ →ₛ ι₂ →ₛ ι₃ →ₛ ι₄ →ₛ E α := 43 | ((csr.of f 1 ι₁).level t₁ 0) |> 44 | ((csr.of f 2 ι₂).level t₂ <$> ·) ⊚ 45 | ((csr.of f 3 ι₃).inherit <$> ·) ⊚ 46 | ((csr.of f 4 ι₄).inherit <$> ·) ⊚ 47 | Functor.mapConst 1 48 | def sss___ : ι₁ →ₛ ι₂ →ₛ ι₃ →ₛ ι₄ →ₛ ι₅ →ₛ ι₆ →ₛ E α := 49 | ((csr.of f 1 ι₁).level t₁ 0) |> 50 | ((csr.of f 2 ι₂).level t₂ <$> ·) ⊚ 51 | ((csr.of f 3 ι₃).level t₃ <$> ·) ⊚ 52 | ((csr.of f 4 ι₄).inherit <$> ·) ⊚ 53 | ((csr.of f 5 ι₅).inherit <$> ·) ⊚ 54 | ((csr.of f 6 ι₆).inherit <$> ·) ⊚ 55 | Functor.mapConst 1 56 | 57 | end Etch.Benchmark.SQL 58 | -------------------------------------------------------------------------------- /archive/src/verification/semantics/README.md: -------------------------------------------------------------------------------- 1 | # Verification of Semantics of Indexed Streams 2 | 3 | In this folder, we verify the semantics of indexed streams. There are two key claims in the paper: that the evaluation of indexed streams to functions is a homomorphism; and the operations on strictly monotonic indexed streams form strictly monotonic indexed streams. 4 | 5 | Both of these claims are proved in `nested_eval.lean`, which defines nested stream evaluation inductively on the depth of the nesting using typeclasses. We define the type of strictly monotonic lawful streams `ι ⟶ₛ α` and prove that the sum, product etc. of strictly lawful streams is strictly lawful. This is implicit in the type signatures of the operators e.g. `+ : (ι ⟶ₛ α) → (ι ⟶ₛ α) → (ι ⟶ₛ α)` produces a strictly lawful stream as output. 6 | 7 | We then define the typeclass `LawfulEval`, where `LawfulEval s f` indicates that evaluation is an addition and multiplication preserving map from `s` to `f`. We inductively give an instance for `LawfulEval (ι ⟶ₛ α) (ι →₀ β)` given an instance `LawfulEval α β`. We also prove correctness for `contract` and `replicate` (expand). Thus, we get the indexed stream correctness theorem (theorem 6.1 from the paper), which states that evaluation is a homomorphism on strictly monotonic lawful streams. 8 | 9 | 10 | The folder is organized as follows: 11 | 12 | - `skip_stream.lean`: This is the file which contains the main definitions of (skippable) indexed streams. 13 | - `add.lean`: Defines the sum of indexed streams; proves that addition is sound (i.e. `(a + b).eval = a.eval + b.eval`) 14 | - `mul.lean`: Defines the product of indexed streams; prove that multiplication is sound (i.e. `(a * b).eval = a.eval * b.eval`) 15 | - `contract.lean`: Defines the contraction of indexed streams; proves soundness (i.e. `s.contract.eval = sum of values of s.eval`, where "sum of values" is `finsupp.sum_range` defined in finsupp_lemmas) 16 | - `replicate.lean`: Defines replication of indexed streams; proves soundness (i.e. `(replicate n v).eval` is the constant function which always returns `v`) 17 | - `nested_eval.lean`: Defines nested evaluation of streams, and puts together the results from the above files to prove that evaluation is a homomorphism. 18 | 19 | The figures from the paper are generated in `examples.lean`. 20 | -------------------------------------------------------------------------------- /bench-wcoj.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "common.h" 6 | #include "operators.h" 7 | #include "sqlite3.h" 8 | 9 | // populated later 10 | #define MAX_SCALE 1000000 11 | 12 | int dsR1_pos[2]; 13 | int dsR1_crd[MAX_SCALE + 10]; 14 | int dsR2_pos[MAX_SCALE + 10]; 15 | int dsR2_crd[2 * MAX_SCALE + 10]; 16 | int dsS1_pos[2]; 17 | int dsS1_crd[MAX_SCALE + 10]; 18 | int dsS2_pos[MAX_SCALE + 10]; 19 | int dsS2_crd[2 * MAX_SCALE + 10]; 20 | int dsT1_pos[2]; 21 | int dsT1_crd[MAX_SCALE + 10]; 22 | int dsT2_pos[MAX_SCALE + 10]; 23 | int dsT2_crd[2 * MAX_SCALE + 10]; 24 | 25 | #include "gen_wcoj.c" 26 | 27 | static int populate_wcoj(sqlite3* db) { 28 | char* zErrMsg; 29 | int rc; 30 | void* data = NULL; 31 | 32 | #define GET_TBL2(tbl_name, col1, col2) \ 33 | do { \ 34 | rc = sqlite3_exec( \ 35 | db, "SELECT " #col1 ", " #col2 " FROM " #tbl_name " ORDER BY 1, 2", \ 36 | gen_callback_wcoj_##tbl_name, (void*)data, &zErrMsg); \ 37 | if (rc != SQLITE_OK) { \ 38 | printf("%s:%d: %s\n", __FILE__, __LINE__, zErrMsg); \ 39 | return rc; \ 40 | } \ 41 | } while (false) 42 | 43 | GET_TBL2(R, a, b); 44 | GET_TBL2(S, b, c); 45 | GET_TBL2(T, a, c); 46 | 47 | return rc; 48 | } 49 | 50 | static sqlite3* db; 51 | int res; 52 | 53 | int main(int argc, char* argv[]) { 54 | int rc = SQLITE_OK; 55 | 56 | sqlite3_initialize(); 57 | rc = sqlite3_open(argc >= 1 ? argv[1] : "data/pldi.db", &db); 58 | 59 | if (rc) { 60 | fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(db)); 61 | return (0); 62 | } else { 63 | fprintf(stderr, "Opened database successfully\n"); 64 | } 65 | 66 | time([]() { return populate_wcoj(db); }, "populate_wcoj", 1); 67 | printf("Loaded\n"); 68 | 69 | time([&]() { 70 | for (int i = 0; i < 1000; ++i) { 71 | res = wcoj(); 72 | } 73 | return res; 74 | }, "wcojx1000", 5); 75 | 76 | sqlite3_close(db); 77 | return 0; 78 | } 79 | -------------------------------------------------------------------------------- /lake-manifest.json: -------------------------------------------------------------------------------- 1 | {"version": 7, 2 | "packagesDir": ".lake/packages", 3 | "packages": 4 | [{"url": "https://github.com/leanprover/std4", 5 | "type": "git", 6 | "subDir": null, 7 | "rev": "e840c18f7334c751efbd4cfe531476e10c943cdb", 8 | "name": "std", 9 | "manifestFile": "lake-manifest.json", 10 | "inputRev": "main", 11 | "inherited": true, 12 | "configFile": "lakefile.lean"}, 13 | {"url": "https://github.com/leanprover-community/quote4", 14 | "type": "git", 15 | "subDir": null, 16 | "rev": "64365c656d5e1bffa127d2a1795f471529ee0178", 17 | "name": "Qq", 18 | "manifestFile": "lake-manifest.json", 19 | "inputRev": "master", 20 | "inherited": true, 21 | "configFile": "lakefile.lean"}, 22 | {"url": "https://github.com/leanprover-community/aesop", 23 | "type": "git", 24 | "subDir": null, 25 | "rev": "5fefb40a7c9038a7150e7edd92e43b1b94c49e79", 26 | "name": "aesop", 27 | "manifestFile": "lake-manifest.json", 28 | "inputRev": "master", 29 | "inherited": true, 30 | "configFile": "lakefile.lean"}, 31 | {"url": "https://github.com/leanprover-community/ProofWidgets4", 32 | "type": "git", 33 | "subDir": null, 34 | "rev": "fb65c476595a453a9b8ffc4a1cea2db3a89b9cd8", 35 | "name": "proofwidgets", 36 | "manifestFile": "lake-manifest.json", 37 | "inputRev": "v0.0.30", 38 | "inherited": true, 39 | "configFile": "lakefile.lean"}, 40 | {"url": "https://github.com/leanprover/lean4-cli", 41 | "type": "git", 42 | "subDir": null, 43 | "rev": "be8fa79a28b8b6897dce0713ef50e89c4a0f6ef5", 44 | "name": "Cli", 45 | "manifestFile": "lake-manifest.json", 46 | "inputRev": "main", 47 | "inherited": true, 48 | "configFile": "lakefile.lean"}, 49 | {"url": "https://github.com/leanprover-community/import-graph.git", 50 | "type": "git", 51 | "subDir": null, 52 | "rev": "61a79185b6582573d23bf7e17f2137cd49e7e662", 53 | "name": "importGraph", 54 | "manifestFile": "lake-manifest.json", 55 | "inputRev": "main", 56 | "inherited": true, 57 | "configFile": "lakefile.lean"}, 58 | {"url": "https://github.com/leanprover-community/mathlib4/", 59 | "type": "git", 60 | "subDir": null, 61 | "rev": "3897434e80c1e66658416557947b9b9604e336a7", 62 | "name": "mathlib", 63 | "manifestFile": "lake-manifest.json", 64 | "inputRev": "3897434e80c1e66658416557947b9b9604e336a7", 65 | "inherited": false, 66 | "configFile": "lakefile.lean"}], 67 | "name": "etch", 68 | "lakeDir": ".lake"} 69 | -------------------------------------------------------------------------------- /Etch/Add.lean: -------------------------------------------------------------------------------- 1 | import Etch.Stream 2 | 3 | variable {ι : Type} {α : Type _} [Tagged ι] [TaggedC ι] [DecidableEq ι] 4 | [LT ι] [LE ι] [DecidableRel (LT.lt : ι → ι → Prop)] 5 | [DecidableRel (LE.le : ι → ι → _)] 6 | 7 | -- `guard b s` returns a stream which returns `0` (empty stream) if `b` is false 8 | -- and acts identically to `s` if `b` is true. 9 | class Guard (α : Type _) where 10 | guard : E Bool → α → α 11 | 12 | instance [Tagged α] [OfNat α (nat_lit 0)] : Guard (E α) where 13 | guard b v := .call Op.ternary ![b, v, (0 : E α)] 14 | 15 | instance : Guard (S ι α) where 16 | guard b s := {s with valid := λ l => b * s.valid l} 17 | 18 | instance [Tagged α] [OfNat α (nat_lit 0)] : Guard (ι →ₐ E α) where 19 | guard b s := Guard.guard b ∘ s 20 | 21 | -- Returns an expression which evaluates to `true` iff `a.index' ≤ b.index'` 22 | def S_le (a : S ι α) (b : S ι β) (l : a.σ × b.σ) : E Bool := 23 | (.call Op.neg ![b.valid l.2]) + (a.valid l.1 * (a.index l.1 <= b.index l.2)) 24 | 25 | infixr:40 "≤ₛ" => S_le 26 | 27 | def Prod.symm (f : α × β) := (f.2, f.1) 28 | 29 | -- Local temporary variables for `add` 30 | structure AddTmp (ι : Type) [TaggedC ι] where 31 | (ci : Var ι) 32 | 33 | def AddTmp.ofName (n : Name) : AddTmp ι := 34 | ⟨(Var.mk "ci").fresh n⟩ 35 | 36 | def S.add [HAdd α β γ] [Guard α] [Guard β] (a : S ι α) (b : S ι β) : S ι γ where 37 | σ := (a.σ × b.σ) × AddTmp ι 38 | value := λ (p, _) => 39 | (Guard.guard ((S_le a b p) * a.ready p.1) $ a.value p.1) + 40 | (Guard.guard ((S_le b a p.symm) * b.ready p.2) $ b.value p.2) 41 | skip := λ (p, _) i => a.skip p.1 i ;; b.skip p.2 i 42 | succ := λ (p, t) i => 43 | t.ci.decl i;; 44 | a.succ p.1 t.ci;; b.succ p.2 t.ci 45 | ready := λ (p, _) => (S_le a b p) * a.ready p.1 + (S_le b a p.symm) * b.ready p.2 46 | index := λ (p, _) => .call Op.ternary ![S_le a b p, a.index p.1, b.index p.2] 47 | valid := λ (p, _) => a.valid p.1 + b.valid p.2 48 | init := λ n => let (i, s) := seqInit a b n; (i, (s, .ofName n)) 49 | 50 | instance [Add α] [Guard α] : Add (ι →ₛ α) := ⟨S.add⟩ 51 | instance [HAdd α β γ] [Guard α] [Guard β] : HAdd (S ι α) (S ι β) (S ι γ) := ⟨S.add⟩ 52 | instance [HAdd α β γ] : HAdd (ι →ₐ α) (ι →ₐ β) (ι →ₐ γ) where hAdd a b := λ v => a v + b v 53 | instance [HAdd α β γ] : HAdd (ι →ₛ α) (ι →ₐ β) (ι →ₛ γ) where hAdd a b := {a with value := λ s => a.value s + b (a.index s)} 54 | instance [HAdd β α γ] : HAdd (ι →ₐ β) (ι →ₛ α) (ι →ₛ γ) where hAdd b a := {a with value := λ s => b (a.index s) + a.value s} 55 | -------------------------------------------------------------------------------- /Etch/InductiveStreamTest.lean: -------------------------------------------------------------------------------- 1 | import Etch.InductiveStreamCompile 2 | import Etch.InductiveStreamDeriving 3 | 4 | namespace Etch.Stream.Test 5 | 6 | inductive A 7 | | attr1 | attr2 | attr3 8 | deriving Repr, DecidableEq 9 | def A.toTag : A → String 10 | | attr1 => "attr1" 11 | | attr2 => "attr2" 12 | | attr3 => "attr3" 13 | instance A.represented : Represented A := ⟨A.toTag⟩ 14 | open A (attr1 attr2 attr3) 15 | 16 | set_option trace.Elab.Deriving.attr_order_total true in 17 | deriving instance AttrOrderTotal with { order := [.attr1, .attr2, .attr3] } for A 18 | 19 | inductive A' 20 | | attr' : Fin 10 → A' 21 | deriving Repr, DecidableEq, 22 | AttrOrderTotal with { order := [.attr' 0, .attr' 1, .attr' 2, .attr' 3, .attr' 4, 23 | .attr' 5, .attr' 6, .attr' 7, .attr' 8, .attr' 9] } 24 | def A'.toTag : A' → String 25 | | attr' _ => "attr1" 26 | instance A'.represented : Represented A' := ⟨A'.toTag⟩ 27 | open A' (attr') 28 | 29 | variable (i j k : A) 30 | #check (contract' i $ default [i, j]) 31 | #check let a : [j] →ₛ i := (contract' i (default [i, j])); a 32 | 33 | variable 34 | (s₁ : [attr1] →ₛ attr3) 35 | (s₂ : [attr2] →ₛ attr3) 36 | (srev : [attr2, attr1] →ₛ attr3) 37 | 38 | section v1 39 | def attr1Sublist : [attr1].SublistT A.order.val := .check' (by decide) 40 | def attr2Sublist : [attr2].SublistT A.order.val := .check' (by decide) 41 | 42 | #eval (mergeAttr attr1Sublist attr2Sublist).fst.val 43 | 44 | #check (s₂.mulMerge attr2Sublist attr1Sublist s₁) -- `: (mergeAttr attr2Sublist attr1Sublist).fst.val →ₛ attr3` 45 | #check (s₂.mulMerge attr2Sublist attr1Sublist s₁ : [attr1, attr2] →ₛ attr3) 46 | 47 | #check (s₂.mulMerge (.check' (by decide)) (.check' (by decide)) s₁) -- works 48 | -- Doesn't work ☹ 49 | -- #check (s₂.mulMerge (.check' (by decide)) (.check' (by decide)) s₁ : [attr1, attr2] →ₛ attr3) 50 | 51 | end v1 52 | 53 | section v2 54 | #eval merge [attr1] [attr2] 55 | #check s₂.mul' s₁ -- `: [attr1, attr2] →ₛ attr3` 56 | 57 | -- #check srev.mul'' s₁ 58 | -- failed to synthesize instance 59 | -- `AttrMerge' A.order.val [attr2, attr1] [attr1] ?m.287939` 60 | 61 | #eval decide (attr' 0 < attr' 1) 62 | #eval decide (attr' 9 < attr' 1) 63 | #eval merge [attr' 0, attr' 9] [attr' 8] 64 | variable 65 | (s₁ : [attr' 0, attr' 3, attr' 4, attr' 5, attr' 6, attr' 7, attr' 9] →ₛ attr' 4) 66 | (s₂ : [attr' 0, attr' 1, attr' 2, attr' 3, attr' 8] →ₛ attr' 4) 67 | -- Notice how fast this is. 68 | #check s₂.mul' s₁ -- `: [attr1, attr2] →ₛ attr3` 69 | 70 | end v2 71 | 72 | end Etch.Stream.Test 73 | -------------------------------------------------------------------------------- /bench-matmul.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "common.h" 6 | #include "operators.h" 7 | #include "sqlite3.h" 8 | 9 | int array_size = 4000000; 10 | 11 | int* ssA1_crd = (int*)calloc(array_size, sizeof(int)); 12 | int* ssA1_pos = (int*)calloc(array_size, sizeof(int)); 13 | int* ssA2_crd = (int*)calloc(array_size, sizeof(int)); 14 | int* ssA2_pos = (int*)calloc(array_size, sizeof(int)); 15 | double* ssA_vals = (double*)calloc(array_size, sizeof(double)); 16 | int* ssB1_crd = (int*)calloc(array_size, sizeof(int)); 17 | int* ssB1_pos = (int*)calloc(array_size, sizeof(int)); 18 | int* ssB2_crd = (int*)calloc(array_size, sizeof(int)); 19 | int* ssB2_pos = (int*)calloc(array_size, sizeof(int)); 20 | double* ssB_vals = (double*)calloc(array_size, sizeof(double)); 21 | 22 | #include "gen_matmul.c" 23 | 24 | static sqlite3* db; 25 | 26 | static int populate_matmul(sqlite3* db) { 27 | char* zErrMsg; 28 | int rc; 29 | void* data = NULL; 30 | 31 | #define GET_TBL3(out_name, tbl_name, col1, col2, col3) \ 32 | do { \ 33 | rc = sqlite3_exec(db, \ 34 | "SELECT " #col1 ", " #col2 ", " #col3 " FROM " #tbl_name \ 35 | " ORDER BY 1, 2, 3", \ 36 | gen_##out_name##_callback, (void*)data, &zErrMsg); \ 37 | if (rc != SQLITE_OK) { \ 38 | printf("%s:%d: %s\n", __FILE__, __LINE__, zErrMsg); \ 39 | return rc; \ 40 | } \ 41 | } while (false) 42 | 43 | GET_TBL3(ssA, A, i, j, v); 44 | GET_TBL3(ssB, B, i, j, v); 45 | 46 | return rc; 47 | } 48 | 49 | int main(int argc, char* argv[]) { 50 | int rc = SQLITE_OK; 51 | 52 | sqlite3_initialize(); 53 | rc = sqlite3_open(argc > 1 ? argv[1] : "./data/pldi.db", &db); 54 | 55 | if (rc) { 56 | fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(db)); 57 | return 1; 58 | } else { 59 | fprintf(stderr, "Opened database successfully\n"); 60 | } 61 | 62 | time([]() { return populate_matmul(db); }, "populate_matmul", 1); 63 | printf("Loaded\n"); 64 | 65 | time(mul_inner, "mul_inner", 5); 66 | time(mul_rowcb, "mul_rowcb", 5); 67 | 68 | sqlite3_close(db); 69 | return 0; 70 | } 71 | -------------------------------------------------------------------------------- /bench/filtered-spmv-datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sqlite3 3 | import sys 4 | from pathlib import Path 5 | 6 | 7 | def makeA(p=0.1, nonzeros=400000): 8 | # 2000 × 2000 × 0.1 = 400000 9 | 10 | n = int(np.round(np.sqrt(nonzeros / p))) 11 | matrix_size = n * n 12 | m = np.random.rand(nonzeros) 13 | m *= (matrix_size - nonzeros) / np.sum(m) 14 | m += 1 15 | 16 | # Extra correction 17 | # print(np.sum(np.round(m))) 18 | m /= np.sum(np.round(m)) / matrix_size 19 | # print(np.sum(np.round(m))) 20 | m /= np.sum(np.round(m)) / matrix_size 21 | # print(np.sum(np.round(m))) 22 | 23 | R = np.random.rand(nonzeros) 24 | 25 | result = set() 26 | last = 0 27 | for idx, r in enumerate(m): 28 | if last >= matrix_size: 29 | break 30 | i, j = last // n, last % n 31 | result.add((i, j, float(R[idx]))) 32 | last += int(np.round(r)) 33 | print( 34 | f"expected={nonzeros} actual={len(result)} expect_sparsity={p} actual_sparsity={len(result) / matrix_size}" 35 | ) 36 | return result 37 | 38 | 39 | def makeV(p=0.1, nonzeros=200): 40 | # 2000 × 0.1 = 200 41 | 42 | n = int(np.round(nonzeros / p)) 43 | matrix_size = n 44 | m = np.random.rand(nonzeros) 45 | m *= (matrix_size - nonzeros) / np.sum(m) 46 | m += 1 47 | 48 | # Extra correction 49 | # print(np.sum(np.round(m))) 50 | m /= np.sum(np.round(m)) / matrix_size 51 | # print(np.sum(np.round(m))) 52 | m /= np.sum(np.round(m)) / matrix_size 53 | # print(np.sum(np.round(m))) 54 | 55 | R = np.random.rand(nonzeros) 56 | 57 | result = set() 58 | last = 0 59 | for i, r in enumerate(m): 60 | if last >= matrix_size: 61 | break 62 | result.add((last, float(R[i]))) 63 | last += int(np.round(r)) 64 | print( 65 | f"expected={nonzeros} actual={len(result)} expect_sparsity={p} actual_sparsity={len(result) / matrix_size}" 66 | ) 67 | return result 68 | 69 | 70 | def main( 71 | db: Path = Path("data/pldi.db"), 72 | pA: float = 0.0002, 73 | pV: float = 0.1, 74 | nonzeros: int = 20000, 75 | ): 76 | c = sqlite3.connect(str(db)) 77 | c.execute("DROP TABLE IF EXISTS A") 78 | c.execute("DROP TABLE IF EXISTS V") 79 | c.execute("CREATE TABLE A(i INTEGER NOT NULL, j INTEGER NOT NULL, v REAL NOT NULL)") 80 | c.execute("CREATE TABLE V(i INTEGER NOT NULL, v REAL NOT NULL)") 81 | print("A") 82 | c.executemany(f"INSERT INTO A VALUES(?,?,?)", makeA(pA, nonzeros)) 83 | print("V") 84 | v_nonzeros = int(np.sqrt(nonzeros / pA) * pV) 85 | c.executemany(f"INSERT INTO V VALUES(?,?)", makeV(pV, v_nonzeros)) 86 | c.commit() 87 | 88 | 89 | main(Path(sys.argv[1]), float(sys.argv[2]), float(sys.argv[3]), int(sys.argv[4])) 90 | -------------------------------------------------------------------------------- /Etch/Benchmark/Basic.lean: -------------------------------------------------------------------------------- 1 | import Etch.C 2 | import Etch.Op 3 | import Etch.Compile 4 | 5 | instance : TaggedC R := ⟨⟨"double"⟩⟩ 6 | 7 | -- For C++ maps. 8 | def Op.indexMap {α γ : Type} [Inhabited γ] : Op (ℕ → γ) where 9 | argTypes := ![α → γ, α] 10 | spec := fun a => fun | 0 => (a 0) (a 1) 11 | | _ => default 12 | opName := "index_map" 13 | def Op.indexMap2 {α β γ : Type} [Inhabited γ] : Op (ℕ → γ) where 14 | argTypes := ![α × β → γ, α, β] 15 | spec := fun a => fun | 0 => (a 0) (a 1, a 2) 16 | | _ => default 17 | opName := "index_map" 18 | 19 | namespace Etch.Benchmark 20 | 21 | section 22 | 23 | variable (I : Type _) [TaggedC I] 24 | (J : Type _) [TaggedC J] 25 | (X : Type _) [TaggedC X] [Inhabited X] [Tagged X] [Zero X] 26 | {Z : Type _} 27 | 28 | local instance (α) : ToString (Var α) := ⟨Var.toString⟩ 29 | local instance : ToString (DeclType) := ⟨fun | .mk s => s⟩ 30 | open TaggedC (tag) 31 | 32 | def Arg := (a : Type) × TaggedC a × Var a 33 | def Arg.mk {a} [inst : TaggedC a] (v : Var a) : Arg := ⟨a, ⟨inst, v⟩⟩ 34 | def Arg.toC : Arg → String 35 | | ⟨a, ⟨_, v⟩⟩ => s!"{tag a} {v}" 36 | 37 | def compileFun [Compile (Var X) Z] (name : String) (exp : Z) (args : List Arg := []) : String := 38 | let val : Var X := "val" 39 | let decl := (val.decl 0).compile.emit.run 40 | let argStr := args.map Arg.toC |> String.intercalate ", " 41 | s!"{tag X} {name}({argStr}) \{\n {decl}\n {go val exp}\n return {val};\n}" 42 | 43 | def compileFunMap [Compile (lvl I (MemLoc X)) Z] (name : String) (exp : Z) : String := 44 | let T := s!"std::unordered_map<{tag I}, {tag X}>" 45 | let out : Var (I → X) := "out"; 46 | let out_loc : Var (ℕ → X) := "out_loc"; 47 | let outVal : lvl I (MemLoc X) := { 48 | push := fun i => 49 | (out_loc.store_var (E.call Op.indexMap ![out.expr, i]), ⟨out_loc, 0⟩) 50 | } 51 | s!"{T} {name}() \{\n {T} {out};\n {tag X}* {out_loc};\n {go outVal exp}\n return out;\n}" 52 | 53 | def compileFunMap2 [Compile (lvl I (lvl J (MemLoc X))) Z] (name : String) (exp : Z) : String := 54 | let tpl := s!"std::tuple<{tag I}, {tag J}>" 55 | let T := s!"std::unordered_map<{tpl}, {tag X}, hash_tuple::hash<{tpl}>>" 56 | let out : Var (I × J → X) := "out"; 57 | let out_loc : Var (ℕ → X) := "out_loc"; 58 | let outVal : lvl I (lvl J (MemLoc X)) := { 59 | push := fun i => 60 | (.skip, ⟨fun j => 61 | (out_loc.store_var (E.call Op.indexMap2 ![out.expr, i, j]), 62 | ⟨out_loc, 0⟩)⟩) } 63 | s!"{T} {name}() \{\n {T} {out};\n {tag X}* {out_loc};\n {go outVal exp}\n return out;\n}" 64 | end 65 | 66 | section 67 | 68 | def compileSqliteCb (name : String) (body : List String) : String := 69 | s!"int {name}(void *data, int argc, char **argv, char **azColName) \{\n {String.join body}\n return 0;\n}" 70 | 71 | end 72 | 73 | end Etch.Benchmark 74 | -------------------------------------------------------------------------------- /archive/src/verification/semantics/contract.lean: -------------------------------------------------------------------------------- 1 | import verification.semantics.skip_stream 2 | 3 | /-! 4 | # Contraction of indexed streams 5 | 6 | In this file, we define the contraction of indexed streams `Stream.contract`. 7 | This replaces the indexing axis with `() : Unit`, implicitly summing over the 8 | values of the stream. 9 | 10 | ## Main results 11 | - `contract_eval`: Correctness for `contract`; evaluating `contract s` results in 12 | the sum of the values of `s` 13 | - `is_lawful (Stream.contract s)`: `s.contract` is lawful assuming `s` is 14 | 15 | -/ 16 | 17 | variables {ι : Type} {α : Type*} 18 | 19 | @[simps] def Stream.contract (s : Stream ι α) : Stream unit α := 20 | { σ := s.σ, 21 | valid := s.valid, 22 | ready := s.ready, 23 | skip := λ q hq i, s.skip q hq (s.index q hq, i.2), 24 | index := default, 25 | value := s.value } 26 | 27 | variables [linear_order ι] 28 | 29 | section index_lemmas 30 | 31 | instance (s : Stream ι α) [is_bounded s] : is_bounded (Stream.contract s) := 32 | ⟨⟨s.wf_rel, s.wf, λ q hq, begin 33 | rintro ⟨⟨⟩, b⟩, 34 | simp only [Stream.contract_skip], 35 | refine (s.wf_valid q hq (s.index q hq, b)).imp_right (and.imp_left _), 36 | simp [Stream.to_order], exact id, 37 | end⟩⟩ 38 | 39 | @[simp] lemma contract_next (s : Stream ι α) (q : s.σ) : (Stream.contract s).next q = s.next q := rfl 40 | 41 | lemma contract_map {β : Type*} (f : α → β) (s : Stream ι α) : 42 | (s.map f).contract = s.contract.map f := rfl 43 | 44 | end index_lemmas 45 | 46 | section value_lemmas 47 | variables [add_comm_monoid α] 48 | 49 | lemma contract_eval₀ (s : Stream ι α) (q : s.σ) (hq : s.valid q) : 50 | (Stream.contract s).eval₀ q hq () = finsupp.sum_range (s.eval₀ q hq) := 51 | by { simp only [Stream.eval₀], dsimp, split_ifs with hr; simp, } 52 | 53 | lemma contract_eval (s : Stream ι α) [is_bounded s] [add_comm_monoid α] (q : s.σ) : 54 | (Stream.contract s).eval q () = finsupp.sum_range (s.eval q) := 55 | begin 56 | refine @well_founded.induction _ (Stream.contract s).wf_rel (Stream.contract s).wf _ q _, 57 | clear q, intros q ih, 58 | by_cases hq : s.valid q, swap, { simp [hq], }, 59 | simp only [s.eval_valid _ hq, (Stream.contract s).eval_valid _ hq, finsupp.coe_add, pi.add_apply, 60 | map_add, ih _ ((Stream.contract s).next_wf q hq)], rw [contract_next, contract_eval₀], 61 | end 62 | 63 | lemma contract_mono (s : Stream ι α) : (Stream.contract s).is_monotonic := 64 | λ q hq i, by { rw [Stream.index'_val hq, punit_eq_star ((Stream.contract s).index q hq)], exact bot_le, } 65 | 66 | instance (s : Stream ι α) [is_lawful s] : is_lawful (Stream.contract s) := 67 | { mono := contract_mono s, 68 | skip_spec := λ q hq i j hj, begin 69 | cases j, 70 | obtain rfl : i = ((), ff) := le_bot_iff.mp hj, 71 | simp only [Stream.contract_skip, contract_eval, Stream.eval_skip_eq_of_ff], 72 | end } 73 | 74 | end value_lemmas 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /bench/tpch-q5-duckdb-export-to-csv.sql: -------------------------------------------------------------------------------- 1 | -- Load TPC-H dataset for Query 5 2 | 3 | .echo on 4 | 5 | INSTALL sqlite; 6 | LOAD sqlite; 7 | 8 | SET GLOBAL sqlite_all_varchar=true; 9 | 10 | CREATE TABLE REGION ( R_REGIONKEY INTEGER NOT NULL, 11 | R_NAME CHAR(25) NOT NULL, 12 | PRIMARY KEY (R_REGIONKEY)); 13 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 14 | N_REGIONKEY INTEGER NOT NULL REFERENCES REGION (R_REGIONKEY), 15 | N_NAME CHAR(25) NOT NULL, 16 | PRIMARY KEY (N_NATIONKEY)); 17 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 18 | S_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 19 | PRIMARY KEY (S_SUPPKEY)); 20 | CREATE TABLE CUSTOMER ( C_CUSTKEY INTEGER NOT NULL, 21 | C_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 22 | PRIMARY KEY (C_CUSTKEY)); 23 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 24 | O_CUSTKEY INTEGER NOT NULL REFERENCES CUSTOMER (C_CUSTKEY), 25 | O_ORDERDATE DATE NOT NULL, 26 | PRIMARY KEY (O_ORDERKEY)); 27 | CREATE TABLE LINEITEM ( L_ORDERKEY INTEGER NOT NULL REFERENCES ORDERS (O_ORDERKEY), 28 | L_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 29 | L_LINENUMBER INTEGER NOT NULL, 30 | L_EXTENDEDPRICE DOUBLE NOT NULL, -- actually DECIMAL(15,2), but etch uses double 31 | L_DISCOUNT DOUBLE NOT NULL, 32 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 33 | 34 | INSERT INTO REGION 35 | SELECT R_REGIONKEY, R_NAME 36 | FROM sqlite_scan('TPC-H.db', 'REGION'); 37 | 38 | INSERT INTO NATION 39 | SELECT N_NATIONKEY, N_REGIONKEY, N_NAME 40 | FROM sqlite_scan('TPC-H.db', 'NATION'); 41 | 42 | INSERT INTO SUPPLIER 43 | SELECT S_SUPPKEY, S_NATIONKEY 44 | FROM sqlite_scan('TPC-H.db', 'SUPPLIER'); 45 | 46 | INSERT INTO CUSTOMER 47 | SELECT C_CUSTKEY, C_NATIONKEY 48 | FROM sqlite_scan('TPC-H.db', 'CUSTOMER'); 49 | 50 | INSERT INTO ORDERS 51 | SELECT O_ORDERKEY, O_CUSTKEY, O_ORDERDATE 52 | FROM sqlite_scan('TPC-H.db', 'ORDERS'); 53 | 54 | INSERT INTO LINEITEM 55 | SELECT L_ORDERKEY, L_SUPPKEY, L_LINENUMBER, L_EXTENDEDPRICE, L_DISCOUNT 56 | FROM sqlite_scan('TPC-H.db', 'LINEITEM'); 57 | 58 | COPY REGION TO 'tpch-csv/region.csv' (HEADER, DELIMITER ','); 59 | COPY NATION TO 'tpch-csv/nation.csv' (HEADER, DELIMITER ','); 60 | COPY SUPPLIER TO 'tpch-csv/supplier.csv' (HEADER, DELIMITER ','); 61 | COPY CUSTOMER TO 'tpch-csv/customer.csv' (HEADER, DELIMITER ','); 62 | COPY ORDERS TO 'tpch-csv/orders.csv' (HEADER, DELIMITER ','); 63 | COPY LINEITEM TO 'tpch-csv/lineitem.csv' (HEADER, DELIMITER ','); 64 | -------------------------------------------------------------------------------- /bench/tpch-q9-duckdb-export-to-csv.sql: -------------------------------------------------------------------------------- 1 | -- Load TPC-H dataset for Query 9 2 | 3 | .echo on 4 | 5 | INSTALL sqlite; 6 | LOAD sqlite; 7 | 8 | SET GLOBAL sqlite_all_varchar=true; 9 | 10 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 11 | N_NAME CHAR(25) NOT NULL, 12 | PRIMARY KEY (N_NATIONKEY)); 13 | CREATE TABLE PART ( P_PARTKEY INTEGER NOT NULL, 14 | P_NAME VARCHAR(55) NOT NULL, 15 | PRIMARY KEY (P_PARTKEY)); 16 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 17 | S_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 18 | PRIMARY KEY (S_SUPPKEY)); 19 | CREATE TABLE PARTSUPP ( PS_PARTKEY INTEGER NOT NULL REFERENCES PART (P_PARTKEY), 20 | PS_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 21 | PS_SUPPLYCOST DOUBLE NOT NULL, 22 | PRIMARY KEY (PS_PARTKEY, PS_SUPPKEY)); 23 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 24 | O_ORDERDATE DATE NOT NULL, 25 | PRIMARY KEY (O_ORDERKEY)); 26 | CREATE TABLE LINEITEM ( L_PARTKEY INTEGER NOT NULL REFERENCES PART (P_PARTKEY), 27 | L_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 28 | L_ORDERKEY INTEGER NOT NULL REFERENCES ORDERS (O_ORDERKEY), 29 | L_LINENUMBER INTEGER NOT NULL, 30 | L_QUANTITY DOUBLE NOT NULL, 31 | L_EXTENDEDPRICE DOUBLE NOT NULL, 32 | L_DISCOUNT DOUBLE NOT NULL, 33 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 34 | 35 | 36 | INSERT INTO NATION 37 | SELECT N_NATIONKEY, N_NAME FROM sqlite_scan('TPC-H.db', 'NATION'); 38 | 39 | INSERT INTO PART 40 | SELECT P_PARTKEY, P_NAME FROM sqlite_scan('TPC-H.db', 'PART'); 41 | 42 | INSERT INTO SUPPLIER 43 | SELECT S_SUPPKEY, S_NATIONKEY FROM sqlite_scan('TPC-H.db', 'SUPPLIER'); 44 | 45 | INSERT INTO PARTSUPP 46 | SELECT PS_PARTKEY, PS_SUPPKEY, PS_SUPPLYCOST FROM sqlite_scan('TPC-H.db', 'PARTSUPP'); 47 | 48 | INSERT INTO ORDERS 49 | SELECT O_ORDERKEY, O_ORDERDATE FROM sqlite_scan('TPC-H.db', 'ORDERS'); 50 | 51 | INSERT INTO LINEITEM 52 | SELECT L_PARTKEY, L_SUPPKEY, L_ORDERKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT FROM sqlite_scan('TPC-H.db', 'LINEITEM'); 53 | 54 | COPY NATION TO 'tpch-csv/nation.csv' (HEADER, DELIMITER ','); 55 | COPY PART TO 'tpch-csv/part.csv' (HEADER, DELIMITER ','); 56 | COPY SUPPLIER TO 'tpch-csv/supplier.csv' (HEADER, DELIMITER ','); 57 | COPY PARTSUPP TO 'tpch-csv/partsupp.csv' (HEADER, DELIMITER ','); 58 | COPY ORDERS TO 'tpch-csv/orders.csv' (HEADER, DELIMITER ','); 59 | COPY LINEITEM TO 'tpch-csv/lineitem.csv' (HEADER, DELIMITER ','); 60 | -------------------------------------------------------------------------------- /bench/tpch-q5-sqlite-prep.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE REGION ( R_REGIONKEY INTEGER NOT NULL, 2 | R_NAME CHAR(25) NOT NULL, 3 | PRIMARY KEY (R_REGIONKEY)); 4 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 5 | N_REGIONKEY INTEGER NOT NULL REFERENCES REGION (R_REGIONKEY), 6 | N_NAME CHAR(25) NOT NULL, 7 | PRIMARY KEY (N_NATIONKEY)); 8 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 9 | S_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 10 | PRIMARY KEY (S_SUPPKEY)); 11 | CREATE TABLE CUSTOMER ( C_CUSTKEY INTEGER NOT NULL, 12 | C_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 13 | PRIMARY KEY (C_CUSTKEY)); 14 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 15 | O_CUSTKEY INTEGER NOT NULL REFERENCES CUSTOMER (C_CUSTKEY), 16 | O_ORDERDATE DATE NOT NULL, 17 | PRIMARY KEY (O_ORDERKEY)); 18 | CREATE TABLE LINEITEM ( L_ORDERKEY INTEGER NOT NULL REFERENCES SUPPLIER (O_ORDERKEY), 19 | L_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 20 | L_LINENUMBER INTEGER NOT NULL, 21 | L_EXTENDEDPRICE DOUBLE NOT NULL, -- actually DECIMAL(15,2), but etch uses double 22 | L_DISCOUNT DOUBLE NOT NULL, 23 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 24 | 25 | CREATE INDEX REGION_idx_q5 ON REGION(R_REGIONKEY, R_NAME); 26 | CREATE INDEX NATION_idx_q5 ON NATION(N_NATIONKEY, N_REGIONKEY, N_NAME); 27 | CREATE INDEX SUPPLIER_idx_q5 ON SUPPLIER(S_SUPPKEY, S_NATIONKEY); 28 | CREATE INDEX ORDERS_idx_q5 ON ORDERS(O_ORDERKEY, O_ORDERDATE, O_CUSTKEY); 29 | CREATE INDEX CUSTOMER_idx_q5 ON CUSTOMER(C_CUSTKEY, C_NATIONKEY); 30 | CREATE INDEX LINEITEM_idx_q5 ON LINEITEM(L_ORDERKEY, L_SUPPKEY, L_EXTENDEDPRICE, L_DISCOUNT); 31 | 32 | ATTACH DATABASE 'TPC-H.db' AS t; 33 | 34 | INSERT INTO region 35 | SELECT r_regionkey, r_name 36 | FROM t.region 37 | ORDER BY 1, 2; 38 | 39 | INSERT INTO nation 40 | SELECT n_nationkey, n_regionkey, n_name 41 | FROM t.nation 42 | ORDER BY 1, 2, 3; 43 | 44 | INSERT INTO supplier 45 | SELECT s_suppkey, s_nationkey 46 | FROM t.supplier 47 | ORDER BY 1, 2; 48 | 49 | INSERT INTO customer 50 | SELECT c_custkey, c_nationkey 51 | FROM t.customer 52 | ORDER BY 1, 2; 53 | 54 | INSERT INTO orders 55 | SELECT o_orderkey, o_custkey, o_orderdate 56 | FROM t.orders 57 | ORDER BY 1, 2, 3; 58 | 59 | INSERT INTO lineitem 60 | SELECT l_orderkey, l_suppkey, l_linenumber, l_extendedprice, l_discount 61 | FROM t.lineitem 62 | ORDER BY 1, 2, 4, 5; 63 | 64 | DETACH DATABASE t; 65 | 66 | SELECT page_count * page_size as size 67 | FROM pragma_page_count(), pragma_page_size(); 68 | -------------------------------------------------------------------------------- /bench/tpch-q9-sqlite-prep.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 2 | N_NAME CHAR(25) NOT NULL, 3 | PRIMARY KEY (N_NATIONKEY)); 4 | CREATE TABLE PART ( P_PARTKEY INTEGER NOT NULL, 5 | P_NAME VARCHAR(55) NOT NULL, 6 | PRIMARY KEY (P_PARTKEY)); 7 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 8 | S_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 9 | PRIMARY KEY (S_SUPPKEY)); 10 | CREATE TABLE PARTSUPP ( PS_PARTKEY INTEGER NOT NULL REFERENCES PART (P_PARTKEY), 11 | PS_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 12 | PS_SUPPLYCOST DOUBLE NOT NULL, 13 | PRIMARY KEY (PS_PARTKEY, PS_SUPPKEY)); 14 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 15 | O_ORDERDATE DATE NOT NULL, 16 | PRIMARY KEY (O_ORDERKEY)); 17 | CREATE TABLE LINEITEM ( L_PARTKEY INTEGER NOT NULL REFERENCES PART (P_PARTKEY), 18 | L_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 19 | L_ORDERKEY INTEGER NOT NULL REFERENCES ORDERS (O_ORDERKEY), 20 | L_LINENUMBER INTEGER NOT NULL, 21 | L_QUANTITY DOUBLE NOT NULL, 22 | L_EXTENDEDPRICE DOUBLE NOT NULL, 23 | L_DISCOUNT DOUBLE NOT NULL, 24 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 25 | 26 | CREATE INDEX LINEITEM_idx_q9 ON LINEITEM(L_PARTKEY, L_SUPPKEY, L_ORDERKEY, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT); 27 | CREATE INDEX PART_idx_q9 ON PART(P_PARTKEY, P_NAME); 28 | CREATE INDEX ORDERS_idx_q9 ON ORDERS(O_ORDERKEY, O_ORDERDATE); 29 | CREATE INDEX SUPPLIER_idx_q9 ON SUPPLIER(S_SUPPKEY, S_NATIONKEY); 30 | CREATE INDEX PARTSUPP_idx_q9 ON PARTSUPP(PS_PARTKEY, PS_SUPPKEY, PS_SUPPLYCOST); 31 | CREATE INDEX NATION_idx_q9 ON NATION(N_NATIONKEY, N_NAME); 32 | 33 | ATTACH DATABASE 'TPC-H.db' AS t; 34 | 35 | INSERT INTO NATION 36 | SELECT N_NATIONKEY, N_NAME 37 | FROM t.NATION 38 | ORDER BY 1; 39 | 40 | INSERT INTO PART 41 | SELECT P_PARTKEY, P_NAME 42 | FROM t.PART 43 | ORDER BY 1; 44 | 45 | INSERT INTO SUPPLIER 46 | SELECT S_SUPPKEY, S_NATIONKEY 47 | FROM t.SUPPLIER 48 | ORDER BY 1; 49 | 50 | INSERT INTO PARTSUPP 51 | SELECT PS_PARTKEY, PS_SUPPKEY, PS_SUPPLYCOST 52 | FROM t.PARTSUPP 53 | ORDER BY 1, 2; 54 | 55 | INSERT INTO ORDERS 56 | SELECT O_ORDERKEY, O_ORDERDATE 57 | FROM t.ORDERS 58 | ORDER BY 1; 59 | 60 | INSERT INTO LINEITEM 61 | SELECT L_PARTKEY, L_SUPPKEY, L_ORDERKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT 62 | FROM t.LINEITEM 63 | ORDER BY 1, 2, 3, 4, 5, 6, 7; 64 | 65 | DETACH DATABASE t; 66 | 67 | SELECT page_count * page_size as size 68 | FROM pragma_page_count(), pragma_page_size(); 69 | -------------------------------------------------------------------------------- /Etch/StreamFusion/Proofs/Imap.lean: -------------------------------------------------------------------------------- 1 | import Etch.StreamFusion.Proofs.StreamProof 2 | 3 | namespace Etch.Verification.Stream 4 | open Streams 5 | 6 | variable {ι ι' : Type} [LinearOrder ι] [LinearOrder ι'] 7 | 8 | theorem lt_iff_of_eq_compare {a b : ι} {c d : ι'} (h : compareOfLessAndEq a b = compareOfLessAndEq c d) : 9 | a < b ↔ c < d := by 10 | simp only [compareOfLessAndEq] at h 11 | split_ifs at h <;> simp [*] 12 | 13 | theorem eq_iff_of_eq_compare {a b : ι} {c d : ι'} (h : compareOfLessAndEq a b = compareOfLessAndEq c d) : 14 | a = b ↔ c = d := by 15 | simp only [compareOfLessAndEq] at h 16 | split_ifs at h with h₁ h₂ <;> aesop 17 | 18 | theorem le_iff_of_eq_compare {a b : ι} {c d : ι'} (h : compareOfLessAndEq a b = compareOfLessAndEq c d) : 19 | a ≤ b ↔ c ≤ d := by 20 | simp [le_iff_lt_or_eq, lt_iff_of_eq_compare h, eq_iff_of_eq_compare h] 21 | 22 | theorem gt_iff_of_eq_compare {a b : ι} {c d : ι'} (h : compareOfLessAndEq a b = compareOfLessAndEq c d) : 23 | b < a ↔ d < c := by 24 | simpa using (le_iff_of_eq_compare h).not 25 | 26 | -- theorem attach_lex_of_eq_compare [LinearOrder γ] (a b : ι) (c : ι') 27 | 28 | theorem IsBounded.imap (s : Stream ι α) [IsBounded s] 29 | (f : ι → ι') (g : ι → ι' → ι) 30 | (hfg : ∀ (i : ι) (j : ι'), compareOfLessAndEq j (f i) = compareOfLessAndEq (g i j) i) : 31 | IsBounded (s.imap_general f g) := by 32 | refine ⟨s.wfRel, ?_⟩ 33 | rintro q ⟨j, b⟩ 34 | refine (s.wf_valid q (g (s.index q) j, b)).imp_right (And.imp_left ?_) 35 | dsimp [toOrder] 36 | convert id 37 | simp only [Prod.Lex.lt_iff', ← lt_iff_of_eq_compare (hfg (s.index q) j), 38 | eq_iff_of_eq_compare (hfg (s.index q) j)] 39 | 40 | theorem IsMonotonic.imap {s : Stream ι α} (hs : s.IsMonotonic) {f : ι → ι'} (g : ι → ι' → ι) (hf : Monotone f) : 41 | IsMonotonic (s.imap_general f g) := by 42 | rw [Stream.isMonotonic_iff] 43 | rintro q ⟨i, b⟩ hq 44 | dsimp at q hq ⊢ 45 | apply hf 46 | apply Stream.isMonotonic_iff.mp hs q 47 | 48 | theorem IsLawful.imap {s : Stream ι α} [AddZeroClass α] [IsLawful s] {f : ι → ι'} {g : ι → ι' → ι} 49 | (hf : Monotone f) (hg : ∀ j, Monotone (g · j)) 50 | (hfg : ∀ (i : ι) (j : ι'), compareOfLessAndEq j (f i) = compareOfLessAndEq (g i j) i) : 51 | IsLawful (s.imap_general f g) := by 52 | haveI : IsBounded (s.imap_general f g) := IsBounded.imap s f g hfg 53 | refine ⟨s.mono.imap g hf, ?_⟩ 54 | rintro q ⟨j₁, b⟩ j₂ hj 55 | dsimp only [imap_general_seek, imap_general_σ, imap_general_valid] 56 | dsimp at q 57 | suffices ∀ i, f i = j₂ → s.eval (s.seek q (g (s.index q) j₁, b)) i = s.eval q i by sorry 58 | rintro i rfl 59 | by_cases le : s.index q ≤ i 60 | · apply ‹IsLawful s›.seek_spec 61 | have : (g i j₁, b) ≤ₗ (i, false) := sorry 62 | refine le_trans ?_ this 63 | simp only [Prod.Lex.mk_snd_mono_le_iff] 64 | exact hg _ le 65 | · rw [s.mono.eq_zero_of_lt_index, s.mono.eq_zero_of_lt_index] 66 | · simpa using le 67 | · refine lt_of_lt_of_le ?_ (s.mono q _) 68 | simpa using le 69 | 70 | end Etch.Verification.Stream 71 | -------------------------------------------------------------------------------- /bench/tpch-q9-duckdb-prep.sql: -------------------------------------------------------------------------------- 1 | ---------------- load data into memory 2 | 3 | 4 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 5 | N_NAME CHAR(25) NOT NULL, 6 | PRIMARY KEY (N_NATIONKEY)); 7 | CREATE TABLE PART ( P_PARTKEY INTEGER NOT NULL, 8 | P_NAME VARCHAR(55) NOT NULL, 9 | PRIMARY KEY (P_PARTKEY)); 10 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 11 | S_NATIONKEY INTEGER NOT NULL, 12 | PRIMARY KEY (S_SUPPKEY)); 13 | CREATE TABLE PARTSUPP ( PS_PARTKEY INTEGER NOT NULL, 14 | PS_SUPPKEY INTEGER NOT NULL, 15 | PS_SUPPLYCOST DOUBLE NOT NULL, 16 | PRIMARY KEY (PS_PARTKEY, PS_SUPPKEY)); 17 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 18 | O_ORDERDATE DATE NOT NULL, 19 | PRIMARY KEY (O_ORDERKEY)); 20 | CREATE TABLE LINEITEM ( L_PARTKEY INTEGER NOT NULL, 21 | L_SUPPKEY INTEGER NOT NULL, 22 | L_ORDERKEY INTEGER NOT NULL, 23 | L_LINENUMBER INTEGER NOT NULL, 24 | L_QUANTITY DOUBLE NOT NULL, 25 | L_EXTENDEDPRICE DOUBLE NOT NULL, 26 | L_DISCOUNT DOUBLE NOT NULL, 27 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 28 | 29 | CREATE INDEX LINEITEM_idx_q9 ON LINEITEM(L_PARTKEY, L_SUPPKEY, L_ORDERKEY, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT); 30 | CREATE INDEX PART_idx_q9 ON PART(P_PARTKEY, P_NAME); 31 | CREATE INDEX ORDERS_idx_q9 ON ORDERS(O_ORDERKEY, O_ORDERDATE); 32 | CREATE INDEX SUPPLIER_idx_q9 ON SUPPLIER(S_SUPPKEY, S_NATIONKEY); 33 | CREATE INDEX PARTSUPP_idx_q9 ON PARTSUPP(PS_PARTKEY, PS_SUPPKEY, PS_SUPPLYCOST); 34 | CREATE INDEX NATION_idx_q9 ON NATION(N_NATIONKEY, N_NAME); 35 | 36 | INSERT INTO NATION 37 | SELECT N_NATIONKEY, N_NAME 38 | FROM read_csv_auto('tpch-csv/nation.csv', delim=',', header=True) 39 | ORDER BY 1, 2; 40 | 41 | INSERT INTO PART 42 | SELECT P_PARTKEY, P_NAME 43 | FROM read_csv_auto('tpch-csv/part.csv', delim=',', header=True) 44 | ORDER BY 1, 2; 45 | 46 | INSERT INTO SUPPLIER 47 | SELECT S_SUPPKEY, S_NATIONKEY 48 | FROM read_csv_auto('tpch-csv/supplier.csv', delim=',', header=True) 49 | ORDER BY 1, 2; 50 | 51 | INSERT INTO PARTSUPP 52 | SELECT PS_PARTKEY, PS_SUPPKEY, PS_SUPPLYCOST 53 | FROM read_csv_auto('tpch-csv/partsupp.csv', delim=',', header=True) 54 | ORDER BY 1, 2, 3; 55 | 56 | INSERT INTO ORDERS 57 | SELECT O_ORDERKEY, O_ORDERDATE 58 | FROM read_csv_auto('tpch-csv/orders.csv', delim=',', header=True) 59 | ORDER BY 1, 2; 60 | 61 | INSERT INTO LINEITEM 62 | SELECT L_PARTKEY, L_SUPPKEY, L_ORDERKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT 63 | FROM read_csv_auto('tpch-csv/lineitem.csv', delim=',', header=True) 64 | ORDER BY 1, 2, 3, 5, 6, 7; 65 | 66 | PRAGMA database_size; 67 | 68 | PRAGMA threads=1; 69 | -------------------------------------------------------------------------------- /bench/tpch-q5-duckdb-prep.sql: -------------------------------------------------------------------------------- 1 | ---------------- load data into memory 2 | 3 | 4 | CREATE TABLE REGION ( R_REGIONKEY INTEGER NOT NULL, 5 | R_NAME CHAR(25) NOT NULL, 6 | PRIMARY KEY (R_REGIONKEY)); 7 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 8 | N_REGIONKEY INTEGER NOT NULL, 9 | N_NAME CHAR(25) NOT NULL, 10 | PRIMARY KEY (N_NATIONKEY)); 11 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 12 | S_NATIONKEY INTEGER NOT NULL, 13 | PRIMARY KEY (S_SUPPKEY)); 14 | CREATE TABLE CUSTOMER ( C_CUSTKEY INTEGER NOT NULL, 15 | C_NATIONKEY INTEGER NOT NULL, 16 | PRIMARY KEY (C_CUSTKEY)); 17 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 18 | O_CUSTKEY INTEGER NOT NULL, 19 | O_ORDERDATE DATE NOT NULL, 20 | PRIMARY KEY (O_ORDERKEY)); 21 | CREATE TABLE LINEITEM ( L_ORDERKEY INTEGER NOT NULL, 22 | L_SUPPKEY INTEGER NOT NULL, 23 | L_LINENUMBER INTEGER NOT NULL, 24 | L_EXTENDEDPRICE DOUBLE NOT NULL, -- actually DECIMAL(15,2), but etch uses double 25 | L_DISCOUNT DOUBLE NOT NULL, 26 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 27 | 28 | -- Note: adding these indices (and also the ORDER BY in INSERT) don't help DuckDB, but we do it out of fairness. 29 | 30 | CREATE INDEX REGION_idx_q5 ON REGION(R_REGIONKEY, R_NAME); 31 | CREATE INDEX NATION_idx_q5 ON NATION(N_NATIONKEY, N_REGIONKEY, N_NAME); 32 | CREATE INDEX SUPPLIER_idx_q5 ON SUPPLIER(S_SUPPKEY, S_NATIONKEY); 33 | CREATE INDEX ORDERS_idx_q5 ON ORDERS(O_ORDERKEY, O_ORDERDATE, O_CUSTKEY); 34 | CREATE INDEX CUSTOMER_idx_q5 ON CUSTOMER(C_CUSTKEY, C_NATIONKEY); 35 | CREATE INDEX LINEITEM_idx_q5 ON LINEITEM(L_ORDERKEY, L_SUPPKEY, L_EXTENDEDPRICE, L_DISCOUNT); 36 | 37 | INSERT INTO REGION 38 | SELECT R_REGIONKEY, R_NAME 39 | FROM read_csv_auto('tpch-csv/region.csv', delim=',', header=True) 40 | ORDER BY 1, 2; 41 | 42 | INSERT INTO NATION 43 | SELECT N_NATIONKEY, N_REGIONKEY, N_NAME 44 | FROM read_csv_auto('tpch-csv/nation.csv', delim=',', header=True) 45 | ORDER BY 1, 2, 3; 46 | 47 | INSERT INTO SUPPLIER 48 | SELECT S_SUPPKEY, S_NATIONKEY 49 | FROM read_csv_auto('tpch-csv/supplier.csv', delim=',', header=True) 50 | ORDER BY 1, 2; 51 | 52 | INSERT INTO CUSTOMER 53 | SELECT C_CUSTKEY, C_NATIONKEY 54 | FROM read_csv_auto('tpch-csv/customer.csv', delim=',', header=True) 55 | ORDER BY 1, 2; 56 | 57 | INSERT INTO ORDERS 58 | SELECT O_ORDERKEY, O_CUSTKEY, O_ORDERDATE 59 | FROM read_csv_auto('tpch-csv/orders.csv', delim=',', header=True) 60 | ORDER BY 1, 2, 3; 61 | 62 | INSERT INTO LINEITEM 63 | SELECT L_ORDERKEY, L_SUPPKEY, L_LINENUMBER, L_EXTENDEDPRICE, L_DISCOUNT 64 | FROM read_csv_auto('tpch-csv/lineitem.csv', delim=',', header=True) 65 | ORDER BY 1, 2, 3, 4, 5; 66 | 67 | PRAGMA database_size; 68 | 69 | PRAGMA threads=1; 70 | -------------------------------------------------------------------------------- /bench-duckdb.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "common.h" 11 | #include "duckdb/duckdb.hpp" 12 | 13 | namespace { 14 | 15 | namespace d = duckdb; 16 | 17 | std::string_view DB_PREFIX = "db:"; 18 | 19 | std::pair ParseArg(std::string_view s) { 20 | int reps = -1; 21 | bool found = false; 22 | 23 | for (int i = 0; i < s.size(); ++i) { 24 | if ('0' <= s[i] && s[i] <= '9') { 25 | found = true; 26 | } else if (s[i] == ':' && found) { 27 | reps = std::stoul(std::string(s.substr(0, i))); 28 | s.remove_prefix(i + 1); 29 | break; 30 | } else { 31 | break; 32 | } 33 | } 34 | 35 | return {reps, s}; 36 | } 37 | 38 | } // namespace 39 | 40 | int main(int argc, char* argv[]) { 41 | std::unique_ptr db_ptr; 42 | 43 | int argi = 1; 44 | if (argi < argc && std::string_view(argv[argi]).starts_with(DB_PREFIX)) { 45 | std::string_view db_file(argv[1]); 46 | db_file.remove_prefix(DB_PREFIX.size()); 47 | db_ptr = std::make_unique(std::string(db_file)); 48 | } else { 49 | db_ptr = std::make_unique(); 50 | } 51 | 52 | d::DuckDB& db = *db_ptr; 53 | d::Connection con(db); 54 | 55 | int query_idx = 1; 56 | 57 | for (; argi < argc; ++argi) { 58 | auto [reps, file] = ParseArg(argv[argi]); 59 | 60 | std::ifstream f((std::string(file))); 61 | std::stringstream file_ss; 62 | file_ss << f.rdbuf(); 63 | std::string sql_str = std::move(file_ss).str(); 64 | 65 | std::stringstream str_ss; 66 | str_ss << "q" << query_idx; 67 | std::string str_i = std::move(str_ss).str(); 68 | 69 | if (reps > 0) { 70 | std::unique_ptr prepare; 71 | time( 72 | [&]() { 73 | prepare = con.Prepare(sql_str); 74 | return 0; 75 | }, 76 | (str_i + " prep").c_str(), 1); 77 | 78 | if (prepare->HasError()) { 79 | std::cout << "Preparing " << file << " failed: " << prepare->GetError() 80 | << '\n'; 81 | return 1; 82 | } 83 | 84 | std::unique_ptr res; 85 | std::vector values; 86 | res = prepare->Execute(values, /*allow_stream_result=*/false); 87 | res = prepare->Execute(values, /*allow_stream_result=*/false); 88 | 89 | time( 90 | [&]() { 91 | res = prepare->Execute(values, 92 | /*allow_stream_result=*/false); 93 | return 0; 94 | }, 95 | str_i.c_str(), reps); 96 | std::cout << res->ToString() << '\n'; 97 | } else { 98 | std::unique_ptr res; 99 | time( 100 | [&]() { 101 | res = con.Query(sql_str); 102 | return 0; 103 | }, 104 | str_i.c_str(), 1); 105 | std::cout << res->ToString() << '\n'; 106 | } 107 | 108 | ++query_idx; 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /bench/tpch-q5-duckdb-prep-foreign-key.sql: -------------------------------------------------------------------------------- 1 | ---------------- load data into memory 2 | 3 | 4 | CREATE TABLE REGION ( R_REGIONKEY INTEGER NOT NULL, 5 | R_NAME CHAR(25) NOT NULL, 6 | PRIMARY KEY (R_REGIONKEY)); 7 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 8 | N_REGIONKEY INTEGER NOT NULL REFERENCES REGION (R_REGIONKEY), 9 | N_NAME CHAR(25) NOT NULL, 10 | PRIMARY KEY (N_NATIONKEY)); 11 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 12 | S_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 13 | PRIMARY KEY (S_SUPPKEY)); 14 | CREATE TABLE CUSTOMER ( C_CUSTKEY INTEGER NOT NULL, 15 | C_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 16 | PRIMARY KEY (C_CUSTKEY)); 17 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 18 | O_CUSTKEY INTEGER NOT NULL REFERENCES CUSTOMER (C_CUSTKEY), 19 | O_ORDERDATE DATE NOT NULL, 20 | PRIMARY KEY (O_ORDERKEY)); 21 | CREATE TABLE LINEITEM ( L_ORDERKEY INTEGER NOT NULL REFERENCES ORDERS (O_ORDERKEY), 22 | L_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 23 | L_LINENUMBER INTEGER NOT NULL, 24 | L_EXTENDEDPRICE DOUBLE NOT NULL, -- actually DECIMAL(15,2), but etch uses double 25 | L_DISCOUNT DOUBLE NOT NULL, 26 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 27 | 28 | -- CREATE INDEX REGION_idx_q5 ON REGION(R_REGIONKEY, R_NAME); 29 | -- CREATE INDEX NATION_idx_q5 ON NATION(N_NATIONKEY, N_REGIONKEY, N_NAME); 30 | -- CREATE INDEX SUPPLIER_idx_q5 ON SUPPLIER(S_SUPPKEY, S_NATIONKEY); 31 | -- CREATE INDEX ORDERS_idx_q5 ON ORDERS(O_ORDERKEY, O_ORDERDATE, O_CUSTKEY); 32 | -- CREATE INDEX CUSTOMER_idx_q5 ON CUSTOMER(C_CUSTKEY, C_NATIONKEY); 33 | -- CREATE INDEX LINEITEM_idx_q5 ON LINEITEM(L_ORDERKEY, L_SUPPKEY, L_EXTENDEDPRICE, L_DISCOUNT); 34 | 35 | INSERT INTO REGION 36 | SELECT R_REGIONKEY, R_NAME 37 | FROM read_csv_auto('tpch-csv/region.csv', delim=',', header=True) 38 | ORDER BY 1, 2; 39 | 40 | INSERT INTO NATION 41 | SELECT N_NATIONKEY, N_REGIONKEY, N_NAME 42 | FROM read_csv_auto('tpch-csv/nation.csv', delim=',', header=True) 43 | ORDER BY 1, 2, 3; 44 | 45 | INSERT INTO SUPPLIER 46 | SELECT S_SUPPKEY, S_NATIONKEY 47 | FROM read_csv_auto('tpch-csv/supplier.csv', delim=',', header=True) 48 | ORDER BY 1, 2; 49 | 50 | INSERT INTO CUSTOMER 51 | SELECT C_CUSTKEY, C_NATIONKEY 52 | FROM read_csv_auto('tpch-csv/customer.csv', delim=',', header=True) 53 | ORDER BY 1, 2; 54 | 55 | INSERT INTO ORDERS 56 | SELECT O_ORDERKEY, O_CUSTKEY, O_ORDERDATE 57 | FROM read_csv_auto('tpch-csv/orders.csv', delim=',', header=True) 58 | ORDER BY 1, 2, 3; 59 | 60 | INSERT INTO LINEITEM 61 | SELECT L_ORDERKEY, L_SUPPKEY, L_LINENUMBER, L_EXTENDEDPRICE, L_DISCOUNT 62 | FROM read_csv_auto('tpch-csv/lineitem.csv', delim=',', header=True) 63 | ORDER BY 1, 2, 3, 4, 5; 64 | 65 | PRAGMA database_size; 66 | 67 | PRAGMA threads=1; 68 | -------------------------------------------------------------------------------- /Etch/Util/Labels.lean: -------------------------------------------------------------------------------- 1 | import Std 2 | 3 | /-! 4 | Defines a type that provides endless labels between any two labels. 5 | -/ 6 | 7 | namespace Etch 8 | 9 | /- 10 | -- Implementation with standard Nat ordering, for testing 11 | 12 | structure LabelIdx where 13 | data : Nat 14 | deriving BEq, Inhabited 15 | 16 | instance : Ord LabelIdx := ⟨fun x y => compare x.data y.data⟩ 17 | instance : LT LabelIdx := ⟨fun x y => x.data < y.data⟩ 18 | instance (x y : LabelIdx) : Decidable (x < y) := inferInstanceAs <| Decidable (x.data < y.data) 19 | 20 | def LabelIdx.nth (n : Nat) : LabelIdx := {data := n} 21 | -/ 22 | 23 | /-- 24 | A label is a natural number with the "binary-revlex ordering". 25 | 26 | The ordering is given by writing the binary representation with the least-significant bit first 27 | as an infinite sequence and then doing lex ordering. 28 | The number `0` is least in this order. 29 | -/ 30 | structure LabelIdx where 31 | data : Nat 32 | deriving BEq 33 | 34 | def LabelIdx.bitsAux (x : Nat) : List Nat := 35 | if h : x = 0 then 36 | [] 37 | else 38 | have : x / 2 < x := by omega 39 | (x % 2) :: LabelIdx.bitsAux (x / 2) 40 | 41 | instance : Repr LabelIdx where 42 | reprPrec x _ := 43 | let bits := LabelIdx.bitsAux x.data 44 | "LabelIdx(" ++ Lean.Format.joinSep bits ", " ++ ")" 45 | 46 | def LabelIdx.compareAux (x y : Nat) : Ordering := 47 | if h : x = y then .eq 48 | else 49 | have : x / 2 + y / 2 < x + y := by 50 | cases x <;> cases y <;> omega 51 | match compare (x % 2) (y % 2) with 52 | | .eq => compareAux (x / 2) (y / 2) 53 | | .lt => .lt 54 | | .gt => .gt 55 | termination_by x + y 56 | 57 | instance : Ord LabelIdx := ⟨fun x y => LabelIdx.compareAux x.data y.data⟩ 58 | instance : LT LabelIdx := ⟨fun x y => compare x y == .lt⟩ 59 | 60 | /-- Finds the first `0` with no `1`s after and sets it to `1`. 61 | This is injective. -/ 62 | def LabelIdx.freshAfterAux (x : Nat) : Nat := 63 | if h : x = 0 then 64 | 1 65 | else 66 | have : x / 2 < x := by omega 67 | (x % 2) + 2 * LabelIdx.freshAfterAux (x / 2) 68 | 69 | def LabelIdx.freshAfter (x : LabelIdx) : LabelIdx := LabelIdx.mk (freshAfterAux x.data) 70 | 71 | def LabelIdx.freshBeforeAux (x y : Nat) : Nat := 72 | if h : y = 0 then 73 | panic! "x < y not true in freshBeforeAux" 74 | else 75 | let xb := x % 2 76 | let yb := y % 2 77 | if xb = yb then 78 | have : y / 2 < y := by omega 79 | xb + 2 * freshBeforeAux (x / 2) (y / 2) 80 | else if xb < yb then 81 | 2 * freshAfterAux (xb / 2) 82 | else 83 | panic! "x < y not true in freshBeforeAux" 84 | termination_by y 85 | 86 | /-- Gives a label that is between `x` and `y`, assuming that `x < y`. -/ 87 | def LabelIdx.freshBefore (x y : LabelIdx) : LabelIdx := 88 | LabelIdx.mk <| LabelIdx.freshBeforeAux x.data y.data 89 | 90 | /-- An injection from `Nat` into LabelIdxs. -/ 91 | def LabelIdx.nth (n : Nat) : LabelIdx := 92 | LabelIdx.mk <| 2 ^ (n + 1) - 1 93 | 94 | instance : Inhabited LabelIdx := ⟨LabelIdx.nth 0⟩ 95 | 96 | /- 97 | #eval List.range 10 |>.map LabelIdx.nth 98 | #eval List.range 10 |>.map fun n => LabelIdx.nextBefore (LabelIdx.nth n) (LabelIdx.nth (n + 1)) 99 | -/ 100 | -------------------------------------------------------------------------------- /bench/tpch-q9-duckdb-prep-foreign-key.sql: -------------------------------------------------------------------------------- 1 | ---------------- load data into memory 2 | 3 | 4 | CREATE TABLE NATION ( N_NATIONKEY INTEGER NOT NULL, 5 | N_NAME CHAR(25) NOT NULL, 6 | PRIMARY KEY (N_NATIONKEY)); 7 | CREATE TABLE PART ( P_PARTKEY INTEGER NOT NULL, 8 | P_NAME VARCHAR(55) NOT NULL, 9 | PRIMARY KEY (P_PARTKEY)); 10 | CREATE TABLE SUPPLIER ( S_SUPPKEY INTEGER NOT NULL, 11 | S_NATIONKEY INTEGER NOT NULL REFERENCES NATION (N_NATIONKEY), 12 | PRIMARY KEY (S_SUPPKEY)); 13 | CREATE TABLE PARTSUPP ( PS_PARTKEY INTEGER NOT NULL REFERENCES PART (P_PARTKEY), 14 | PS_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 15 | PS_SUPPLYCOST DOUBLE NOT NULL, 16 | PRIMARY KEY (PS_PARTKEY, PS_SUPPKEY)); 17 | CREATE TABLE ORDERS ( O_ORDERKEY INTEGER NOT NULL, 18 | O_ORDERDATE DATE NOT NULL, 19 | PRIMARY KEY (O_ORDERKEY)); 20 | CREATE TABLE LINEITEM ( L_PARTKEY INTEGER NOT NULL REFERENCES PART (P_PARTKEY), 21 | L_SUPPKEY INTEGER NOT NULL REFERENCES SUPPLIER (S_SUPPKEY), 22 | L_ORDERKEY INTEGER NOT NULL REFERENCES ORDERS (O_ORDERKEY), 23 | L_LINENUMBER INTEGER NOT NULL, 24 | L_QUANTITY DOUBLE NOT NULL, 25 | L_EXTENDEDPRICE DOUBLE NOT NULL, 26 | L_DISCOUNT DOUBLE NOT NULL, 27 | PRIMARY KEY (L_ORDERKEY, L_LINENUMBER)); 28 | 29 | -- Indices don't help 30 | -- CREATE INDEX LINEITEM_idx_q9 ON LINEITEM(L_PARTKEY, L_SUPPKEY, L_ORDERKEY, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT); 31 | -- CREATE INDEX PART_idx_q9 ON PART(P_PARTKEY, P_NAME); 32 | -- CREATE INDEX ORDERS_idx_q9 ON ORDERS(O_ORDERKEY, O_ORDERDATE); 33 | -- CREATE INDEX SUPPLIER_idx_q9 ON SUPPLIER(S_SUPPKEY, S_NATIONKEY); 34 | -- CREATE INDEX PARTSUPP_idx_q9 ON PARTSUPP(PS_PARTKEY, PS_SUPPKEY, PS_SUPPLYCOST); 35 | -- CREATE INDEX NATION_idx_q9 ON NATION(N_NATIONKEY, N_NAME); 36 | 37 | INSERT INTO NATION 38 | SELECT N_NATIONKEY, N_NAME 39 | FROM read_csv_auto('tpch-csv/nation.csv', delim=',', header=True) 40 | ORDER BY 1, 2; 41 | 42 | INSERT INTO PART 43 | SELECT P_PARTKEY, P_NAME 44 | FROM read_csv_auto('tpch-csv/part.csv', delim=',', header=True) 45 | ORDER BY 1, 2; 46 | 47 | INSERT INTO SUPPLIER 48 | SELECT S_SUPPKEY, S_NATIONKEY 49 | FROM read_csv_auto('tpch-csv/supplier.csv', delim=',', header=True) 50 | ORDER BY 1, 2; 51 | 52 | INSERT INTO PARTSUPP 53 | SELECT PS_PARTKEY, PS_SUPPKEY, PS_SUPPLYCOST 54 | FROM read_csv_auto('tpch-csv/partsupp.csv', delim=',', header=True) 55 | ORDER BY 1, 2, 3; 56 | 57 | INSERT INTO ORDERS 58 | SELECT O_ORDERKEY, O_ORDERDATE 59 | FROM read_csv_auto('tpch-csv/orders.csv', delim=',', header=True) 60 | ORDER BY 1, 2; 61 | 62 | INSERT INTO LINEITEM 63 | SELECT L_PARTKEY, L_SUPPKEY, L_ORDERKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT 64 | FROM read_csv_auto('tpch-csv/lineitem.csv', delim=',', header=True) 65 | ORDER BY 1, 2, 3, 5, 6, 7; 66 | 67 | PRAGMA database_size; 68 | 69 | PRAGMA threads=1; 70 | -------------------------------------------------------------------------------- /bench-filtered-spmv.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "common.h" 8 | #include "operators.h" 9 | #include "sqlite3.h" 10 | 11 | int array_size = 4000000; 12 | 13 | int* ssA1_pos = (int*)calloc(2, sizeof(int)); 14 | int* ssA1_crd = (int*)calloc(array_size, sizeof(int)); 15 | int* ssA2_pos = (int*)calloc(array_size, sizeof(int)); 16 | int* ssA2_crd = (int*)calloc(array_size, sizeof(int)); 17 | double* ssA_vals = (double*)calloc(array_size, sizeof(double)); 18 | 19 | int* sV1_pos = (int*)calloc(2, sizeof(int)); 20 | int* sV1_crd = (int*)calloc(array_size, sizeof(int)); 21 | double* sV_vals = (double*)calloc(array_size, sizeof(double)); 22 | 23 | #include "gen_filtered_spmv.c" 24 | 25 | static sqlite3* db; 26 | 27 | static int populate_filtered_spmv(sqlite3* db) { 28 | char* zErrMsg; 29 | int rc; 30 | void* data = NULL; 31 | 32 | #define GET_TBL2(out_name, tbl_name, ...) \ 33 | do { \ 34 | rc = sqlite3_exec( \ 35 | db, "SELECT " #__VA_ARGS__ " FROM " #tbl_name " ORDER BY 1, 2", \ 36 | gen_##out_name##_callback, (void*)data, &zErrMsg); \ 37 | if (rc != SQLITE_OK) { \ 38 | printf("%s:%d: %s\n", __FILE__, __LINE__, zErrMsg); \ 39 | return rc; \ 40 | } \ 41 | } while (false) 42 | 43 | #define GET_TBL3(out_name, tbl_name, ...) \ 44 | do { \ 45 | rc = sqlite3_exec( \ 46 | db, "SELECT " #__VA_ARGS__ " FROM " #tbl_name " ORDER BY 1, 2, 3", \ 47 | gen_##out_name##_callback, (void*)data, &zErrMsg); \ 48 | if (rc != SQLITE_OK) { \ 49 | printf("%s:%d: %s\n", __FILE__, __LINE__, zErrMsg); \ 50 | return rc; \ 51 | } \ 52 | } while (false) 53 | 54 | GET_TBL3(ssA, A, i, j, v); 55 | GET_TBL2(sV, V, i, v); 56 | 57 | return rc; 58 | } 59 | 60 | int main(int argc, char* argv[]) { 61 | int rc = SQLITE_OK; 62 | 63 | sqlite3_initialize(); 64 | rc = sqlite3_open(argc > 1 ? argv[1] : "./data/pldi.db", &db); 65 | 66 | if (rc) { 67 | fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(db)); 68 | return 1; 69 | } else { 70 | fprintf(stderr, "Opened database successfully\n"); 71 | } 72 | 73 | time([&rc]() { return rc = populate_filtered_spmv(db); }, 74 | "populate_filtered_spmv", 1); 75 | if (rc != SQLITE_OK) { 76 | return 1; 77 | } 78 | printf("Loaded\n"); 79 | 80 | for (double threshold : 81 | std::initializer_list{0.0, 0.2, 0.4, 0.6, 0.8, 1.0}) { 82 | time([threshold]() { return filter_spmv(threshold); }, 83 | ("filter_spmv_" + std::to_string(threshold)).c_str(), 50); 84 | } 85 | 86 | sqlite3_close(db); 87 | return 0; 88 | } 89 | -------------------------------------------------------------------------------- /Etch/StreamFusion/ExpandSeq.lean: -------------------------------------------------------------------------------- 1 | /- TODO: make "indexed functor" class and generalize Expand; remove this file -/ 2 | /- This is largely duplicated from Expand! Please ensure any changes stay in sync. -/ 3 | import Etch.StreamFusion.SequentialStream 4 | import Etch.Util.ExpressionTree 5 | 6 | open Etch.ExpressionTree 7 | 8 | namespace Etch.Verification.SequentialStream 9 | 10 | section 11 | 12 | local infixr:25 " →ₛ " => SequentialStream 13 | 14 | -- todo: decide on a nicer notation 15 | local notation n:30 "~" i:30 => LabeledIndex n i 16 | 17 | variable (i : LabelIdx) (ι : Type) 18 | @[inline] instance [LinearOrder ι] : LinearOrder (i~ι) := by change LinearOrder ι; exact inferInstance 19 | @[inline] instance [Inhabited ι] : Inhabited (i~ι) := by change Inhabited ι; exact inferInstance 20 | 21 | instance : TypeHasIndex (i~ι →ₛ β) i ι β where 22 | instance : TypeHasIndex (i~ι → β) i ι β where 23 | 24 | instance [Scalar α] : Label [] α α := ⟨id⟩ 25 | instance [Label is α β] : Label (i::is) (ι →ₛ α) (i~ι →ₛ β) := ⟨map (Label.label is)⟩ 26 | instance [Label is α β] : Label (i::is) (ι → α) (i~ι → β) := ⟨(Label.label is ∘ .)⟩ 27 | instance [Label is α β] : Label (i::is) (i'~ι →ₛ α) (i~ι →ₛ β) := ⟨map (Label.label is)⟩ 28 | instance [Label is α β] : Label (i::is) (i'~ι → α) (i~ι → β) := ⟨(Label.label is ∘ .)⟩ 29 | 30 | def idx (x : α) (shape : List LabelIdx) [Label shape α β] := Label.label shape x 31 | 32 | instance (I : Type) : MapIndex i α β (i~I →ₛ α) (i~I →ₛ β)where 33 | map f s := s.map f 34 | 35 | instance (J : Type) [IdxLt j i] [MapIndex i a b a' b'] : MapIndex i a b (j~J →ₛ a') (j~J →ₛ b') where 36 | map f s := s.map (MapIndex.map i f) 37 | 38 | instance : Contract i (i~ι →ₛ α) (i~Unit →ₛ α) := ⟨fun s => contract s⟩ 39 | instance [Contract j α β] [IdxLt i j] : Contract j (i~ι →ₛ α) (i~ι →ₛ β) := ⟨map (Contract.contract j)⟩ 40 | instance [Contract j α β] : Contract j (Unit →ₛ α) (Unit →ₛ β) := ⟨map (Contract.contract j)⟩ 41 | 42 | --notation "Σ " j ", " t => Contract.contract j t 43 | --notation "Σ " j ": " t => Contract.contract j t 44 | 45 | variable {σ : List (LabelIdx × Type)} 46 | 47 | section 48 | variable {α β : Type*} 49 | instance expBase : Expand [] α α := ⟨id⟩ 50 | instance expScalar {ι : Type} {i : LabelIdx} [Scalar α] [Expand σ α β] : Expand ((i,ι) :: σ) α (i~ι → β) := ⟨fun v _ => Expand.expand σ v⟩ 51 | instance expLt {ι : Type} {i j : LabelIdx} [IdxLt i j] [Expand σ (j~ι' →ₛ α) β] : Expand ((i,ι) :: σ) (j~ι' →ₛ α) (i~ι → β) := ⟨fun v _ => Expand.expand σ v⟩ 52 | instance expGt {ι : Type} {i j : LabelIdx} [IdxLt j i] [Expand ((i,ι) :: σ) α β] : Expand ((i,ι) :: σ) (j~ι' →ₛ α) (j~ι' →ₛ β) := ⟨fun v => map (Expand.expand ((i,ι)::σ)) v⟩ 53 | instance expEq {ι : Type} {i : LabelIdx} [Expand σ α β] : Expand ((i,ι) :: σ) (i~ι →ₛ α) (i~ι →ₛ β) := ⟨fun v => map (Expand.expand σ) v⟩ 54 | 55 | instance expLtFun {ι : Type} {i j : LabelIdx} [IdxLt i j] [Expand σ (j~ι' → α) β] : Expand ((i,ι) :: σ) (j~ι' → α) (i~ι → β) := ⟨fun v _ => Expand.expand σ v⟩ 56 | instance expGtFun {ι : Type} {i j : LabelIdx} [IdxLt j i] [Expand ((i,ι) :: σ) α β] : Expand ((i,ι) :: σ) (j~ι' → α) (j~ι' → β) := ⟨fun v => Expand.expand ((i,ι)::σ) ∘ v⟩ 57 | instance expEqFun {ι : Type} {i : LabelIdx} [Expand σ α β] : Expand ((i,ι) :: σ) (i~ι → α) (i~ι → β) := ⟨fun v => (Expand.expand σ) ∘ v⟩ 58 | end 59 | 60 | -- Ignoring `base` for now. It should be used for a coercion. 61 | instance [Expand σ α β] : EnsureBroadcast σ base α β where 62 | broadcast := Expand.expand σ 63 | 64 | end 65 | 66 | end Etch.Verification.SequentialStream 67 | -------------------------------------------------------------------------------- /archive/src/test.lean: -------------------------------------------------------------------------------- 1 | import algebra.big_operators.finsupp 2 | import tactic.field_simp 3 | import data.rat.floor 4 | 5 | def BoundedStreamGen : Type := sorry 6 | 7 | def BoundedStreamGen.eval (x : BoundedStreamGen) : ℕ →₀ ℤ := sorry 8 | 9 | def contract (x : BoundedStreamGen) : BoundedStreamGen := sorry 10 | 11 | def externSparseVec (x : list ℕ) (y : list ℤ) : BoundedStreamGen := sorry 12 | 13 | lemma externSparseVec.spec (x : list ℕ) (y : list ℤ) : 14 | (externSparseVec x y).eval = (list.zip_with finsupp.single x y).sum := sorry 15 | 16 | lemma contract.spec (s : BoundedStreamGen) : 17 | (contract s).eval = s.eval.map_domain (λ _, 0) := sorry 18 | 19 | class HasCorrectEval (x : BoundedStreamGen) (gn : out_param $ ℕ →₀ ℤ) : Prop := 20 | (iseq [] : x.eval = gn) 21 | 22 | open HasCorrectEval (iseq) 23 | 24 | instance externSparseVec.correctEval (x : list ℕ) (y : list ℤ) [fact (x.length = y.length)] : 25 | HasCorrectEval (externSparseVec x y) (list.zip_with finsupp.single x y).sum := ⟨externSparseVec.spec _ _⟩ 26 | 27 | instance contract.correctEval (s : BoundedStreamGen) {gn : ℕ →₀ ℤ} 28 | [HasCorrectEval s gn] : HasCorrectEval (contract s) (finsupp.map_domain.add_monoid_hom (λ _, 0) gn) := 29 | ⟨by { rw [contract.spec, iseq], refl, }⟩ 30 | 31 | def sum_vec (x : list ℕ) (y : list ℤ) : BoundedStreamGen := contract (externSparseVec x y) 32 | 33 | lemma sum_vec.spec (x : list ℕ) (y : list ℤ) (hx : x.length = y.length) : 34 | (sum_vec x y).eval = finsupp.single 0 y.sum := 35 | begin 36 | haveI : fact _ := ⟨hx⟩, 37 | rw [sum_vec, iseq],-- ← list.sum_hom, list.map_zip_with], 38 | -- simp, 39 | end 40 | 41 | 42 | 43 | open HasCorrectEval (iseq) 44 | 45 | 46 | 47 | universe u 48 | lemma eq_of_heq' {α : Sort u} {a a' : α} (h : a == a') : a = a' := 49 | have ∀ (α' : Sort u) (a' : α') (h₁ : @heq α a α' a') (h₂ : α = α'), (eq.rec_on h₂ a : α') = a', from 50 | λ (α' : Sort u) (a' : α') (h₁ : @heq α a α' a'), heq.rec_on h₁ (λ h₂ : α = α, rfl), 51 | show (eq.rec_on (eq.refl α) a : α) = a', from 52 | this α a' h (eq.refl α) 53 | 54 | 55 | def star (f : ℕ → ℕ) (hf : ∀ n : ℕ, f (n + 1) ≤ n) : ℕ → ℕ 56 | | 0 := 0 57 | | (n + 1) := have _ := nat.lt_succ_of_le (hf n), if f (n + 1) = 0 then 0 else 1 + star (f (n + 1)) 58 | 59 | notation f`*`:9000 := star f (by assumption) 60 | 61 | lemma star_eq {f : ℕ → ℕ} (h₁ : f 0 = 0) (h₂ : ∀ n, f (n + 1) ≤ n) (n : ℕ) : 62 | f* n = if f n = 0 then 0 else 1 + f* (f n) := 63 | by cases n; simp [star, h₁] 64 | 65 | @[simp] lemma star_zero {f : ℕ → ℕ} (h₂ : ∀ n, f (n + 1) ≤ n) : f* 0 = 0 := by simp [star] 66 | 67 | @[simp] lemma star_one {f : ℕ → ℕ} (h₂ : ∀ n, f (n + 1) ≤ n) : f* 1 = 0 := 68 | by simpa [star, imp_false] using h₂ 0 69 | 70 | lemma star_contraction_of_contraction {f : ℕ → ℕ} (H : ∀ n, f (n + 1) ≤ n) (n : ℕ) : 71 | f* 0 = 0 ∧ f* (n + 1) ≤ n := 72 | begin 73 | split, { simp, }, 74 | induction n using nat.strong_induction_on with n ih, 75 | rw star, 76 | split_ifs, { exact zero_le _, }, 77 | specialize H n, refine trans _ H, 78 | cases (f (n + 1)) with m hm, { contradiction, }, 79 | rw [nat.succ_eq_add_one, add_comm 1 _, add_le_add_iff_right], 80 | refine ih m _, 81 | rwa ← nat.succ_le_iff, 82 | end 83 | 84 | -- open_locale big_operators 85 | 86 | 87 | -- lemma egyption_fraction_wf (r : ℚ) (h₁ : 0 < r) (h₂ : r < 1) : 88 | -- (r - (1 : ℚ) / ⌈1/r⌉).num < r.num := 89 | -- begin 90 | 91 | -- end 92 | 93 | -- def egyptian_fraction : ∀ (r : ℚ) (h₁ : 0 ≤ r) (h₂ : r < 1), finset ℤ | r h₁ h₂ := 94 | -- if H : r = 0 then 0 else 95 | -- let n : ℤ := ⌈1/r⌉ in 96 | -- have wf : (r - (1 : ℚ) / n).num < r.num := sorry, 97 | -- insert n (egyptian_fraction (r - (1 : ℚ) / n) _ _) 98 | -- using_well_founded {rel_tac := λ _ _, `[exact ⟨_, measure_wf rat.num⟩]} 99 | -------------------------------------------------------------------------------- /archive/src/verification/test.lean: -------------------------------------------------------------------------------- 1 | import verification.semantics.stream_add 2 | import verification.semantics.stream_multiply 3 | import verification.semantics.stream_replicate 4 | import verification.semantics.stream_props 5 | 6 | local infixr ` ↠ `:50 := SimpleStream 7 | 8 | section 9 | variables {ι₁ ι₂ ι₃ : Type} [linear_order ι₁] [linear_order ι₂] 10 | [linear_order ι₃] {R : Type} [semiring R] 11 | 12 | open Eval (eval) 13 | 14 | local notation `∑ᵢ ` s := s.contract 15 | 16 | local notation (name := bool_add) a ` && ` b := a + b 17 | 18 | -- 19 | noncomputable instance SimpleStream.AddZeroEval_weird : 20 | AddZeroEval (ι₁ ↠ ι₂ ↠ ι₃ ↠ R) ι₁ (ι₂ →₀ ι₃ →₀ R) := 21 | SimpleStream.AddZeroEval 22 | 23 | example (a b c d : ι₁ ↠ ι₂ ↠ ι₃ ↠ R) : 24 | eval (a * (b + c) * d) = 25 | (eval a) * ((eval b) + (eval c)) * (eval d) := 26 | by simp 27 | 28 | example [semiring R] (a b c : SimpleStream ι₁ (SimpleStream ι₂ R)) : 29 | eval ((a + b) * c) = eval a * eval c + eval b * eval c := 30 | by simp [add_mul] 31 | 32 | end 33 | 34 | open_locale big_operators 35 | 36 | section 37 | variables {ι₁ ι₂ ι₃ : Type} [linear_order ι₁] [linear_order ι₂] 38 | [linear_order ι₃] {R : Type} [semiring R] 39 | 40 | local notation `∑ᵢ ` s := s.contract 41 | 42 | -- Unfortunately, Lean doesn't like the notation `eval s x y` because it doesn't know `eval s x` is going to be a function 43 | -- TODO: Fix 44 | @[reducible] def eval {ι₁ ι₂ α₁ R : Type*} [has_zero R] [Eval α₁ ι₁ (ι₂ →₀ R)] 45 | (x : α₁) : ι₁ →₀ ι₂ →₀ R := Eval.eval x 46 | 47 | @[reducible] def eval3 {ι₁ ι₂ ι₃ α₁ R : Type*} [has_zero R] [Eval α₁ ι₁ (ι₂ →₀ ι₃ →₀ R)] 48 | (x : α₁) : ι₁ →₀ ι₂ →₀ ι₃ →₀ R := Eval.eval x 49 | 50 | local attribute [simp] eval finsupp.sum_range_eq_sum finsupp.sum 51 | finsupp.finset_sum_apply 52 | 53 | example (a b : ι₁ ↠ ι₂ ↠ R) 54 | (j : ι₂) : eval (∑ᵢ (a * b)) () j = 55 | ∑ i in (eval a * eval b).support, 56 | (eval a i j * eval b i j) := 57 | by rw Eval.contract'; simp 58 | 59 | end 60 | 61 | constants (n₁ n₂ n₃ : ℕ) 62 | 63 | variables {R : Type} [semiring R] 64 | 65 | variables {ι₁ ι₂ ι₃ : Type} [linear_order ι₁] [linear_order ι₂] [linear_order ι₃] 66 | (m₁ : fin n₁ ≃o ι₁) 67 | (m₂ : fin n₂ ≃o ι₂) 68 | (m₃ : fin n₃ ≃o ι₃) 69 | 70 | local notation `⇑₁` := SimpleStream.replicate' (m₁ : fin n₁ ↪o ι₁) 71 | local notation `⇑₂` := SimpleStream.replicate' (m₂ : fin n₂ ↪o ι₂) 72 | local notation `⇑₃` := SimpleStream.replicate' (m₃ : fin n₃ ↪o ι₃) 73 | 74 | section 75 | 76 | local notation `∑ᵢ ` s := s.contract 77 | 78 | local attribute [simp] eval finsupp.sum_range_eq_sum finsupp.sum 79 | finsupp.finset_sum_apply finsupp.const 80 | SimpleStream.replicate'.spec_equiv -- TODO: tag this as @[simp]? 81 | 82 | example (c : R) (k : ι₃) : Eval.eval (⇑₃ c) k = c := 83 | by simp [Eval.eval] 84 | 85 | example (v w : ι₃ ↠ R) (k : ι₃) : Eval.eval (v * w) k = (Eval.eval v k) * (Eval.eval w k) := 86 | by simp 87 | 88 | example (c : R) (v : ι₃ ↠ R) (k : ι₃) : Eval.eval ((⇑₃ c) * v) k = c * (Eval.eval v k) := 89 | by { simp_rw [MulEval.hmul, Eval.eval], simp } 90 | 91 | -- Help instance inferrer out a bit. 92 | noncomputable instance test_instance : 93 | Eval (StreamExec unit R) unit R := infer_instance 94 | noncomputable instance test_instance2 {ι} [linear_order ι] : 95 | Eval (StreamExec unit (ι ↠ R)) unit (ι →₀ R) := infer_instance 96 | 97 | noncomputable def matmul (a : ι₁ ↠ ι₂ ↠ R) (b : ι₂ ↠ ι₃ ↠ R) := 98 | (λ (r : ι₂ ↠ R), ∑ᵢ ((⇑₃ <§₂> r) * b)) <§₂> a 99 | 100 | example (a : ι₁ ↠ ι₂ ↠ R) (b : ι₂ ↠ ι₃ ↠ R) (i : ι₁) (k : ι₃) : 101 | eval3 (matmul m₃ a b) i () k = 102 | ∑ j in (eval a i).support ∪ (eval b).support, 103 | (eval a i j * eval b j k) := 104 | begin 105 | simp_rw [eval, eval3], 106 | sorry -- TODO: one day, hopefully soon 107 | end 108 | 109 | end 110 | -------------------------------------------------------------------------------- /Etch/StreamFusion/TestUtil.lean: -------------------------------------------------------------------------------- 1 | import Etch.StreamFusion.Stream 2 | import Etch.StreamFusion.Traversals 3 | 4 | def csvHeader := "time,test\n" 5 | 6 | def time (s : String) (m : Unit → IO α) : IO α := do 7 | let t0 ← IO.monoMsNow 8 | let v ← m () 9 | let t1 ← IO.monoMsNow 10 | IO.println s!"[{s}] time: {t1-t0}" 11 | pure v 12 | 13 | def time' (s : String) (m : Unit → IO α) : IO (α × ℕ) := do 14 | let t0 ← IO.monoMsNow 15 | let v ← m () 16 | let t1 ← IO.monoMsNow 17 | pure (v, t1-t0) 18 | 19 | open Etch.Verification 20 | open SStream 21 | open OfStream ToStream 22 | 23 | variable 24 | {ι ι' : Type} 25 | [LinearOrder ι] [LinearOrder ι'] 26 | 27 | @[inline] def sparseVec (num : Nat) := 28 | let v : Vec ℕ num := ⟨Array.range num, Array.size_range⟩ 29 | SparseArray.mk v v 30 | 31 | @[inline] def sparseVecRB (num : Nat) := 32 | RB.TreeMap.ofArray $ (Array.range num).map fun n => (n,n) 33 | 34 | @[inline] def vecStream (num : Nat) := 35 | let v : Vec ℕ num := ⟨Array.range num, Array.size_range⟩ 36 | stream $ SparseArray.mk v v 37 | 38 | @[inline] 39 | def SparseArray.range (num : Nat) : SparseArray ℕ ℕ := 40 | let v := Vec.range num; SparseArray.mk v v 41 | 42 | @[inline] def sparseMat (num : Nat) := 43 | let v := SparseArray.range num 44 | v.mapVals fun _ => SparseArray.range num 45 | 46 | @[inline, specialize] def sparseMatFn (f : ℕ → ℕ → α) (num : Nat) := 47 | let v := SparseArray.range num 48 | v.mapVals fun i => (SparseArray.range num |>.mapVals fun j => f i j) 49 | 50 | @[inline] def boolStream (num : Nat) : ℕ →ₛ Bool:= 51 | stream $ Array.range num 52 | 53 | -- todo investigate perf differences 54 | @[specialize] 55 | def genCase [OfStream α β] [Zero β] (label : String) (setup : init → α) [ToString β'] (print : β → β') (num : init) (reps := 10) : IO Unit := do 56 | IO.println s!"reps: {reps}-----" 57 | let s := setup num 58 | time label fun _ => do 59 | for i in [0:reps] do 60 | let x := SStream.eval s 61 | if i % 1000000 = 0 then 62 | IO.println s!"{print x}" 63 | 64 | @[specialize] 65 | def genCase'' (label : String) (setup : init → α) (op : α → β) [ToString β'] (print : β → β') (num : init) (reps := 10) : IO Unit := do 66 | IO.println s!"reps: {reps}-----" 67 | let s := setup num 68 | time label fun _ => do 69 | for i in [0:reps] do 70 | let x := op s 71 | if i % 1000000 = 0 then 72 | IO.println s!"{print x}" 73 | 74 | 75 | def appendFile (fname : System.FilePath) (content : String) : IO Unit := do 76 | let h ← IO.FS.Handle.mk fname IO.FS.Mode.append 77 | h.putStr content 78 | 79 | def resetFile (f : System.FilePath) := IO.FS.writeFile f csvHeader 80 | 81 | def recordTestCases (file : System.FilePath) (cases : List (System.FilePath → IO Unit)) : IO Unit := do 82 | resetFile file 83 | cases.forM fun case => case file 84 | 85 | @[specialize] 86 | def recordTestCase (file : System.FilePath) (label : String) (setup : init → α) (op : α → β) 87 | [ToString β'] (print : β → β') (data : init) (reps := 10) : IO Unit := do 88 | IO.println s!"--- test case: {file}:{label} ---" 89 | let s := setup data 90 | let go := fun () => time' label fun _ => pure (op s) 91 | for _ in [0:5] do _ ← go () -- warmup 92 | let mut result := "" 93 | for _ in [0:reps] do -- test 94 | let (x, t) ← go () 95 | result := result ++ s!"{t},{label}\n" 96 | IO.println s!"{print x}" 97 | appendFile file result 98 | 99 | def randStrings (num : Nat) : IO (Array String) := do 100 | let mut result := #[] 101 | for _ in [0:num] do 102 | let n ← IO.rand 1 (num*2) 103 | result := result.push n 104 | pure $ result.qsort (·<·) |>.deduplicateSorted |>.map toString 105 | 106 | def randNats (num : Nat) : IO (Array Nat) := do 107 | let mut result := #[] 108 | for _ in [0:num] do 109 | let n ← IO.rand 1 (num*2) 110 | result := result.push n 111 | pure $ result.qsort (·<·) |>.deduplicateSorted 112 | -------------------------------------------------------------------------------- /archive/src/verification/semantics/split.lean: -------------------------------------------------------------------------------- 1 | import verification.semantics.stream_props 2 | 3 | variables {α ι₁ ι₂ : Type} 4 | 5 | section streams 6 | 7 | @[simps] 8 | def substream (s : Stream (ι₁ × ι₂) α) (i₁ : ι₁) : Stream ι₂ α := 9 | { σ := s.σ, 10 | valid := λ p, ∃ (h : s.valid p), (s.index p h).1 = i₁, 11 | ready := λ p, s.ready p, 12 | next := λ p h, s.next p h.fst, 13 | index := λ p h, (s.index p h.fst).2, 14 | value := λ p h, s.value p h, 15 | } 16 | 17 | variables {s : Stream (ι₁ × ι₂) α} {i₁ : ι₁} 18 | 19 | @[simp] lemma substream.next'_eq {x : s.σ} : 20 | (substream s i₁).valid x → (substream s i₁).next' x = s.next' x := 21 | λ h, by rw [Stream.next'_val h, substream_next, Stream.next'_val] 22 | 23 | lemma substream.valid_subsumes {n : ℕ} {x : s.σ} : 24 | (substream s i₁).valid ((substream s i₁).next'^[n] x) → s.valid (s.next'^[n] x) := 25 | begin 26 | induction n with n ih generalizing x, 27 | { simp only [function.iterate_zero, substream_valid], 28 | exact Exists.fst }, 29 | { intro h, 30 | have hxv := Stream.next'_valid' _ h, 31 | rw [function.iterate_succ_apply] at h, 32 | simpa [hxv] using ih h } 33 | end 34 | 35 | lemma substream.bound_valid {B : ℕ} {x : s.σ} : 36 | s.bound_valid B x → ∀ i₁, (substream s i₁).bound_valid B x := 37 | begin 38 | simp_rw bound_valid_iff_next'_iterate, 39 | induction B with n ih generalizing x, 40 | { simp_rw [function.iterate_zero_apply, substream_valid, not_exists], 41 | intros; contradiction }, 42 | { intros hnv _, 43 | exact mt substream.valid_subsumes hnv } 44 | end 45 | 46 | end streams 47 | 48 | section stream_exec 49 | 50 | structure split_state (s : StreamExec (ι₁ × ι₂) α) := 51 | (state : s.stream.σ) 52 | (last : option ι₁) 53 | (remaining : ℕ) 54 | (bound_valid : s.stream.bound_valid remaining state) 55 | 56 | @[simps] 57 | def Stream.split (s : StreamExec (ι₁ × ι₂) α) : Stream ι₁ (StreamExec ι₂ α) := 58 | { σ := split_state s, 59 | valid := λ p, s.stream.valid p.1, 60 | ready := λ p, s.stream.ready p.1 ∧ 61 | ∃ hv, p.last ≠ (s.stream.index p.1 hv).1, 62 | next := λ p h, ⟨s.stream.next p.1 h, 63 | (s.stream.index p.1 h).1, 64 | p.remaining.pred, 65 | show s.stream.bound_valid p.remaining.pred _, by { 66 | apply Stream.bound_valid_succ.1, 67 | cases hp : p.remaining, 68 | { have := p.bound_valid, rw hp at this, 69 | cases this, contradiction, }, 70 | { simpa [hp] using p.bound_valid } 71 | }⟩, 72 | index := λ p h, (s.stream.index p.1 h).1, 73 | value := λ p h, { 74 | stream := substream s.stream (s.stream.index p.1 h.2.fst).1, 75 | state := p.1, 76 | bound := p.remaining, 77 | bound_valid := substream.bound_valid p.bound_valid _, 78 | }, 79 | } 80 | 81 | variables {s : StreamExec (ι₁ × ι₂) α} 82 | 83 | @[simp] lemma Stream.split_next'_state (p : split_state s) : 84 | ((Stream.split s).next' p).state = s.stream.next' p.state := 85 | by { by_cases H : s.stream.valid p.state, { simpa [H] }, { simp [H] } } 86 | 87 | @[simp] lemma Stream.split_next'_state' (x : split_state s) (n) : 88 | ((Stream.split s).next'^[n] x).state = (s.stream.next'^[n] x.state) := 89 | begin 90 | induction n with _ ih generalizing x, 91 | { simp }, 92 | { simp_rw [function.iterate_succ_apply, ← Stream.split_next'_state], 93 | exact ih _ } 94 | end 95 | 96 | def StreamExec.split (s : StreamExec (ι₁ × ι₂) α) : StreamExec ι₁ (StreamExec ι₂ α) := 97 | { stream := Stream.split s, 98 | state := ⟨s.state, none, s.bound, s.bound_valid⟩, 99 | bound := s.bound, 100 | bound_valid := begin 101 | have bv := s.bound_valid, 102 | rw bound_valid_iff_next'_iterate at ⊢ bv, 103 | induction eq : s.bound; simpa [eq] using bv, 104 | end, 105 | } 106 | 107 | variables [add_comm_monoid α] 108 | 109 | /- 110 | TODOs: 111 | - do we need a no-lookback hypothesis for `i₁`? 112 | -/ 113 | 114 | theorem StreamExec.split.spec (i₁ i₂) : 115 | (StreamExec.eval <$₂> StreamExec.split s).eval i₁ i₂ = s.eval (i₁, i₂) := 116 | sorry 117 | 118 | end stream_exec 119 | -------------------------------------------------------------------------------- /Etch/Benchmark/TPCHq9.lean: -------------------------------------------------------------------------------- 1 | import Etch.Benchmark.Basic 2 | import Etch.Benchmark.SQL 3 | import Etch.Op 4 | import Etch.Mul 5 | import Etch.ShapeInference 6 | import Etch.Stream 7 | 8 | -- simple O(mn) algorithm 9 | partial def String.findStr? (s f : String) : Option String.Pos := 10 | loop 0 11 | where 12 | loop (off : String.Pos) := 13 | if s.substrEq off f 0 f.length then 14 | some off 15 | else if off + f.endPos < s.endPos then 16 | loop (s.next off) 17 | else 18 | none 19 | 20 | -- `![a, b]` means find `b` within `a` 21 | -- if found, ≥0 is byte index; if not found, -1 22 | private def Op.findStr : Op Int where 23 | argTypes := ![String, String] 24 | spec := fun x => match (x 0).findStr? (x 1) with 25 | | some off => off.byteIdx 26 | | none => -1 27 | opName := "str_find" 28 | -- We will implement this using C's strstr(), which is not exactly 29 | -- the same thing since it's not UTF-8 aware, but close enough. 30 | 31 | private def E.findStr (s f : E String) : E Int := E.call Op.findStr ![s, f] 32 | private def E.hasSubstr (s f : E String) : E Bool := s.findStr f >= (0 : E ℤ) 33 | 34 | private def Op.dateToYear : Op ℤ where 35 | argTypes := ![ℤ] 36 | spec := fun a => 1970 + (a 0) / (365 * 24 * 60 * 60) -- not exactly 37 | opName := "date_to_year" 38 | 39 | -- compute ι₂ from ι₁ 40 | def S.deriveIdx {ι₁} [Tagged ι₁] [TaggedC ι₁] [Zero ι₁] 41 | {ι₂} [Tagged ι₂] [LT ι₂] [@DecidableRel ι₂ LT.lt] 42 | (α) [Tagged α] [One α] (f : E ι₁ → E ι₂) : ι₁ →ₛ ι₂ →ₛ E α where 43 | σ := Var ι₁ 44 | skip pos := pos.store_var 45 | succ _ _ := .skip 46 | ready _ := 1 47 | valid _ := 1 48 | index pos := pos.expr 49 | -- value pos := S.predRangeIncl (f pos.expr) (f pos.expr) -- equiv but slower 50 | value pos := { 51 | σ := Var Bool 52 | skip := fun v i => .if1 (f pos.expr << i) (v.store_var 1) 53 | succ := fun v _ => v.store_var 1 54 | ready := fun v => E.call Op.neg ![v.expr] 55 | valid := fun v => E.call Op.neg ![v.expr] 56 | index := fun _ => f pos.expr 57 | value := fun _ => 1 58 | init := fun n => let v := Var.fresh "visited" n; ⟨v.decl 0, v⟩ 59 | } 60 | init n := let v := Var.fresh "pos" n; ⟨v.decl 0, v⟩ 61 | 62 | namespace Etch.Benchmark.TPCHq9 63 | 64 | -- Schema 65 | 66 | abbrev partkey := (0, ℕ) 67 | abbrev partname := (1, String) 68 | abbrev suppkey := (2, ℕ) 69 | abbrev orderkey := (3, ℕ) 70 | abbrev nationkey := (4, ℕ) 71 | abbrev orderdate := (5, ℤ) 72 | abbrev orderyear := (6, ℤ) 73 | abbrev nationname := (7, String) 74 | abbrev supplycost := (8, R) 75 | abbrev extendedprice := (9, R) 76 | abbrev discount := (10, R) 77 | abbrev quantity := (11, R) 78 | 79 | def lineitem : partkey ↠ₛ suppkey ↠ₛ orderkey ↠ₛ extendedprice ↠ₛ discount ↠ₛ quantity ↠ₛ E R := 80 | (SQL.sss___ "tpch9_lineitem" : ℕ →ₛ ℕ →ₛ ℕ →ₛ R →ₛ R →ₛ R →ₛ E R) 81 | 82 | def part : partkey ↠ₐ partname ↠ₛ E R := (SQL.d_ "tpch9_part" : ℕ →ₐ String →ₛ E R) 83 | def orders : orderkey ↠ₐ orderdate ↠ₛ E R := (SQL.d_ "tpch9_orders" : ℕ →ₐ ℤ →ₛ E R) 84 | def supplier : suppkey ↠ₐ nationkey ↠ₛ E R := (SQL.d_ "tpch9_supplier" : ℕ →ₐ ℕ →ₛ E R) 85 | def partsupp : partkey ↠ₐ suppkey ↠ₛ supplycost ↠ₛ E R := (SQL.ds_ "tpch9_partsupp" .binarySearch : ℕ →ₐ ℕ →ₛ R →ₛ E R) 86 | def nation : nationkey ↠ₐ nationname ↠ₛ E R := (SQL.d_ "tpch9_nation" : ℕ →ₐ String →ₛ E R) 87 | 88 | -- Query 89 | 90 | def orderyearCalc : orderdate ↠ₛ orderyear ↠ₛ E R := 91 | S.deriveIdx R (fun (d : E ℤ) => E.call Op.dateToYear ![d]) 92 | def partGreen : partname ↠ₐ E Bool := hasGreen 93 | where hasGreen : String →ₐ E Bool := fun v => v.hasSubstr "green" 94 | def profitCalc : supplycost ↠ₐ extendedprice ↠ₐ discount ↠ₐ quantity ↠ₐ E R := profitCalc' 95 | where profitCalc' : R →ₐ R →ₐ R →ₐ R →ₐ E R := fun c p d q => p * (1 - d) - c * q 96 | 97 | def q9 := ∑ partkey, partname, suppkey: ∑ orderkey, nationkey, orderdate: ∑ supplycost, extendedprice, discount, quantity: 98 | lineitem * part * partGreen * partsupp * supplier * nation * profitCalc * (orders * orderyearCalc) 99 | 100 | def funcs : List (String × String) := [ 101 | let fn := "q9"; (fn, compileFunMap2 ℤ String R fn q9) 102 | ] 103 | 104 | end Etch.Benchmark.TPCHq9 105 | -------------------------------------------------------------------------------- /bench/taco-datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sqlite3 3 | import sys 4 | from pathlib import Path 5 | 6 | 7 | def makeV(p=0.1, nonzeros=200): 8 | # 2000 × 0.1 = 200 9 | 10 | # Reduce the number of nonzeros for highly sparse datasets to help account for overhead. 11 | nonzeros = int(np.round(nonzeros / max(1, -np.log2(p)))) 12 | 13 | n = int(np.round(nonzeros / p)) 14 | matrix_size = n 15 | m = np.random.rand(nonzeros) 16 | m *= (matrix_size - nonzeros) / np.sum(m) 17 | m += 1 18 | 19 | # Extra correction 20 | # print(np.sum(np.round(m))) 21 | m /= np.sum(np.round(m)) / matrix_size 22 | # print(np.sum(np.round(m))) 23 | m /= np.sum(np.round(m)) / matrix_size 24 | # print(np.sum(np.round(m))) 25 | 26 | result = set() 27 | last = 0 28 | for r in m: 29 | if last >= matrix_size: 30 | break 31 | result.add((last, float(r))) 32 | last += int(np.round(r)) 33 | print( 34 | f"expected={nonzeros} actual={len(result)} expect_sparsity={p} actual_sparsity={len(result) / matrix_size}" 35 | ) 36 | return result 37 | 38 | 39 | def makeA(p=0.1, nonzeros=400000): 40 | # 2000 × 2000 × 0.1 = 400000 41 | 42 | # Reduce the number of nonzeros for highly sparse datasets to help account for overhead. 43 | nonzeros = int(np.round(nonzeros / max(1, -np.log2(p)))) 44 | 45 | n = int(np.round(np.sqrt(nonzeros / p))) 46 | matrix_size = n * n 47 | m = np.random.rand(nonzeros) 48 | m *= (matrix_size - nonzeros) / np.sum(m) 49 | m += 1 50 | 51 | # Extra correction 52 | # print(np.sum(np.round(m))) 53 | m /= np.sum(np.round(m)) / matrix_size 54 | # print(np.sum(np.round(m))) 55 | m /= np.sum(np.round(m)) / matrix_size 56 | # print(np.sum(np.round(m))) 57 | 58 | result = set() 59 | last = 0 60 | for r in m: 61 | if last >= matrix_size: 62 | break 63 | i, j = last // n, last % n 64 | result.add((i, j, float(r))) 65 | last += int(np.round(r)) 66 | print( 67 | f"expected={nonzeros} actual={len(result)} expect_sparsity={p} actual_sparsity={len(result) / matrix_size}" 68 | ) 69 | return result 70 | 71 | 72 | def makeC(p=0.1, nonzeros=800000): 73 | # 200 × 200 × 200 * 0.1 = 800000 74 | 75 | # Reduce the number of nonzeros for highly sparse datasets to help account for overhead. 76 | nonzeros = int(np.round(nonzeros / max(1, -np.log2(p)))) 77 | 78 | n = int(np.round(np.cbrt(nonzeros / p))) 79 | matrix_size = n * n * n 80 | m = np.random.rand(nonzeros) 81 | m *= (matrix_size - nonzeros) / np.sum(m) 82 | m += 1 83 | 84 | # Rounding correction 85 | # print(np.sum(np.round(m))) 86 | m /= np.sum(np.round(m)) / matrix_size 87 | # print(np.sum(np.round(m))) 88 | m /= np.sum(np.round(m)) / matrix_size 89 | # print(np.sum(np.round(m))) 90 | 91 | result = set() 92 | last = 0 93 | for r in m: 94 | if last >= matrix_size: 95 | break 96 | i, j, k = last // (n * n), (last // n) % n, last % n 97 | result.add((i, j, k, float(r))) 98 | last += int(np.round(r)) 99 | print( 100 | f"expected={nonzeros} actual={len(result)} expect_sparsity={p} actual_sparsity={len(result) / matrix_size}" 101 | ) 102 | return result 103 | 104 | 105 | def main(db: Path = Path("data/pldi.db"), sparsity: float = 0.1): 106 | c = sqlite3.connect(str(db)) 107 | c.execute("DROP TABLE IF EXISTS A") 108 | c.execute("DROP TABLE IF EXISTS B") 109 | c.execute("DROP TABLE IF EXISTS C") 110 | c.execute("DROP TABLE IF EXISTS V") 111 | c.execute("CREATE TABLE A(i INTEGER NOT NULL, j INTEGER NOT NULL, v REAL NOT NULL)") 112 | c.execute("CREATE TABLE B(i INTEGER NOT NULL, j INTEGER NOT NULL, v REAL NOT NULL)") 113 | c.execute( 114 | "CREATE TABLE C(i INTEGER NOT NULL, j INTEGER NOT NULL, k INTEGER NOT NULL, v REAL NOT NULL)" 115 | ) 116 | c.execute("CREATE TABLE V(i INTEGER NOT NULL, v REAL NOT NULL)") 117 | print("A") 118 | # reference factor = 20 119 | c.executemany(f"INSERT INTO A VALUES(?,?,?)", makeA(sparsity)) 120 | print("B") 121 | c.executemany(f"INSERT INTO B VALUES(?,?,?)", makeA(sparsity)) 122 | print("C") 123 | c.executemany(f"INSERT INTO C VALUES(?,?,?,?)", makeC(sparsity)) 124 | print("V") 125 | c.executemany(f"INSERT INTO V VALUES(?,?)", makeV(sparsity)) 126 | c.commit() 127 | 128 | 129 | main(Path(sys.argv[1]), float(sys.argv[2])) 130 | -------------------------------------------------------------------------------- /taco/add2.c: -------------------------------------------------------------------------------- 1 | 2 | int evaluate(taco_tensor_t *out, taco_tensor_t *A, taco_tensor_t *B) { 3 | int out1_dimension = (int)(out->dimensions[0]); 4 | int* restrict out2_pos = (int*)(out->indices[1][0]); 5 | int* restrict out2_crd = (int*)(out->indices[1][1]); 6 | double* restrict out_vals = (double*)(out->vals); 7 | int A1_dimension = (int)(A->dimensions[0]); 8 | int* restrict A2_pos = (int*)(A->indices[1][0]); 9 | int* restrict A2_crd = (int*)(A->indices[1][1]); 10 | double* restrict A_vals = (double*)(A->vals); 11 | int B1_dimension = (int)(B->dimensions[0]); 12 | int* restrict B2_pos = (int*)(B->indices[1][0]); 13 | int* restrict B2_crd = (int*)(B->indices[1][1]); 14 | double* restrict B_vals = (double*)(B->vals); 15 | 16 | out2_pos = (int32_t*)malloc(sizeof(int32_t) * (out1_dimension + 1)); 17 | out2_pos[0] = 0; 18 | for (int32_t pout2 = 1; pout2 < (out1_dimension + 1); pout2++) { 19 | out2_pos[pout2] = 0; 20 | } 21 | int32_t out2_crd_size = 1048576; 22 | out2_crd = (int32_t*)malloc(sizeof(int32_t) * out2_crd_size); 23 | int32_t jout = 0; 24 | int32_t out_capacity = 1048576; 25 | out_vals = (double*)malloc(sizeof(double) * out_capacity); 26 | 27 | for (int32_t i = 0; i < B1_dimension; i++) { 28 | int32_t pout2_begin = jout; 29 | 30 | int32_t jA = A2_pos[i]; 31 | int32_t pA2_end = A2_pos[(i + 1)]; 32 | int32_t jB = B2_pos[i]; 33 | int32_t pB2_end = B2_pos[(i + 1)]; 34 | 35 | while (jA < pA2_end && jB < pB2_end) { 36 | int32_t jA0 = A2_crd[jA]; 37 | int32_t jB0 = B2_crd[jB]; 38 | int32_t j = TACO_MIN(jA0,jB0); 39 | if (jA0 == j && jB0 == j) { 40 | if (out_capacity <= jout) { 41 | out_vals = (double*)realloc(out_vals, sizeof(double) * (out_capacity * 2)); 42 | out_capacity *= 2; 43 | } 44 | out_vals[jout] = A_vals[jA] + B_vals[jB]; 45 | if (out2_crd_size <= jout) { 46 | out2_crd = (int32_t*)realloc(out2_crd, sizeof(int32_t) * (out2_crd_size * 2)); 47 | out2_crd_size *= 2; 48 | } 49 | out2_crd[jout] = j; 50 | jout++; 51 | } 52 | else if (jA0 == j) { 53 | if (out_capacity <= jout) { 54 | out_vals = (double*)realloc(out_vals, sizeof(double) * (out_capacity * 2)); 55 | out_capacity *= 2; 56 | } 57 | out_vals[jout] = A_vals[jA]; 58 | if (out2_crd_size <= jout) { 59 | out2_crd = (int32_t*)realloc(out2_crd, sizeof(int32_t) * (out2_crd_size * 2)); 60 | out2_crd_size *= 2; 61 | } 62 | out2_crd[jout] = j; 63 | jout++; 64 | } 65 | else { 66 | if (out_capacity <= jout) { 67 | out_vals = (double*)realloc(out_vals, sizeof(double) * (out_capacity * 2)); 68 | out_capacity *= 2; 69 | } 70 | out_vals[jout] = B_vals[jB]; 71 | if (out2_crd_size <= jout) { 72 | out2_crd = (int32_t*)realloc(out2_crd, sizeof(int32_t) * (out2_crd_size * 2)); 73 | out2_crd_size *= 2; 74 | } 75 | out2_crd[jout] = j; 76 | jout++; 77 | } 78 | jA += (int32_t)(jA0 == j); 79 | jB += (int32_t)(jB0 == j); 80 | } 81 | while (jA < pA2_end) { 82 | int32_t j = A2_crd[jA]; 83 | if (out_capacity <= jout) { 84 | out_vals = (double*)realloc(out_vals, sizeof(double) * (out_capacity * 2)); 85 | out_capacity *= 2; 86 | } 87 | out_vals[jout] = A_vals[jA]; 88 | if (out2_crd_size <= jout) { 89 | out2_crd = (int32_t*)realloc(out2_crd, sizeof(int32_t) * (out2_crd_size * 2)); 90 | out2_crd_size *= 2; 91 | } 92 | out2_crd[jout] = j; 93 | jout++; 94 | jA++; 95 | } 96 | while (jB < pB2_end) { 97 | int32_t j = B2_crd[jB]; 98 | if (out_capacity <= jout) { 99 | out_vals = (double*)realloc(out_vals, sizeof(double) * (out_capacity * 2)); 100 | out_capacity *= 2; 101 | } 102 | out_vals[jout] = B_vals[jB]; 103 | if (out2_crd_size <= jout) { 104 | out2_crd = (int32_t*)realloc(out2_crd, sizeof(int32_t) * (out2_crd_size * 2)); 105 | out2_crd_size *= 2; 106 | } 107 | out2_crd[jout] = j; 108 | jout++; 109 | jB++; 110 | } 111 | 112 | out2_pos[i + 1] = jout - pout2_begin; 113 | } 114 | 115 | int32_t csout2 = 0; 116 | for (int32_t pout20 = 1; pout20 < (out1_dimension + 1); pout20++) { 117 | csout2 += out2_pos[pout20]; 118 | out2_pos[pout20] = csout2; 119 | } 120 | 121 | out->indices[1][0] = (uint8_t*)(out2_pos); 122 | out->indices[1][1] = (uint8_t*)(out2_crd); 123 | out->vals = (uint8_t*)out_vals; 124 | return 0; 125 | } 126 | 127 | -------------------------------------------------------------------------------- /operators.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | static inline double num_add(double a, double b) { return a + b; } 7 | static inline double num_sub(double a, double b) { return a - b; } 8 | static inline double num_mul(double a, double b) { return a * b; } 9 | static inline double num_one() { return 1; } 10 | static inline double num_zero() { return 0; } 11 | static inline double num_lt(double a, double b) { return a < b; } 12 | static inline double num_le(double a, double b) { return a <= b; } 13 | // static inline double num_lt(double a, double b) { printf("%f < %f\n", a, 14 | // b); return a < b; } 15 | 16 | static inline double num_ofBool(bool x) { return x ? 1 : 0; } 17 | static inline double num_toMin(double x) { return x; } 18 | static inline double num_toMax(double x) { return x; } 19 | static inline double nat_toNum(int x) { return x; } 20 | 21 | static inline double min_add(double a, double b) { return a < b ? a : b; } 22 | static inline double min_mul(double a, double b) { return a + b; } 23 | static inline double min_one() { return 0; } 24 | static inline double min_zero() { return DBL_MAX; } 25 | 26 | static inline double max_add(double a, double b) { return a < b ? b : a; } 27 | static inline double max_mul(double a, double b) { return a + b; } 28 | static inline double max_one() { return 0; } 29 | static inline double max_zero() { return -DBL_MAX; } 30 | 31 | static inline int nat_add(int a, int b) { return a + b; } 32 | static inline int nat_mul(int a, int b) { return a * b; } 33 | static inline int nat_sub(int a, int b) { return a - b; } 34 | static inline bool nat_lt(int a, int b) { return a < b; } 35 | static inline bool nat_le(int a, int b) { return a <= b; } 36 | static inline bool nat_eq(int a, int b) { return a == b; } 37 | static inline int nat_max(int a, int b) { return a < b ? b : a; } 38 | static inline int nat_min(int a, int b) { return a < b ? a : b; } 39 | static inline int nat_succ(int a) { return a + 1; } 40 | static inline int nat_mid(int a, int b) { return (a + b) / 2; } 41 | static inline int nat_one() { return 1; } 42 | static inline int nat_zero() { return 0; } 43 | static inline int nat_ofBool(bool x) { return x; } 44 | 45 | static inline int int_add(int a, int b) { return a + b; } 46 | static inline int int_mul(int a, int b) { return a * b; } 47 | static inline int int_sub(int a, int b) { return a - b; } 48 | static inline bool int_lt(int a, int b) { return a < b; } 49 | static inline bool int_le(int a, int b) { return a <= b; } 50 | static inline bool int_eq(int a, int b) { return a == b; } 51 | static inline int int_max(int a, int b) { return a < b ? b : a; } 52 | static inline int int_min(int a, int b) { return a < b ? a : b; } 53 | static inline int int_succ(int a) { return a + 1; } 54 | static inline bool int_neg(int a) { return !a; } 55 | static inline int int_mid(int a, int b) { return (a + b) / 2; } 56 | static inline int int_one() { return 1; } 57 | static inline int int_zero() { return 0; } 58 | static inline int int_ofBool(bool x) { return x; } 59 | 60 | // static inline bool bool_add(bool a, bool b) { return a || b; } 61 | // static inline bool bool_mul(bool a, bool b) { return a && b; } 62 | #define bool_add(a, b) ((a) || (b)) 63 | #define bool_mul(a, b) ((a) && (b)) 64 | #define bool_one() true 65 | #define bool_zero() false 66 | #define bool_neg(a) (!(a)) 67 | 68 | // Treat NULL as the top value (e.g., empty space at the end of the array). 69 | static inline const char* str_zero() { return ""; } 70 | static inline bool str_lt(const char* a, const char* b) { 71 | if (!a) return false; 72 | if (!b) return true; 73 | return strcmp(a, b) < 0; 74 | } 75 | static inline bool str_le(const char* a, const char* b) { 76 | if (!a) return !b; 77 | if (!b) return true; 78 | return strcmp(a, b) <= 0; 79 | } 80 | static inline int str_find(const char* haystack, const char* needle) { 81 | if (!haystack) return -1; 82 | const char* res = strstr(haystack, needle); 83 | if (!res) return -1; 84 | return res - haystack; 85 | } 86 | static inline const char* str_max(const char* a, const char* b) { 87 | return str_lt(a, b) ? b : a; 88 | } 89 | static inline const char* str_min(const char* a, const char* b) { 90 | return str_lt(a, b) ? a : b; 91 | } 92 | static inline bool str_eq(const char* a, const char* b) { 93 | if (!a || !b) return a == b; 94 | return strcmp(a, b) == 0; 95 | } 96 | static inline int str_atoi(const char* a) { return atoi(a); } 97 | static inline double str_atof(const char* a) { return atof(a); } 98 | 99 | #define TACO_MIN(a, b) ((a) < (b) ? (a) : (b)) 100 | 101 | #define macro_ternary(c, x, y) ((c) ? x : y) 102 | #define index_map(a, ...) &a[{__VA_ARGS__}] 103 | -------------------------------------------------------------------------------- /common.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace { 11 | 12 | // https://stackoverflow.com/a/6245777/1937836 13 | namespace aux { 14 | template 15 | struct seq {}; 16 | 17 | template 18 | struct gen_seq : gen_seq {}; 19 | 20 | template 21 | struct gen_seq<0, Is...> : seq {}; 22 | 23 | template 24 | void print_tuple(std::basic_ostream& os, Tuple const& t, seq) { 25 | using swallow = int[]; 26 | (void)swallow{0, 27 | (void(os << (Is == 0 ? "" : ", ") << std::get(t)), 0)...}; 28 | } 29 | } // namespace aux 30 | 31 | template 32 | auto operator<<(std::basic_ostream& os, std::tuple const& t) 33 | -> std::basic_ostream& { 34 | aux::print_tuple(os, t, aux::gen_seq()); 35 | return os; 36 | } 37 | 38 | template 39 | std::ostream& operator<<(std::ostream& os, 40 | const std::unordered_map& m) { 41 | for (auto&& p : m) { 42 | os << p.first << ": " << p.second << '\n'; 43 | } 44 | return os; 45 | } 46 | 47 | template 48 | void time(F f, char const* tag, int reps) { 49 | using fsec = std::chrono::duration; 50 | using usec = std::chrono::microseconds; 51 | using std::chrono::steady_clock; 52 | using std::chrono::system_clock; 53 | 54 | auto as_fsec = [](auto dur) { return std::chrono::duration_cast(dur); }; 55 | auto tv_diff = [](timeval start, timeval end) { 56 | auto start_sec = system_clock::from_time_t(start.tv_sec); 57 | auto end_sec = system_clock::from_time_t(end.tv_sec); 58 | return std::chrono::duration_cast(end_sec - start_sec) + 59 | (usec(end.tv_usec) - usec(start.tv_usec)); 60 | }; 61 | 62 | // Unfortunately, std::chrono::duration::operator<<() (C++20) still isn't well-supported yet. 63 | auto sec_to_str = [](auto secs) { 64 | return std::to_string(secs.count()) + "s"; 65 | }; 66 | 67 | steady_clock::duration total_real{0}; 68 | usec total_user{0}; 69 | usec total_sys{0}; 70 | 71 | for (int i = 0; i < reps; i++) { 72 | struct rusage s_start; 73 | getrusage(RUSAGE_SELF, &s_start); 74 | auto rep_start = steady_clock::now(); 75 | 76 | auto val = f(); 77 | 78 | auto rep_end = steady_clock::now(); 79 | struct rusage s_end; 80 | getrusage(RUSAGE_SELF, &s_end); 81 | 82 | auto udiff = tv_diff(s_start.ru_utime, s_end.ru_utime); 83 | auto sdiff = tv_diff(s_start.ru_stime, s_end.ru_stime); 84 | std::cout << tag << " val: " << std::fixed << val << std::endl; 85 | std::cout << tag << " took (s): real " << sec_to_str(as_fsec(rep_end - rep_start)) 86 | << " user " << sec_to_str(as_fsec(udiff)) << " sys " << sec_to_str(as_fsec(sdiff)) 87 | << std::endl; 88 | 89 | total_real += rep_end - rep_start; 90 | total_user += udiff; 91 | total_sys += sdiff; 92 | } 93 | 94 | std::cout << tag << " took (avg): real " << sec_to_str(as_fsec(total_real) / reps) 95 | << " user " << sec_to_str(as_fsec(total_user) / reps) << " sys " 96 | << sec_to_str(as_fsec(total_sys) / reps) << std::endl; 97 | } 98 | 99 | namespace hash_tuple { 100 | 101 | template 102 | struct hash { 103 | size_t operator()(TT const& tt) const { return std::hash()(tt); } 104 | }; 105 | 106 | namespace { 107 | template 108 | inline void hash_combine(std::size_t& seed, T const& v) { 109 | seed ^= hash()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); 110 | } 111 | 112 | template ::value - 1> 113 | struct HashValueImpl { 114 | static void apply(size_t& seed, Tuple const& tuple) { 115 | HashValueImpl::apply(seed, tuple); 116 | hash_combine(seed, std::get(tuple)); 117 | } 118 | }; 119 | 120 | template 121 | struct HashValueImpl { 122 | static void apply(size_t& seed, Tuple const& tuple) { 123 | hash_combine(seed, std::get<0>(tuple)); 124 | } 125 | }; 126 | } // namespace 127 | 128 | template 129 | struct hash> { 130 | size_t operator()(std::tuple const& tt) const { 131 | size_t seed = 0; 132 | HashValueImpl>::apply(seed, tt); 133 | return seed; 134 | } 135 | }; 136 | 137 | } // namespace hash_tuple 138 | 139 | } // namespace 140 | -------------------------------------------------------------------------------- /Etch/InductiveStreamDeriving.lean: -------------------------------------------------------------------------------- 1 | import Etch.InductiveStreamCompile 2 | 3 | namespace Etch.Deriving.AttrOrder 4 | 5 | open Lean Elab Command Term Meta 6 | 7 | structure Options (α : Type u) where 8 | (order : List α) 9 | (localDecl : Bool := false) 10 | structure OptionsT where 11 | (order : TSyntaxArray `term) 12 | (localDecl : Bool) 13 | 14 | private def mkOptions (declName : Name) (stx : TSyntax ``Parser.Term.structInst) : TermElabM OptionsT := do 15 | -- First do a type check just to be sure 16 | let expectedType ← elabType (← `($(mkIdent ``Options) $(mkIdent declName))) 17 | let expr ← StructInst.elabStructInst stx.raw (some expectedType) 18 | let typ ← inferType expr 19 | if expr.hasSorry || typ.hasSorry then 20 | throwErrorAt stx "expression contains sorry" 21 | trace[Elab.Deriving.attr_order_total] m!"elabStructInst returns {expr}; with type {typ}; expected {expectedType}" 22 | 23 | -- Extract elements from syntax. 24 | match stx.raw with 25 | | `(Parser.Term.structInst| { $fields* }) 26 | | `(Parser.Term.structInst| { $fields* : $_ }) => do 27 | let mut order := none 28 | let mut localDecl := false 29 | for field in fields.getElems do 30 | match field with 31 | | `(Parser.Term.structInstField| order := [$elems,*]) => 32 | order := some elems.getElems 33 | | `(Parser.Term.structInstField| order := $_) => 34 | throwErrorAt field "incorrect argument format; must be order := [attr1, attr2, …]" 35 | | `(Parser.Term.structInstField| localDecl := true) => 36 | localDecl := true 37 | | `(Parser.Term.structInstField| localDecl := false) => 38 | localDecl := false 39 | | `(Parser.Term.structInstField| localDecl := $v) => 40 | throwErrorAt v "incorrect argument format; must be true or false" 41 | | _ => throwErrorAt field "unknown field" 42 | match order with 43 | | some elems => return { order := elems, localDecl : OptionsT } 44 | | none => throwErrorAt stx "missing order field" 45 | | _ => throwError "incorrect argument format; must be \{ order := [attr1, attr2, …] }" 46 | 47 | def mkOrder (declName : Name) (options : OptionsT) : CommandElabM Bool := do 48 | let id := mkIdent declName 49 | let prepareName (n : String) : MacroM Name := 50 | if options.localDecl then 51 | -- Logic similar to `Lean.Elab.Command.mkInstanceName` 52 | let rec getSuffix (m : Name) := match m with 53 | | .anonymous => "" 54 | | .str m "" => getSuffix m 55 | | .str _ s => s 56 | | .num m _ => getSuffix m 57 | let n := if "inst".isPrefixOf n then 58 | n ++ (getSuffix declName.eraseMacroScopes).capitalize 59 | else 60 | n 61 | mkUnusedBaseName <| Name.mkSimple n 62 | else 63 | pure <| `_root_ ++ declName.str n 64 | let ordID := mkIdent <| (← liftMacroM <| prepareName "order") 65 | let instOrdID := mkIdent <| (← liftMacroM <| prepareName "instAttrOrder") 66 | let ordHereID := mkIdent <| (← liftMacroM <| prepareName "orderHere") 67 | let instOrdTotalID := mkIdent <| (← liftMacroM <| prepareName "instAttrOrderTotal") 68 | 69 | let ord ← `(command| 70 | @[reducible] def $ordID : $(mkIdent ``Shape) $id := 71 | ⟨[$(options.order):term,*], by decide⟩) 72 | let instOrd ← `(command| 73 | @[reducible] instance $instOrdID:ident : $(mkIdent ``AttrOrder) $id := 74 | ⟨$ordID⟩) 75 | 76 | let mut ordHereCases := #[] 77 | for arg in options.order do 78 | let alt ← `(Parser.Term.matchAltExpr| | $arg => $(mkIdent ``List.Find.mem)) 79 | ordHereCases := ordHereCases.push alt 80 | let ordHere ← `(command| 81 | def $ordHereID : ∀ (i : $id), $(mkIdent ``List.MemT) i ($(mkIdent ``Shape.val) ($(mkIdent ``AttrOrder.order) (self := $instOrdID))) 82 | $ordHereCases:matchAlt*) 83 | 84 | let instOrdTotal ← `(command| 85 | instance $instOrdTotalID:ident : $(mkIdent ``AttrOrderTotal) $id := 86 | ⟨$ordHereID⟩) 87 | 88 | let cmds := #[ord, instOrd, ordHere, instOrdTotal] 89 | for cmd in cmds do 90 | trace[Elab.Deriving.attr_order_total] cmd 91 | elabCommand cmd 92 | return true 93 | 94 | def mkOrderTotalInstanceHandler (declNames : Array Name) (args? : Option (TSyntax ``Parser.Term.structInst)) : CommandElabM Bool := do 95 | if declNames.size != 1 then 96 | return false -- mutually inductive types are not supported yet 97 | else if let some args := args? then 98 | let opts ← liftTermElabM <| mkOptions declNames[0]! args 99 | mkOrder declNames[0]! opts 100 | else 101 | return false -- don't support automatically forming an order yet 102 | 103 | initialize 104 | registerDerivingHandlerWithArgs ``AttrOrderTotal mkOrderTotalInstanceHandler 105 | registerTraceClass `Elab.Deriving.attr_order_total 106 | 107 | end Etch.Deriving.AttrOrder 108 | -------------------------------------------------------------------------------- /Etch/StreamFusion/Tutorial.lean: -------------------------------------------------------------------------------- 1 | /- very WIP tutorial for the library -/ 2 | 3 | import Etch.StreamFusion.Stream 4 | import Etch.StreamFusion.Expand 5 | import Etch.StreamFusion.TestUtil 6 | 7 | import Std.Data.RBMap 8 | import Std.Data.HashMap 9 | 10 | open Std (RBMap RBSet HashMap) 11 | 12 | namespace Etch.Verification 13 | 14 | open ToStream 15 | 16 | variable {I J K L α β : Type} 17 | [LinearOrder I] [LinearOrder J] [LinearOrder K] [LinearOrder L] 18 | [Scalar α] [Mul α] [Zero α] [Add α] 19 | 20 | def_index_enum_group i,j,k,l 21 | 22 | /- 23 | Some coercion examples 24 | -/ 25 | 26 | def mul_fns [ToStream t (I → J → α)] [ToStream t' (J → K → α)] (a : t) (b : t') 27 | : i~I → j~J → k~K → α := 28 | a(i,j) * b(j,k) 29 | 30 | def mul_fns' [ToStream t (I → J → α)] [ToStream t' (J → K → α)] (a : t) (b : t') := 31 | a(i,j) * b(j,k) 32 | 33 | section 34 | --set_option trace.Meta.synthInstance true 35 | #synth ExpressionTree.EnsureBroadcast [(0, I), (1, J), (2, K)] α (j~J → k~K →ₛ α) _ 36 | end 37 | 38 | 39 | -- Notice, no Broadcast helper class, it was unfolded 40 | #print mul_fns' 41 | 42 | --def testContractElab (A : I →ₛ J →ₛ α) (B : J →ₛ K →ₛ α) := Σ j k => (Σ i => A(i,j)) * B(j,k) 43 | -- i~Unit →ₛ j~Unit →ₛ k~K →ₛ α 44 | --#print testContractElab 45 | /- 46 | Contract.contract j 47 | (Contract.contract i ([(i, I), (j, J), (k, K)] ⇑ Label.label [i, j] A) * 48 | [(i, Unit), (j, J), (k, K)] ⇑ Label.label [j, k] B) 49 | -/ 50 | 51 | @[inline] def testSelect (m : I →ₛ J →ₛ α) (v : J →ₛ α) := memo SparseArray I α from select i => m(i, j) * v(j) 52 | -- I →ₛ α 53 | 54 | /- Some examples of notation 55 | 56 | notes: 57 | - a shape is a list of (Nat, Type) pairs 58 | - a collection needs to have a ToStream instance 59 | - indices are index names encoded as natural numbers for now 60 | -/ 61 | 62 | @[inline] def vecSum (v : I →ₛ α) := Σ i => v(i) 63 | @[inline] def matSum (m : I →ₛ J →ₛ α) (v : J →ₛ α) := Σ i j => m(i, j) * v(j) 64 | 65 | 66 | @[inline] def matMul_ijjk {α J} [LinearOrder J] [Mul α] [Scalar α] (a : I →ₛ J →ₛ α) (b : J →ₛ K →ₛ α) := 67 | Σ j => a(i,j) * b(j,k) 68 | 69 | open ToStream 70 | --open OfStream 71 | open SStream 72 | 73 | variable [Hashable K] 74 | 75 | -- todo: investigate these definitions and other approaches 76 | @[inline] def ABC_ 77 | (a : I →ₛ J →ₛ α) 78 | (b : J →ₛ K →ₛ α) 79 | (c : K →ₛ L →ₛ α) := 80 | let m1 := a(i,j) 81 | let m2 := b(j,k) 82 | let m3 := c(k,l) 83 | let x : SparseArray I (HashMap K α) := eval $ Σ j => m1 * m2 84 | let m := (stream x)(i,k) * m3 85 | Σ k => m 86 | 87 | @[inline] def ABC' (a : I →ₛ J →ₛ α) (b : J →ₛ K →ₛ α) (c : K →ₛ L →ₛ α) := 88 | let ijk := [(i,I),(j,J),(k,K)] 89 | let m1 := ijk ⇑ a(i,j) 90 | let m := m1.map fun row => memo HashMap K α from Σ j => row * b(j,k) 91 | let m := m(i,k) * c(k,l) 92 | Σ k => m 93 | 94 | @[inline] def ABC (a : I →ₛ J →ₛ α) (b : J →ₛ K →ₛ α) (c : K →ₛ L →ₛ α) := 95 | Σ j k => a(i,j)*b(j,k)*c(k,l) 96 | 97 | --@[inline] def ABC_memo' (a : I →ₛ J →ₛ α) (b : J →ₛ K →ₛ α) (c : K →ₛ L →ₛ α) := 98 | -- Σ k => memo(Σ j=> a(i,j)*b(j,k) with SparseArray I (HashMap K α)) * c(k,l) 99 | 100 | @[inline] def ABC_memo (a : I →ₛ J →ₛ α) (b : J →ₛ K →ₛ α) (c : K →ₛ L →ₛ α) := 101 | let ijk := [(i,I),(j,J),(k,K)] 102 | let m1 := ijk ⇑ a(i,j) 103 | let m := m1.map fun row => 104 | memo HashMap K α from Σ j => row * b(j,k) 105 | let m := m(i,k) * c(k,l) 106 | Σ k => m 107 | 108 | def mat' (num : ℕ) := sparseMat num.sqrt 109 | 110 | def matMul1 (num : ℕ) : IO Unit := do 111 | let m := stream $ mat' num 112 | let x := matMul_ijjk m m 113 | time "matrix 1'" fun _ => 114 | for _ in [0:10] do 115 | let x : HashMap ℕ (HashMap ℕ ℕ) := eval x 116 | IO.println s!"{x.1.size}" 117 | 118 | def matMul1' (num : ℕ) : IO Unit := do 119 | let m := stream $ mat' num 120 | let x := matMul_ijjk m m 121 | let x := Σ i k => matMul_ijjk m m 122 | time "matrix 1'" fun _ => 123 | for _ in [0:10] do 124 | let x : ℕ := eval x 125 | IO.println s!"{x}" 126 | 127 | def testABC (num : ℕ) : IO Unit := do 128 | let m := stream $ mat' num 129 | time "matrix abc" fun _ => 130 | for _ in [0:10] do 131 | let x : SparseArray ℕ (HashMap ℕ ℕ) := eval $ ABC m m m 132 | IO.println s!"{x.1.size}" 133 | 134 | def testABC' (num : ℕ) : IO Unit := do 135 | let m := stream $ mat' num 136 | time "matrix abc'" fun _ => 137 | for _ in [0:10] do 138 | let x : SparseArray ℕ (HashMap ℕ ℕ) := eval $ ABC' m m m 139 | IO.println s!"{x.1.size}" 140 | 141 | def _root_.main (args : List String) : IO Unit := do 142 | let num := (args[0]!).toNat?.getD 1000 143 | IO.println s!"test of size {num}" 144 | IO.println "starting" 145 | --matMul1 num 146 | --matMul1' num 147 | testABC num 148 | testABC' num 149 | 150 | open ToStream 151 | -------------------------------------------------------------------------------- /Etch/KRelation.lean: -------------------------------------------------------------------------------- 1 | import Mathlib.Data.Nat.Basic 2 | import Mathlib.Data.Finset.Card 3 | import Mathlib.Data.Finsupp.Basic 4 | import Mathlib.Data.Fintype.Basic 5 | import Mathlib.Data.Option.Basic 6 | import Mathlib.Data.Set.Finite 7 | import Mathlib.Algebra.BigOperators.Basic 8 | import Mathlib.Logic.Function.Basic 9 | import Mathlib.Tactic.LibrarySearch 10 | import Etch.Basic 11 | 12 | -- the class does not need to reference K, the semiring of values 13 | class PositiveAlgebra {A : Type} [DecidableEq A] (α : Finset A → Type) where 14 | finite : ∀ {S : Finset A}, α S → Type _ 15 | equiv : A → A → Prop 16 | 17 | mul (a b : α S) : α S 18 | expand (i : A) (S : Finset A) : α S → α (insert i S) 19 | expand_sub (sub : S ⊆ S') : α S → α S' 20 | contract (i : A) (S : Finset A) (s : α S) (fin : finite s) : α (Finset.erase S i) 21 | contract_sub (sub : S ⊆ S') (s : α S') (fin : finite s) : α S 22 | rename (S : Finset A) (ρ : S → A) (equiv : (i : S) → equiv i (ρ i)) : α S → α (S.attach.image ρ) 23 | 24 | section KRel 25 | variable (K : Type) [Semiring K] {A : Type} [DecidableEq A] (I : A → Type) (S : Finset A) [(i : A) → DecidableEq (I i)] 26 | 27 | abbrev Tuple := (s : S) → I s 28 | 29 | instance : EmptyCollection (Tuple I Finset.empty) := ⟨ (nomatch .) ⟩ 30 | instance : Inhabited (Tuple I Finset.empty) := ⟨ {} ⟩ 31 | 32 | #synth DecidableEq (Tuple I S) 33 | 34 | def KRel := Tuple I S → K 35 | instance : Semiring (KRel K I S) := Pi.semiring 36 | 37 | variable {I} {S} 38 | def Tuple.project {S S' : Finset A} (sub : S ⊆ S') (t : Tuple I S') : Tuple I S := fun ⟨i, mem⟩ ↦ t ⟨i, Finset.mem_of_subset sub mem⟩ 39 | def Tuple.erase (i : A) (t : Tuple I S) : Tuple I (S.erase i) := t.project (S.erase_subset _) 40 | def Tuple.erase' (i : A) (t : Tuple I (insert i S)) : Tuple I S := t.project (S.subset_insert _) 41 | 42 | instance KRel.positiveAlgebra : PositiveAlgebra (KRel K I) where 43 | finite f := Σ' supp : Finset (Tuple I _), ∀ x, f x ≠ 0 → x ∈ supp 44 | equiv := (I . = I .) 45 | mul a b := a * b 46 | expand i _ f x := f (x.erase' i) 47 | contract i S f fin := fun t ↦ fin.1.filter (fun t' : Tuple I S ↦ t'.erase i = t) |>.sum f 48 | expand_sub sub f x := f (x.project sub) 49 | contract_sub sub f fin := fun t ↦ fin.1.filter (fun t' : Tuple I _ ↦ t'.project sub = t) |>.sum f 50 | rename S ρ equiv f t := f (fun (a : S) ↦ equiv a ▸ t ⟨ ρ a, Finset.mem_image_of_mem _ (Finset.mem_attach _ a) ⟩ ) 51 | 52 | #check @KRel.positiveAlgebra 53 | instance [h : PositiveAlgebra 𝓣] : Mul (𝓣 S) := ⟨ h.mul ⟩ 54 | 55 | instance : One (KRel K I S) := ⟨ fun _ ↦ 1 ⟩ 56 | instance : Zero (KRel K I S) := ⟨ fun _ ↦ 0 ⟩ 57 | 58 | namespace KRel 59 | variable {K} 60 | 61 | def singleton (t : Tuple I S) (v : K) : KRel K I S := fun t' ↦ if t = t' then v else 0 62 | 63 | @[simp] def singleton_zero {v : K} (t t' : Tuple I S) (h : t ≠ t') : singleton t v t' = 0 := by simp [singleton, h] 64 | @[simp] def add_hom (f g : KRel K I S) : (f + g) x = f x + g x := rfl -- forgot what this is called 65 | 66 | def ofList (l : List (Tuple I S × K)) : KRel K I S := l.map (fun (k, v) ↦ KRel.singleton k v) |>.sum 67 | 68 | @[simp] def ofList_cons (kv : Tuple I S × K) : ofList (kv :: l) = singleton kv.fst kv.snd + ofList l := by simp [ofList] 69 | @[simp] def ofList_nil_eq_zero : ofList [] = (0 : KRel K I S) := by simp [ofList] 70 | 71 | def finite_ofList (l : List (Tuple I S × K)) : PositiveAlgebra.finite (KRel.ofList l) where 72 | fst := l.map Prod.fst |>.toFinset 73 | snd t neq_zero := by 74 | induction l with 75 | | nil => cases neq_zero rfl 76 | | cons kv l ih => 77 | simp only [List.map, List.toFinset_cons, List.mem_toFinset, Finset.mem_insert] at ih ⊢ 78 | by_cases h : kv.fst = t 79 | . left; exact h.symm 80 | . right; apply ih; simpa [h] using neq_zero 81 | 82 | def Tuple.nil : Tuple I {} := {} 83 | def nil : Tuple I {} := Tuple.nil 84 | 85 | section examples 86 | 87 | abbrev I1 : Fin 2 → Type := fun _ ↦ Fin 3 88 | def t1 : Tuple I1 {0} := fun | ⟨0, _⟩ => 0 89 | def t2 : Tuple I1 {0} := fun | ⟨0, _⟩ => 1 90 | def l0 := [(t1, 1), (t2, 3)] 91 | def f0 : KRel ℕ I1 {0} := ofList l0 92 | def f1 : KRel ℕ I1 {0} := 1 93 | #synth PositiveAlgebra (KRel ℕ (fun _ : Fin 2 => Fin 2)) 94 | 95 | def f2 : KRel ℕ (fun _ : Fin 2 ↦ Fin 2) {0,1} := 1 96 | 97 | open PositiveAlgebra 98 | def f0_finite : finite f0 := KRel.finite_ofList l0 99 | 100 | notation:15 "∑ " i "," a => contract i _ (ofList a) (finite_ofList a) 101 | --notation:15 "∑ " s "," a => contract_sub (by decide : s ⊆ _) (ofList a) (finite_ofList a) 102 | 103 | #check contract 0 _ (ofList l0) (finite_ofList l0) 104 | def asdf : ({0} : Finset ℕ) ⊆ {0,1} := by decide 105 | #print asdf.proof_1 106 | 107 | #check l0 108 | def fff : ({} : Finset (Fin 2)) ⊆ {0} := by decide 109 | #reduce contract_sub fff (ofList l0) (finite_ofList l0) 110 | #reduce contract 0 _ f0 f0_finite 111 | #eval contract 0 _ f0 f0_finite nil 112 | #check contract 1 _ (ofList l0) (finite_ofList l0) 113 | #eval (∑ 0, l0) nil 114 | #eval (∑ 1, l0) t2 115 | 116 | end examples 117 | 118 | end KRel 119 | -------------------------------------------------------------------------------- /graphs/run.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | mkdir -p bench-output 4 | 5 | ############################# Preparations 6 | 7 | # Lean / Binaries 8 | make -j$(nproc) 9 | 10 | # filtered SpMV 11 | (for nonzeros in 2000000; do 12 | echo data/filtered-spmv-$nonzeros.db 13 | done) | xargs make -j$(nproc) 14 | 15 | # TACO 16 | (for size in 0.0001 0.0003 0.0007 0.001 0.003 0.007 0.01 0.03 0.07 0.1 0.3 0.5; do 17 | echo data/taco-s$size.db 18 | done) | xargs make -j$(nproc) 19 | 20 | # TPC-H 21 | (for size in x0.01 x0.025 x0.05 x0.1 x0.25 x0.5 x1 x2 x4; do 22 | echo data/TPC-H-$size.db 23 | done) | xargs make -j$(nproc) 24 | 25 | # WCOJ 26 | (for size in x1 x3 x10 x30 x100 x300 x1000 x3000 x10000; do 27 | echo data/wcoj-csv-$size data/wcoj-$size.db 28 | done) | xargs make -j$(nproc) 29 | 30 | 31 | ############################# Benchmarks 32 | 33 | # filtered SpMV 34 | for nonzeros in 2000000; do 35 | for db in duckdb etch sqlite; do 36 | echo run-filtered-spmv-$nonzeros-$db 37 | 38 | # warm up 39 | make run-filtered-spmv-$nonzeros-$db >/dev/null 40 | make run-filtered-spmv-$nonzeros-$db >/dev/null 41 | 42 | rm -f bench-output/run-wcoj-$nonzeros-$db.txt 43 | for i in `seq 5`; do 44 | make run-filtered-spmv-$nonzeros-$db >>bench-output/run-filtered-spmv-$nonzeros-$db.txt 45 | done 46 | done 47 | done 48 | 49 | # TACO 50 | for size in 0.0001 0.0003 0.0007 0.001 0.003 0.007 0.01 0.03 0.07 0.1 0.3 0.5; do 51 | echo run-taco-s$size 52 | 53 | # warm up 54 | make run-taco-s$size >/dev/null 55 | make run-taco-s$size >/dev/null 56 | 57 | rm -f bench-output/run-taco-s$size.txt 58 | for i in `seq 5`; do 59 | make run-taco-s$size >>bench-output/run-taco-s$size.txt 60 | done 61 | done 62 | 63 | # TPC-H Q5 64 | for size in x0.01 x0.025 x0.05 x0.1 x0.25 x0.5 x1 x2 x4; do 65 | dbs='duckdb duckdbforeign etch sqlite' 66 | 67 | for db in $dbs; do 68 | echo run-tpch-$size-q5-$db 69 | 70 | # warm up 71 | make run-tpch-$size-q5-$db >/dev/null 72 | make run-tpch-$size-q5-$db >/dev/null 73 | 74 | rm -f bench-output/run-tpch-$size-q5-$db.txt 75 | for i in `seq 5`; do 76 | # echo run-tpch-$size-q5-$db time $i 77 | make run-tpch-$size-q5-$db >>bench-output/run-tpch-$size-q5-$db.txt 78 | done 79 | done 80 | sed -n 's/^q2 took (s): real \([^ ]*\)s.*/\1/p' /dev/null 95 | make run-tpch-$size-q9-$db >/dev/null 96 | 97 | rm -f bench-output/run-tpch-$size-q9-$db.txt 98 | for i in `seq 5`; do 99 | # echo run-tpch-$size-q9-$db time $i 100 | make run-tpch-$size-q9-$db >>bench-output/run-tpch-$size-q9-$db.txt 101 | done 102 | done 103 | sed -n 's/^q2 took (s): real \([^ ]*\)s.*/\1/p' /dev/null 123 | make run-wcoj-$size-$db >/dev/null 124 | 125 | rm -f bench-output/run-wcoj-$size-$db.txt 126 | for i in `seq 5`; do 127 | # echo run-wcoj-$size-$db time $i 128 | make run-wcoj-$size-$db >>bench-output/run-wcoj-$size-$db.txt 129 | done 130 | done 131 | 132 | if [[ $dbs =~ duckdb ]]; then 133 | sed -n 's/^q2 took (s): real \([^ ]*\)s.*/\1/p' .snd 6 | end C 7 | 8 | open C 9 | 10 | def emitString [ToString α] (a : α) : M Unit := modify (· ++ toString a) 11 | 12 | namespace String 13 | def emit (a : String) : M Unit := _root_.modify $ λ s => s ++ a 14 | def emitStart (a : String) : M Unit := λ indent => _root_.modify $ λ s => s ++ indent ++ a 15 | end String 16 | 17 | def emitStart (a : String) : M Unit := λ indent => _root_.modify $ λ s => s ++ indent ++ a 18 | 19 | def emitLine [ToString α] (a : α) : M Unit := 20 | emitString a *> emitString "\n" 21 | 22 | def emitLines [ToString α] (a : Array α) : Op := 23 | a.forM emitLine 24 | 25 | inductive Expr 26 | | lit (n : Int) 27 | | litf (f : Float) 28 | | lits (s : String) 29 | | var (v : Var) 30 | | index (base : Expr) (indices : List Expr) 31 | | star (addr : Expr) 32 | | mul (exprs : List Expr) 33 | | binOp : String → Expr → Expr → Expr 34 | | ternary : Expr → Expr → Expr → Expr 35 | | extern : String → Expr 36 | | call : String → List Expr → Expr 37 | | true 38 | | false 39 | deriving Repr 40 | -- todo? inductive LHS | var (v : Var) | index (base : LHS) (indices : List Expr) deriving Repr 41 | inductive DeclType | mk : String → DeclType 42 | deriving Repr 43 | inductive Stmt 44 | | forIn : (n : Nat) → Var → Stmt → Stmt 45 | | while : Expr → Stmt → Stmt 46 | | cond : Expr → Stmt → Stmt 47 | | conde : Expr → Stmt → Stmt → Stmt 48 | | accum : Expr → Expr → Stmt 49 | | store : Expr → Expr → Stmt 50 | | decl : DeclType → Var → Expr → Stmt 51 | | seq : Stmt → Stmt → Stmt 52 | | noop : Stmt 53 | | extern : String → Stmt 54 | | block : Stmt → Stmt 55 | | break_ 56 | deriving Repr 57 | 58 | class TaggedC (α : Type _) where 59 | tag : DeclType 60 | 61 | instance : TaggedC Nat := ⟨⟨"int"⟩⟩ 62 | instance : TaggedC Int := ⟨⟨"int"⟩⟩ 63 | instance : TaggedC Float := ⟨⟨"float"⟩⟩ 64 | instance : TaggedC Bool := ⟨⟨"bool"⟩⟩ 65 | instance : TaggedC String := ⟨⟨"const char *"⟩⟩ 66 | 67 | instance : OfNat Expr n where 68 | ofNat := Expr.lit n 69 | 70 | def Stmt.sequence : List Stmt → Stmt 71 | | [] => Stmt.noop 72 | | x :: xs => Stmt.seq x $ sequence xs 73 | 74 | def String.wrap (s : String) : String := "(" ++ s ++ ")" 75 | 76 | namespace Expr 77 | 78 | def wrap (s : String) : String := String.wrap s 79 | partial def toString : Expr → String 80 | | lit n => ToString.toString n 81 | | litf n => ToString.toString n 82 | | lits s => "\"" ++ s ++ "\"" -- TODO: escape 83 | | var v => v 84 | | index n indices => toString n ++ (indices.map λ i => "[" ++ toString i ++ "]").foldl String.append "" 85 | | star addr => "*" ++ addr.toString 86 | | mul es => wrap $ String.join (List.intersperse "*" (es.map toString)) 87 | | binOp op a b => wrap $ a.toString ++ op ++ b.toString 88 | | ternary cond a b => wrap $ cond.toString.wrap ++ "?" ++ a.toString.wrap ++ ":" ++ b.toString.wrap 89 | | extern s => s 90 | | call f args => let as := String.join (List.intersperse "," (args.map toString)); s!"{f}({as})" 91 | | true => "true" 92 | | false => "false" 93 | 94 | instance : ToString Expr where 95 | toString := Expr.toString 96 | 97 | def emit : Expr → Op := String.emit ∘ toString 98 | 99 | end Expr 100 | 101 | def DeclType.emit 102 | | mk s => s.emit 103 | 104 | namespace Stmt 105 | 106 | def semicolon := emitLine ";" 107 | 108 | def lf := "".emitStart 109 | def indentUnit : String := " " 110 | def indent : M a → M a := ReaderT.adapt (. ++ indentUnit) 111 | def emit : Stmt → Op 112 | | extern s => emitString s 113 | | accum lhs value => do lf; lhs.emit; " += ".emit; value.emit; emitLine ";" 114 | | store lhs value => do lf; lhs.emit; " = ".emit; value.emit; emitLine ";" 115 | | forIn bound var body => do 116 | lf; emitString s!"for (int {var} = 0; {var} < {bound}; {var}++) \{\n" 117 | indent body.emit 118 | lf; emitLine "}" 119 | | Stmt.while condition body => do lf; "while (".emit ; condition.emit; ") {\n".emit; indent body.emit; lf; emitLine "}" 120 | | cond condition a => do lf; "if (".emit; condition.emit; ") {\n".emit; indent a.emit; lf; emitLine "}" 121 | | conde condition thenb elseb => do lf; "if (".emit; condition.emit; ") {\n".emit; indent thenb.emit; lf; "} else {\n".emit; indent elseb.emit; lf; emitLine "}" 122 | | decl type name value => do lf; type.emit; " ".emit; name.emit; emitString " = "; value.emit; semicolon 123 | | seq p1 p2 => do emit p1; emit p2 124 | | block s => do lf; emitLine "{"; s.emit; lf; emitLine "}" 125 | | break_ => do lf; emitLine "break;" 126 | | noop => pure () 127 | 128 | 129 | def emitIncludes : Op := do 130 | "#include \n".emit 131 | "#include \n".emit 132 | 133 | def emitPrintf : Op := emitLine "" 134 | 135 | def emitProgram (body : Stmt) : Op := do 136 | emitIncludes 137 | emitLine "int main() {" 138 | body.emit 139 | emitPrintf 140 | emitLine "return 0;" 141 | emitLine "}" 142 | 143 | def compile (p : Stmt) : String := p.emit |>.run 144 | def compileWithWrapper (p : Stmt) : String := p.emitProgram |>.run 145 | end Stmt 146 | -------------------------------------------------------------------------------- /Etch/Op.lean: -------------------------------------------------------------------------------- 1 | import Etch.Basic 2 | 3 | class Tagged (α : Type _) where 4 | tag : String 5 | 6 | def tag_mk_fun (α : Type _) [Tagged α] (fn : String) : String := 7 | (Tagged.tag α) ++ "_" ++ fn 8 | 9 | inductive R | mk 10 | 11 | instance : Tagged Unit := ⟨ "macro" ⟩ -- default type for actual monotypic function 12 | instance : Tagged ℕ := ⟨ "nat" ⟩ 13 | instance : Tagged Int := ⟨ "int" ⟩ 14 | instance : Tagged String := ⟨ "str" ⟩ 15 | instance : Tagged Bool := ⟨ "bool" ⟩ 16 | instance : Tagged R := ⟨ "num" ⟩ 17 | 18 | instance : Inhabited R := ⟨ R.mk ⟩ 19 | -- todo 20 | instance : Add R := ⟨ λ _ _ => default ⟩ 21 | instance : LT R := ⟨ λ _ _ => false ⟩ 22 | instance : DecidableRel (LT.lt : R → R → _) := λ .mk .mk => .isFalse (by simp [LT.lt] ) 23 | instance : LE R := ⟨ λ _ _ => false ⟩ 24 | instance : DecidableRel (LE.le : R → R → _) := λ .mk .mk => .isFalse (by simp [LE.le] ) 25 | 26 | instance : DecidableEq R := fun .mk .mk => .isTrue (by simp) 27 | instance : Max R := ⟨fun _ _ => default⟩ 28 | instance : Mul R := ⟨ λ _ _ => default ⟩ 29 | 30 | instance : Sub R := ⟨ λ _ _ => default ⟩ 31 | 32 | instance : OfNat R (nat_lit 0) := ⟨ default ⟩ 33 | instance : OfNat R (nat_lit 1) := ⟨ default ⟩ 34 | 35 | instance : Coe ℕ R := ⟨fun _ => default⟩ 36 | 37 | namespace String 38 | 39 | instance instLEString : LE String := ⟨fun s₁ s₂ ↦ s₁ < s₂ || s₁ = s₂⟩ 40 | 41 | instance decLe : @DecidableRel String (· ≤ ·) 42 | | s₁, s₂ => if h₁ : s₁ < s₂ then isTrue (by simp [instLEString, h₁]) 43 | else if h₂ : s₁ = s₂ then isTrue (by simp [instLEString, h₂]) 44 | else isFalse (by simp [instLEString, h₁, h₂]) 45 | 46 | instance zero : Zero String := ⟨""⟩ 47 | 48 | instance max : Max String := ⟨fun s₁ s₂ ↦ if s₁ < s₂ then s₂ else s₁⟩ 49 | 50 | end String 51 | 52 | --attribute [irreducible] RMin 53 | --attribute [irreducible] RMax 54 | 55 | structure Op (α : Type _) where 56 | arity : ℕ 57 | argTypes : Fin arity → Type 58 | spec : ((n : Fin arity) → argTypes n) → α 59 | opName : String 60 | 61 | attribute [reducible] Op.argTypes 62 | attribute [simp] Op.spec 63 | 64 | -- def Op.name (f : Op β) : String := f.tag ++ "_" ++ f.opName 65 | 66 | def Op.lt [Tagged α] [LT α] [DecidableRel (LT.lt : α → α → _) ] : Op Bool where 67 | argTypes := ![α, α] 68 | spec := λ a => a 0 < a 1 69 | opName := tag_mk_fun α "lt" 70 | 71 | def Op.le [Tagged α] [LE α] [DecidableRel (LE.le : α → α → _) ] : Op Bool where 72 | argTypes := ![α, α] 73 | spec := λ a => a 0 ≤ a 1 74 | opName := tag_mk_fun α "le" 75 | 76 | def Op.max [Tagged α] [Max α] : Op α where 77 | argTypes := ![α, α] 78 | spec := λ a => Max.max (a 0) (a 1) 79 | opName := tag_mk_fun α "max" 80 | 81 | def Op.min [Tagged α] [Min α] : Op α where 82 | argTypes := ![α, α] 83 | spec := λ a => Min.min (a 0) (a 1) 84 | opName := tag_mk_fun α "min" 85 | 86 | @[simps, reducible] 87 | def Op.eq [Tagged α] [DecidableEq α] : Op Bool where 88 | argTypes := ![α, α] 89 | spec := λ a => a 0 = a 1 90 | opName := tag_mk_fun α "eq" 91 | 92 | @[simps] 93 | def Op.add [Tagged α] [Add α] : Op α where 94 | argTypes := ![α, α] 95 | spec := λ a => a 0 + a 1 96 | opName := tag_mk_fun α "add" 97 | 98 | @[simps] 99 | def Op.sub [Tagged α] [Sub α] : Op α where 100 | argTypes := ![α, α] 101 | spec := λ a => a 0 - a 1 102 | opName := tag_mk_fun α "sub" 103 | 104 | def Op.mid : Op ℕ where 105 | argTypes := ![ℕ, ℕ] 106 | spec := λ a => Nat.div (a 0 + a 1) 2 107 | opName := tag_mk_fun ℕ "mid" 108 | 109 | def Op.mul [Tagged α] [Mul α] : Op α where 110 | argTypes := ![α, α] 111 | spec := λ a => a 0 * a 1 112 | opName := tag_mk_fun α "mul" 113 | 114 | def Op.div [Tagged α] [HDiv α α β] : Op β where 115 | argTypes := ![α, α] 116 | spec := λ a => a 0 / a 1 117 | opName := tag_mk_fun α "div" 118 | 119 | @[simps] 120 | def Op.neg : Op Bool where 121 | argTypes := ![Bool] 122 | spec := λ a => not $ a 0 123 | opName := tag_mk_fun Bool "neg" 124 | 125 | def Op.one [Tagged α] [OfNat α 1] : Op α where 126 | argTypes := ![] 127 | spec := λ _ => 1 128 | opName := tag_mk_fun α "one" 129 | 130 | def Op.zero [Tagged α] [OfNat α 0] : Op α where 131 | argTypes := ![] 132 | spec := λ _ => 0 133 | opName := tag_mk_fun α "zero" 134 | 135 | def Op.atoi : Op ℕ where 136 | argTypes := ![String] 137 | spec := λ _ => default 138 | opName := tag_mk_fun String "atoi" 139 | 140 | def Op.atof : Op R where 141 | argTypes := ![String] 142 | spec := λ _ => default -- todo 143 | opName := tag_mk_fun String "atof" 144 | 145 | def Op.ofBool [Tagged α] [OfNat α (nat_lit 0)] [OfNat α (nat_lit 1)] : Op α where 146 | argTypes := ![Bool] 147 | spec := λ a => if a 0 then 1 else 0 148 | opName := tag_mk_fun α "ofBool" 149 | 150 | def Op.toNum : Op R where 151 | argTypes := ![ℕ] 152 | spec := λ _ => default 153 | opName := tag_mk_fun ℕ "toNum" 154 | 155 | def Op.ternary : Op α where 156 | argTypes := ![Bool, α, α] 157 | spec := λ a => bif (a 0) then a 1 else a 2 158 | opName := "macro_ternary" 159 | 160 | def Op.access {ι α : Type} : Op α := 161 | { argTypes := ![ι → α, ι], 162 | spec := λ x => (x 0) (x 1), 163 | opName := "arr_access" } 164 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Etch 2 | 3 | This repository implements indexed streams, a representation for fused 4 | *contraction programs* like those found in sparse tensor algebra and relational 5 | algebra. 6 | 7 | Correctness proofs and compiler are written in the [Lean 4][lean4] language. 8 | 9 | [lean4]: https://github.com/leanprover/lean4 10 | 11 | ## Directory structure 12 | 13 | ### Compiler and benchmarks 14 | 15 | ``` 16 | etch4 17 | ├── Etch # the compiler 18 | │   ├── Basic.lean # compiler core 19 | │   ├── C.lean 20 | │   ├── Compile.lean 21 | │   ├── LVal.lean 22 | │   ├── Op.lean 23 | │   ├── ShapeInference.lean 24 | │   ├── Stream.lean 25 | │   ├── Add.lean # basic streams 26 | │   ├── Mul.lean 27 | │   ├── Benchmark.lean # benchmark queries 28 | │   ├── Benchmark 29 | │   │   └── ... 30 | │   ├── Verification # stream model formalization 31 | │   │   └── README.md 32 | │   ├── KRelation.lean # work in progress 33 | │   ├── Omni.lean 34 | │   └── InductiveStream… 35 | ├── Makefile # workflows 36 | ├── bench # benchmark auxiliary files (SQL files, data generators, etc.) 37 | │   └── ... 38 | ├── bench-….cpp # benchmark drivers 39 | ├── common.h 40 | ├── operators.h 41 | ├── graphs # run benchmarks; plot graphs 42 | │   └── ... 43 | ├── taco # TACO-compiled kernels as baseline 44 | │   └── ... 45 | └── taco_kernels.c 46 | ``` 47 | 48 | ## Build compiler and proofs 49 | 50 | First install [Lean 4](https://leanprover.github.io/lean4/doc/quickstart.html). 51 | In the `etch4` directory, run 52 | ``` 53 | lake update 54 | lake exe cache get 55 | lake build 56 | ``` 57 | Then, load Etch/Benchmark.lean in your editor. 58 | 59 | ### Compile benchmarks 60 | 61 | While loading the Etch/Benchmark.lean in your editor, Lean will automatically 62 | compile all the benchmarks and write them to `etch4/gen_….c`. 63 | 64 | Alternatively, you can compile by running `make gen_tpch_q5.c FORCE_REGEN=1`. 65 | 66 | ### Running benchmarks 67 | 68 | Only Linux is supported right now. 69 | 70 | Dependencies you need to have before running any benchmarks: 71 | * `make` ([GNU Make](https://www.gnu.org/software/make/)) 72 | * `clang` ([LLVM Clang](https://clang.llvm.org/)) 73 | * `bc` ([GNU bc](https://www.gnu.org/software/bc/)) 74 | 75 | Other dependencies (e.g., baselines) are automatically downloaded by the 76 | Makefile. 77 | 78 | #### Run individual benchmarks 79 | 80 | You can run an individual benchmark by calling `make`. Here are some examples: 81 | ``` 82 | make run-tpch-x1-q5-duckdb 83 | make run-tpch-x1-q5-duckdbforeign 84 | make run-tpch-x1-q5-etch 85 | make run-tpch-x1-q5-sqlite 86 | 87 | make run-wcoj-x1000-etch 88 | ``` 89 | 90 | If any test data need to be (re-)generated, the above commands will 91 | automatically do so. 92 | 93 | For a full list of supported targets, run `make list-benchmarks`. 94 | 95 | #### Run all benchmarks 96 | 97 | Run `graphs/run.sh`. 98 | This will run all the benchmarks shown in our PLDI 2023 paper, and save the 99 | results in bench-output/. 100 | 101 | Note: this will take a **long** time (≥1.5 hours). 102 | 103 | #### Generate graphs 104 | 105 | Make sure all the benchmarks have been run with `graphs/run.sh`. 106 | 107 | Make sure the benchmark results are stored in bench-output/. 108 | 109 | Make sure you have [Poetry](https://python-poetry.org/) (a Python package 110 | manager) installed. 111 | 112 | Then run: 113 | ``` 114 | cd graphs 115 | poetry install 116 | poetry shell 117 | cd .. 118 | python graphs/graph.py 119 | ``` 120 | 121 | Graphs will be generated in PDF form in the root directory. 122 | 123 | #### Running benchmarks in Docker 124 | 125 | We also provide a Dockerfile to simplify the process of setting up an 126 | environment. This is primarily useful for artifact evaluation. See 127 | `graphs/Dockerfile` for details. 128 | 129 | If you will be developing Etch locally, we recommend going through the previous 130 | steps instead. 131 | 132 | ## Publications 133 | 134 | This repository implements indexed streams as defined in the paper: 135 | 136 | > Scott Kovach, Praneeth Kolichala, Tiancheng Gu, and Fredrik Kjolstad. 2023. 137 | > Indexed Streams: A Formal Intermediate Representation for Fused Contraction 138 | > Programs. To appear in Proc. ACM Program. Lang. 7, 139 | > PLDI, Article 154 (June 2023), 25 pages. https://doi.org/10.1145/3591268 140 | 141 | ## Old Correctness proofs 142 | 143 | These were written originally but recently automatically ported to Lean4. 144 | 145 | ``` 146 | . 147 | └── src 148 |    └── verification 149 |    ├── code_generation # WIP code generation proofs 150 |    │   └── ... 151 |    ├── semantics # correctness proofs 152 |    │   ├── README.md 153 |    │   └── ... 154 |    └── test.lean 155 | ``` 156 | 157 | ### Build old proofs 158 | 159 | First install [Lean 3](https://leanprover-community.github.io/get_started.html). 160 | In the root directory, run 161 | ``` 162 | leanproject get-mathlib-cache 163 | leanproject build 164 | ``` 165 | -------------------------------------------------------------------------------- /Etch/LVal.lean: -------------------------------------------------------------------------------- 1 | import Etch.Basic 2 | import Etch.Stream 3 | 4 | variable 5 | (ι : Type _) [Tagged ι] [DecidableEq ι] 6 | [LE ι] [DecidableRel (LE.le : ι → ι → _)] [LT ι] [DecidableRel (LT.lt : ι → ι → _)] 7 | {α : Type _} [Tagged α] [OfNat α (nat_lit 0)] 8 | 9 | abbrev loc := E ℕ 10 | structure il (ι : Type _) := (push' : (loc → P) → E ι → P × loc) 11 | structure vl (α : Type _) := (value : loc → α) (init : loc → P) 12 | structure lvl (ι α : Type _) := (push : E ι → P × α) -- (declare : P) (σ : Type) 13 | 14 | instance : Functor (lvl ι) := { map := λ f l => { push := Prod.map id f ∘ l.push } } 15 | 16 | def lvl.of {ι α} (i : il ι) (v : vl α) : lvl ι α := v.value <$> ⟨i.push' v.init⟩ 17 | 18 | variable {ι} 19 | 20 | infixl:20 "||" => Add.add 21 | 22 | structure MemLoc (α : Type) := (arr : Var (ℕ → α)) (ind : E ℕ) 23 | 24 | def MemLoc.access (m : MemLoc α) : E α := m.arr.access m.ind 25 | 26 | structure Dump (α : Type) where 27 | 28 | def sparse_il (ind_array : Var (ℕ → ι)) (bounds : MemLoc ℕ) : il ι := 29 | let array := bounds.arr 30 | let ind := bounds.ind 31 | let lower := array.access ind 32 | let upper := array.access (ind + 1) 33 | let loc := upper - 1 34 | let current := ind_array.access loc 35 | { push' := λ init i => 36 | let prog := P.if1 (lower == upper || i != current) 37 | (array.incr_array (ind + 1);; init loc);; 38 | P.store_mem ind_array loc i 39 | (prog, loc) } 40 | 41 | def dense_il (dim : E ℕ) (counter : Var ℕ) (base : E ℕ) : il ℕ := 42 | { push' := λ init i => 43 | let l (i : E ℕ) : loc := base * dim + i 44 | let prog : P := P.while (counter.expr <= i) (init (l counter);; counter.incr) 45 | (prog, l i) } 46 | 47 | def interval_vl (array : ArrayVar ℕ) : vl (MemLoc ℕ) := 48 | { value := λ loc => ⟨array, loc⟩, 49 | init := λ loc => .store_mem array (loc + 1) (.access array loc) } 50 | def dense_vl (array : ArrayVar α) : vl (MemLoc α) := 51 | { value := λ loc => ⟨array, loc⟩, 52 | init := λ loc => .store_mem array loc 0 } 53 | def implicit_vl : vl (E ℕ) := { value := id, init := λ _ => P.skip } 54 | def dump_vl : vl (Dump α) := { value := fun _ => .mk, init := fun _ => P.skip } 55 | 56 | -- this combinator combines an il with a vl to form a lvl. 57 | -- the extra parameter α is used to thread the primary argument to a level through ⊚. 58 | -- see dcsr/csr_mat/dense below 59 | def with_values : (α → il ι) → vl β → α → lvl ι β := λ i v e => lvl.of (i e) v 60 | 61 | -- somehow with_values doesn't work with dump_vl… 62 | def without_values : (α → il ι) → α → lvl ι (Dump β) := 63 | fun i e => 64 | lvl.mk ((Prod.map id fun _ => Dump.mk) ∘ (i e).push' (fun _ => .skip)) 65 | 66 | def dcsr (l : String) : lvl ℕ (lvl ℕ (MemLoc α)) := 67 | (interval_vl $ l ++ "1_pos").value 0 |> 68 | (with_values (sparse_il (l ++ "1_crd" : ArrayVar ℕ)) (interval_vl $ l ++ "2_pos")) ⊚ 69 | (with_values (sparse_il (l ++ "2_crd" : ArrayVar ℕ)) (dense_vl $ l ++ "_vals")) 70 | 71 | def csr_mat (l dim i : String) : lvl ℕ (lvl ℕ (MemLoc α)) := 0 |> 72 | (with_values (dense_il dim i) (interval_vl $ l ++ "2_pos")) ⊚ 73 | (with_values (sparse_il $ l ++ "2_crd") (dense_vl $ l ++ "_vals")) 74 | 75 | def dense_vec (l : String) (d₁ : E ℕ) (i : String) : lvl ℕ (MemLoc α) := (0 : E ℕ) |> 76 | (with_values (dense_il d₁ i) $ dense_vl $ l ++ "_vals") 77 | 78 | def sparse_vec (l : String) : lvl ℕ (MemLoc α) := 79 | (interval_vl $ l ++ "1_pos").value 0 |> 80 | (with_values (sparse_il $ l ++ "1_crd") $ dense_vl $ l ++ "_vals") 81 | 82 | def dense_mat (d₁ d₂ : E ℕ) : lvl ℕ (lvl ℕ (MemLoc ℕ)) := (0 : E ℕ) |> 83 | (with_values (dense_il d₁ "i1") implicit_vl) ⊚ 84 | (with_values (dense_il d₂ "i2") $ dense_vl "values") 85 | 86 | def cube_lvl (d₁ d₂ d₃ : E ℕ) := 0 |> 87 | (with_values (dense_il d₁ "i1") implicit_vl) ⊚ 88 | (with_values (dense_il d₂ "i2") implicit_vl) ⊚ 89 | (with_values (dense_il d₃ "i3") $ dense_vl "values") 90 | --def sparse_vec : lvl ℕ (MemLoc α) := ⟨("size" : Var ℕ), (0 : E ℕ)⟩ & 91 | -- (with_values (sparse_il ("A1_crd" : Var ℕ)) (dense_vl "A_vals")) 92 | 93 | def tcsr (l : String) : lvl ℕ (lvl ℕ (lvl ℕ (MemLoc α))) := 94 | (interval_vl $ l ++ "1_pos").value 0 |> 95 | (with_values (sparse_il (l ++ "1_crd" : ArrayVar ℕ)) (interval_vl $ l ++ "2_pos")) ⊚ 96 | (with_values (sparse_il (l ++ "2_crd" : ArrayVar ℕ)) (interval_vl $ l ++ "3_pos")) ⊚ 97 | (with_values (sparse_il (l ++ "3_crd" : ArrayVar ℕ)) (dense_vl $ l ++ "_vals")) 98 | 99 | def dss (l dim i : String) : lvl ℕ (lvl ℕ (lvl ℕ (MemLoc α))) := 0 |> 100 | (with_values (dense_il dim i) (interval_vl $ l ++ "2_pos")) ⊚ 101 | (with_values (sparse_il (l ++ "2_crd" : ArrayVar ℕ)) (interval_vl $ l ++ "3_pos")) ⊚ 102 | (with_values (sparse_il (l ++ "3_crd" : ArrayVar ℕ)) (dense_vl $ l ++ "_vals")) 103 | 104 | #exit 105 | --todo 106 | inductive LevelType | s | d 107 | def trieType' (α : Type) : List LevelType → Type _ 108 | | [] => α → α 109 | | _ :: xs => α → lvl ℕ (trieType' α xs) 110 | 111 | def trieType (α : Type) : List LevelType → Type _ 112 | | [] => MemLoc α 113 | | _ :: xs => lvl ℕ (trieType α xs) 114 | 115 | #check @trieType 116 | def object (l : String) : (t : List LevelType) → MemLoc ℕ → trieType α t 117 | | [] => ⟨"no", 0⟩ 118 | | [x] => λ e => 119 | (with_values (sparse_il (l ++ "1_crd" : ArrayVar ℕ)) (dense_vl $ l ++ "_vals")) 120 | -------------------------------------------------------------------------------- /archive/src/verification/semantics/dense.lean: -------------------------------------------------------------------------------- 1 | import verification.semantics.skip_stream 2 | 3 | /-! 4 | # Dense Vectors as Indexed Streams (sanity check) 5 | 6 | In this file, we show that `denseVec` (modelling dense vectors as a stream) 7 | satisfies the stream conditions (it is strictly lawful); therefore our conditions are not vacuous. 8 | A similar thing can be done for sparse vectors. 9 | 10 | This is mostly a lot of tedious casework. TODO: can we automate this to an SMT solver? 11 | 12 | ## Definitions: 13 | We define `Stream.denseVec vals`, which takes a vector `vals` and constructs 14 | an always-ready stream that outputs the elements of `vals`. The state of `denseVec` 15 | is considered to be `fin (n + 1)`, the natural numbers `0 ≤ q ≤ n`, where `q : fin (n + 1)` is the terminated state 16 | 17 | ## Main results: 18 | - `Stream.denseVec_eval`: Evaluating from a state `q : fin (n + 1)` results in 19 | emitting `vals[q:]` at the appropriate indices. 20 | - Corollary (`Stream.denseVec_eval_start`): Starting from `q = 0` produces the whole vector. 21 | - `is_strict_lawful (Stream.denseVec vals)`: The stream associated with a dense vector is strictly lawful 22 | -/ 23 | 24 | 25 | 26 | variables {α : Type*} 27 | 28 | def Stream.denseVec {n : ℕ} (vals : fin n → α) : Stream (fin n) α := 29 | { σ := fin (n + 1), 30 | valid := λ i, ↑i < n, 31 | ready := λ i, ↑i < n, 32 | skip := λ i hi j, max i (cond j.2 j.1.succ (fin.cast_le n.le_succ j.1)), 33 | index := λ i hi, i.cast_lt hi, 34 | value := λ i hi, vals (i.cast_lt hi), } 35 | 36 | section 37 | local attribute [reducible] Stream.denseVec 38 | 39 | instance {n} (vals : fin n → α) : is_bounded (Stream.denseVec vals) := 40 | ⟨⟨(>), finite.preorder.well_founded_gt, λ i hi j, begin 41 | simp [Stream.to_order, hi], 42 | rcases j with ⟨j, (b|b)⟩, 43 | { rw prod.lex.lt_iff'', cases j, cases i, simp, rw [or_iff_not_imp_left, not_lt], tauto, }, 44 | { rw prod.lex.lt_iff'', cases i, cases j, simp [@lt_iff_not_le _ _ tt, imp_false, ← lt_iff_le_and_ne, nat.lt_succ_iff, nat.succ_le_iff], exact le_or_lt _ _, }, 45 | end⟩⟩ 46 | 47 | variables [add_zero_class α] 48 | 49 | lemma fin.cast_lt_le_iff {m n : ℕ} (a b : fin n) (h₁ : ↑a < m) (h₂ : ↑b < m) : 50 | a.cast_lt h₁ ≤ b.cast_lt h₂ ↔ a ≤ b := 51 | by { cases a, cases b, simp, } 52 | 53 | lemma Stream.denseVec.eq_n_of_invalid {n : ℕ} {vals : fin n → α} {q : fin (n + 1)} 54 | (hq : ¬(Stream.denseVec vals).valid q) : ↑q = n := eq_of_le_of_not_lt (nat.lt_succ_iff.mp q.prop) hq 55 | 56 | lemma Stream.denseVec_eval {n : ℕ} (vals : fin n → α) (q : fin (n + 1)) : 57 | ⇑((Stream.denseVec vals).eval q) = (λ j, if (fin.cast_succ j) < q then 0 else vals j) := 58 | begin 59 | refine @well_founded.induction _ _ (Stream.denseVec vals).wf _ q _, 60 | clear q, intros q ih, 61 | by_cases hq : (Stream.denseVec vals).valid q, swap, 62 | { replace hq : ↑q = n := Stream.denseVec.eq_n_of_invalid hq, 63 | rw [Stream.eval_invalid], swap, { exact hq.not_lt, }, 64 | ext j, 65 | have : (fin.cast_succ j) < q, { rw fin.lt_iff_coe_lt_coe, rw hq, simp, }, 66 | simp only [this, finsupp.coe_zero, pi.zero_apply, if_true], }, 67 | { rw [Stream.eval_valid, Stream.eval₀, dif_pos]; try { exact hq }, 68 | ext j, 69 | rw [finsupp.add_apply, ih _ ((Stream.denseVec vals).next_wf q hq)], dsimp only, 70 | rcases lt_trichotomy (fin.cast_succ j) q with (h|h|h), 71 | { rw [if_pos, add_zero], swap, { simp [Stream.next_val hq], left, assumption, }, 72 | rw [if_pos], swap, { assumption, }, 73 | rw finsupp.single_apply_eq_zero, intro h', exfalso, refine h.not_le _, 74 | simp [h'], }, 75 | { have : fin.cast_lt q hq = j, { simp [← h], }, 76 | rw [if_pos, add_zero], swap, { simp [Stream.next_val hq, (show ↑q < n, from hq)], right, rw this, exact fin.cast_succ_lt_succ _, }, 77 | rw [if_neg], swap, { exact h.not_lt, }, 78 | simp [this], }, 79 | { rw [if_neg], swap, { simp [Stream.next_val hq, (show ↑q < n, from hq), not_or_distrib], revert h, cases j, cases q, simp [nat.succ_le_iff], intro h, exact ⟨h.le, h⟩, }, 80 | rw [if_neg], swap, { exact h.le.not_lt, }, 81 | rw [finsupp.single_apply_eq_zero.mpr, zero_add], 82 | rintro rfl, exfalso, simpa using h, } }, 83 | end 84 | 85 | @[simp] lemma Stream.denseVec_eval_start {n : ℕ} (vals : fin n → α) : 86 | ⇑((Stream.denseVec vals).eval 0) = vals := 87 | by { rw Stream.denseVec_eval, ext j, simp, } 88 | 89 | instance {n} (vals : fin n → α) : is_lawful (Stream.denseVec vals) := 90 | { mono := Stream.is_monotonic_iff.mpr $ λ q hq i hq', by { dsimp, rw fin.cast_lt_le_iff, exact le_max_left _ _, }, 91 | skip_spec := λ q hq i j hj, begin 92 | simp only [Stream.denseVec_eval], 93 | by_cases h₁ : fin.cast_succ j < q, { simp [h₁], }, 94 | rw [if_neg, if_neg], { assumption, }, 95 | rw [lt_max_iff, not_or_distrib], 96 | refine ⟨h₁, _⟩, rw [not_lt], 97 | rcases i with ⟨i, (b|b)⟩; simp only [cond], 98 | { rw [fin.cast_succ, fin.cast_add, order_embedding.le_iff_le], 99 | simpa using hj, }, 100 | { rw ← fin.cast_succ_lt_iff_succ_le, 101 | simpa [prod.lex.le_iff'] using hj, }, 102 | end } 103 | 104 | lemma Stream.denseVec_index'_mono {n} (vals : fin n → α) : strict_mono (Stream.denseVec vals).index' := 105 | λ q₁ q₂, begin 106 | simp only [Stream.index'], split_ifs with h₁ h₂ h₂, 107 | { cases q₁, cases q₂, simp, }, { simp only [with_top.coe_lt_top], exact λ _, trivial, }, 108 | { simp only [not_top_lt, not_lt, fin.lt_iff_coe_lt_coe, Stream.denseVec.eq_n_of_invalid h₁, imp_false], exact nat.lt_succ_iff.mp q₂.prop, }, 109 | { simp [Stream.denseVec.eq_n_of_invalid h₁, Stream.denseVec.eq_n_of_invalid h₂, fin.lt_iff_coe_lt_coe], }, 110 | end 111 | 112 | instance {n} (vals : fin n → α) : is_strict_lawful (Stream.denseVec vals) := 113 | ⟨⟨Stream.mono _, λ q hq i hi hr, ne_of_lt begin 114 | apply Stream.denseVec_index'_mono, 115 | rcases i with ⟨i, b⟩, dsimp only, 116 | suffices : q < cond b i.succ (fin.cast_succ i), 117 | { rw max_eq_right, { exact this, }, exact this.le, }, 118 | simp only [Stream.to_order, (show ↑q < n, from hq), to_bool_true_eq_tt] at hi, 119 | cases i, cases q, cases b; simpa [nat.lt_succ_iff] using hi, 120 | end⟩⟩ 121 | 122 | end 123 | 124 | -------------------------------------------------------------------------------- /bench-sqlite.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "common.h" 11 | #include "sqlite3.h" 12 | 13 | namespace { 14 | 15 | std::string_view DB_PREFIX = "db:"; 16 | 17 | std::pair ParseArg(std::string_view s) { 18 | int reps = -1; 19 | bool found = false; 20 | 21 | for (int i = 0; i < s.size(); ++i) { 22 | if ('0' <= s[i] && s[i] <= '9') { 23 | found = true; 24 | } else if (s[i] == ':' && found) { 25 | reps = std::stoul(std::string(s.substr(0, i))); 26 | s.remove_prefix(i + 1); 27 | break; 28 | } else { 29 | break; 30 | } 31 | } 32 | 33 | return {reps, s}; 34 | } 35 | 36 | void PrintRow(sqlite3_stmt* stmt) { 37 | int num_cols = sqlite3_column_count(stmt); 38 | 39 | for (int i = 0; i < num_cols; ++i) { 40 | if (i > 0) { 41 | std::cout << '|'; 42 | } 43 | std::cout << sqlite3_column_text(stmt, i); 44 | } 45 | std::cout << std::endl; 46 | } 47 | 48 | } // namespace 49 | 50 | #define CHECK_ERR(c, msg) \ 51 | do { \ 52 | int tmp = (c); \ 53 | if (tmp) { \ 54 | std::cerr << "Error while " << msg << ": " << sqlite3_errstr(tmp) \ 55 | << std::endl; \ 56 | exit(1); \ 57 | } \ 58 | } while (false) 59 | 60 | int main(int argc, char* argv[]) { 61 | char* zErrMsg = 0; 62 | int rc = SQLITE_OK; 63 | 64 | sqlite3* db; 65 | CHECK_ERR(sqlite3_initialize(), "initializing"); 66 | 67 | int argi = 1; 68 | if (argi < argc && std::string_view(argv[argi]).starts_with(DB_PREFIX)) { 69 | std::string_view db_file(argv[1]); 70 | db_file.remove_prefix(DB_PREFIX.size()); 71 | CHECK_ERR(sqlite3_open(db_file.data(), &db), "opening DB"); 72 | } else { 73 | CHECK_ERR(sqlite3_open(":memory:", &db), "opening in-memory DB"); 74 | } 75 | 76 | int query_idx = 1; 77 | 78 | for (; argi < argc; ++argi) { 79 | auto [reps, file] = ParseArg(argv[argi]); 80 | 81 | std::string filename(file); 82 | std::ifstream f(filename); 83 | std::stringstream file_ss; 84 | file_ss << f.rdbuf(); 85 | std::string sql_str = std::move(file_ss).str(); 86 | std::string_view sql_sv = sql_str; 87 | 88 | std::stringstream str_ss; 89 | str_ss << "q" << query_idx; 90 | std::string str_i = std::move(str_ss).str(); 91 | 92 | if (reps > 0) { 93 | sqlite3_stmt* stmt = nullptr; 94 | time( 95 | [&]() { 96 | do { 97 | const char* remaining = nullptr; 98 | CHECK_ERR(sqlite3_prepare_v2(db, sql_sv.data(), sql_sv.size(), 99 | &stmt, &remaining), 100 | "preparing " << filename); 101 | if (remaining) { 102 | sql_sv.remove_prefix(remaining - sql_sv.data()); 103 | } else { 104 | sql_sv.remove_prefix(sql_sv.size()); 105 | } 106 | sql_sv.remove_prefix(std::min( 107 | sql_sv.find_first_not_of(" \t\r\v\n"), sql_sv.size())); 108 | } while (!stmt && !sql_sv.empty()); 109 | return 0; 110 | }, 111 | (str_i + " prep").c_str(), 1); 112 | if (!stmt) { 113 | continue; 114 | } 115 | time( 116 | [&]() { 117 | int res = sqlite3_step(stmt); 118 | for (; res != SQLITE_DONE; res = sqlite3_step(stmt)) { 119 | if (res == SQLITE_ROW) { 120 | PrintRow(stmt); 121 | } else { 122 | CHECK_ERR(res, "running " << filename); 123 | } 124 | } 125 | CHECK_ERR(sqlite3_reset(stmt), "resetting " << filename); 126 | return res; 127 | }, 128 | str_i.c_str(), reps); 129 | 130 | CHECK_ERR(sqlite3_finalize(stmt), "finalizing " << filename); 131 | 132 | while (!sql_sv.empty()) { 133 | const char* remaining = nullptr; 134 | CHECK_ERR(sqlite3_prepare_v2(db, sql_sv.data(), sql_sv.size(), &stmt, 135 | &remaining), 136 | "preparing " << filename); 137 | if (remaining) { 138 | sql_sv.remove_prefix(remaining - sql_sv.data()); 139 | } else { 140 | sql_sv.remove_prefix(sql_sv.size()); 141 | } 142 | 143 | if (stmt) { 144 | std::cerr << "Error while executing " << file 145 | << ": more than one query in file" << std::endl; 146 | CHECK_ERR(sqlite3_finalize(stmt), "finalizing " << filename); 147 | return 1; 148 | } 149 | } 150 | } else { 151 | time( 152 | [&]() { 153 | do { 154 | sqlite3_stmt* stmt = nullptr; 155 | const char* remaining = nullptr; 156 | CHECK_ERR(sqlite3_prepare_v2(db, sql_sv.data(), sql_sv.size(), 157 | &stmt, &remaining), 158 | "preparing " << filename); 159 | if (remaining) { 160 | sql_sv.remove_prefix(remaining - sql_sv.data()); 161 | } else { 162 | sql_sv.remove_prefix(sql_sv.size()); 163 | } 164 | sql_sv.remove_prefix(std::min( 165 | sql_sv.find_first_not_of(" \t\r\v\n"), sql_sv.size())); 166 | 167 | int res = sqlite3_step(stmt); 168 | for (; res != SQLITE_DONE; res = sqlite3_step(stmt)) { 169 | if (res == SQLITE_ROW) { 170 | PrintRow(stmt); 171 | } else { 172 | CHECK_ERR(res, "running " << filename); 173 | } 174 | } 175 | 176 | CHECK_ERR(sqlite3_finalize(stmt), "finalizing " << filename); 177 | } while (!sql_sv.empty()); 178 | return 0; 179 | }, 180 | str_i.c_str(), 1); 181 | 182 | } 183 | ++query_idx; 184 | } 185 | } 186 | --------------------------------------------------------------------------------