├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── README.md ├── deploy.md ├── pyproject.toml ├── requirements.txt └── src ├── __about__.py ├── app.py ├── codex_prompt.txt ├── database.py ├── demo_search.py ├── embed_mathlib ├── .gitattributes ├── count_tokens.py ├── embed_mathlib.py ├── embeddings_to_numpy.py └── np_embeddings.npy ├── parse_docgen ├── .gitattributes ├── docgen_export_with_formal_statement.jsonl └── parse.py ├── templates └── main.html └── web.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Elastic Beanstalk Files 132 | .elasticbeanstalk/* 133 | !.elasticbeanstalk/*.cfg.yml 134 | !.elasticbeanstalk/*.global.yml 135 | 136 | cache -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zhangir Azerbayev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Note:** Because this repo relied on the now unsupported `code-davinci-002`, the search demo will no longer work. It is trivial to replace codex with the chat api. 2 | # mathlib-semantic-search 3 | Semantic search for Lean mathlib using OpenAI's [embeddings API](https://openai.com/blog/new-and-improved-embedding-model/) and [Faiss](https://github.com/facebookresearch/faiss). 4 | 5 | Tested with python 3.10.4. Packages in `requirements.txt`. 6 | 7 | ## Running demo 8 | Save an openai api key to the environment variable `OPENAI_API_KEY`. Then `cd` into `src` and run `demo_search.py`. The mathlib embeddings are pre-computed, but do note embedding a query will cost you `$0.0004 / 1K tokens`. 9 | 10 | ## Recomputing mathlib embeddings 11 | This step is only documented for reproducibility purposes. It is not necessary to run the demo. 12 | 13 | Pull [leanprover-community/doc-gen](https://github.com/leanprover-community/doc-gen) and run the `gen_docs` script. When generating the docs, we used mathlib commit hash `06d0adfa76594f304b4650d098273d4366ede61b`. Move the generated `export.json` to `src/parse_docgen/docgen_export.json`. 14 | 15 | Then, in the `src/parse_docgen` directory run `python parse.py`. Then, `cd` into the `src/embed_mathlib` directory and run `python embed_mathlib.py` and `embeddings_to_numpy.py`. 16 | 17 | ## How `demo_search.py` works 18 | 19 | At a high-level, the script embeds queries using the embeddings api, then uses Faiss to do a fast kNN search against precomputed embeddings of mathlib declarations. 20 | 21 | The search indexes into the file `src/parse_docgen/docgen_export_with_formal_statement.jsonl`, which is loaded into the `docs` variable in `src/demo_search.py`. The entries of `docs` are in exactly the same format as the entries of the `export.json` generated by `gen_docs`, so in theory we should be use to reuse the `doc-gen` code that maps `export.json` entries to html. 22 | 23 | Each entry of `docs` actually contains one additional field compared to `export.json`: the `formal_statement` field, which is a deparsed string representation of the theorem statement. 24 | -------------------------------------------------------------------------------- /deploy.md: -------------------------------------------------------------------------------- 1 | This is what I did to set up a deployment instance on AWS. 2 | I think eventually the answer is to make a docker image. 3 | Mostly following [this guide](https://www.digitalocean.com/community/tutorials/how-to-serve-flask-applications-with-gunicorn-and-nginx-on-ubuntu-18-04). 4 | 5 | 1. Set up an EC2 instance (t2.large with Ubuntu running on it). 6 | 7 | ```sh 8 | # install stuff 9 | sudo apt update 10 | sudo apt install python3-pip python3-dev build-essential libssl-dev libffi-dev python3-setuptools 11 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 12 | sudo apt-get install git-lfs 13 | 14 | # nginx 15 | sudo apt install nginx 16 | sudo ufw allow 'Nginx Full' 17 | 18 | # set up the repo 19 | git clone git@github.com:zhangir-azerbayev/mathlib-semantic-search.git 20 | cd mathlib-semantic-search 21 | pip install -r requirements.txt 22 | touch .env 23 | ``` 24 | 25 | 10. make an `.env` file and put values for 26 | - `AWS_ACCESS_KEY_ID` 27 | - `AWS_SECRET_ACCESS_KEY` 28 | - `OPENAI_API_KEY` 29 | 30 | 12. Splat this in `/etc/systemd/system/mathlib-search.service` 31 | ```ini 32 | [Unit] 33 | Description=Gunicorn instance of mathlib-semantic-search 34 | After=network.target 35 | 36 | [Service] 37 | User=ubuntu 38 | Group=www-data 39 | WorkingDirectory=/home/ubuntu/mathlib-semantic-search 40 | ExecStart=/home/ubuntu/.local/bin/gunicorn --workers 1 --bind unix:mathlib-search.sock -m 007 src.web:app --timeout 120 41 | 42 | [Install] 43 | WantedBy=multi-user.target 44 | ``` 45 | 14. Put this in `/etc/nginx/sites-available/mathlib-search` 46 | ```conf 47 | server { 48 | listen 80; 49 | server_name mathlib-search.edayers.com; 50 | 51 | location / { 52 | include proxy_params; 53 | proxy_pass http://unix:/home/ubuntu/mathlib-semantic-search/mathlib-search.sock; 54 | } 55 | } 56 | ``` 57 | 12. `sudo ln -s /etc/nginx/sites-available/mathlib-search /etc/nginx/sites-enabled` 58 | 15. This is a hack, go to `/etc/nginx/nginx.conf` and replace the user with `root`. I tried for like 3 hours to get this working with the `www-data` user and it just can't seem to read from `mathlib-search.sock` no matter what I tried. This is a security issue but at least it works. 59 | 13. start everything 60 | ```sh 61 | sudo systemctl enable --now mathlib-search 62 | sudo systemctl enable --now nginx 63 | ``` 64 | 65 | 10. Make an elastic IP 66 | 11. Associate it to your EC2 instance. 67 | 12. Go to your DNS and add a new A-record for that. 68 | 69 | [todo] add TLS -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "mathlib-semantic-search" 7 | description = 'Semantic search for mathlib' 8 | readme = "README.md" 9 | requires-python = ">=3.9" 10 | license = "MIT" 11 | keywords = [] 12 | authors = [ 13 | { name = "Zhangir Azerbayev", email = "59542043+zhangir-azerbayev@users.noreply.github.com"}, 14 | { name = "E.W.Ayers", email = "edward.ayers@outlook.com" }, 15 | ] 16 | dependencies = [ 17 | "faiss_cpu==1.7.2", 18 | "ndjson==0.3.1", 19 | "numpy==1.23.0", 20 | "openai==0.20.0", 21 | "tqdm==4.64.0", 22 | "streamlit", 23 | "boto3", 24 | "fastapi", 25 | "uvicorn[standard]", 26 | "Flask" 27 | ] 28 | dynamic = ["version"] 29 | 30 | [project.urls] 31 | Documentation = "https://github.com/unknown/mathlib-semantic-search#readme" 32 | Issues = "https://github.com/unknown/mathlib-semantic-search/issues" 33 | Source = "https://github.com/unknown/mathlib-semantic-search" 34 | 35 | [tool.hatch.version] 36 | path = "src/__about__.py" 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss_cpu==1.7.2 2 | ndjson==0.3.1 3 | numpy==1.23.0 4 | openai==0.20.0 5 | tqdm==4.64.0 6 | backoff==2.2.1 7 | boto3 8 | Flask 9 | gunicorn 10 | python-dotenv 11 | -------------------------------------------------------------------------------- /src/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Literal, Optional, TypedDict, Union 3 | from uuid import uuid4 4 | import ndjson 5 | import numpy as np 6 | from datetime import datetime 7 | import faiss 8 | import openai 9 | import time 10 | from src.database import DB 11 | 12 | ExprStr = Union[ 13 | str, tuple[Literal['c', 'n'], "ExprStr"], tuple[Literal['n'], "ExprStr", "ExprStr"] 14 | ] 15 | 16 | class Arg(TypedDict): 17 | arg : ExprStr 18 | implicit : bool 19 | 20 | class Result(TypedDict): 21 | kind : str 22 | line : int 23 | name : str 24 | doc_string : str 25 | formal_statement : str 26 | is_meta : bool 27 | type: ExprStr 28 | filename : str 29 | 30 | args : list[Arg] 31 | attributes : list[str] 32 | constructors : list 33 | equations : list 34 | structure_fields : list 35 | noncomputable_reason : Optional[str] 36 | 37 | 38 | def url_of_entry(x : Result): 39 | """ Hackily recover the URL from file and name. """ 40 | try: 41 | # https://leanprover-community.github.io/mathlib_docs/topology/instances/ennreal.html#metric_space_emetric_ball 42 | name = x['name'] # metric_space_emetric_ball 43 | file = x['filename'] # "/data/lily/zaa7/duplicates/doc-gen/_target/deps/mathlib/src/topology/instances/ennreal.lean" 44 | _, f = file.split('mathlib/src/') 45 | p, ext = f.split('.') 46 | return f"https://leanprover-community.github.io/mathlib_docs/{p}.html#{name}" 47 | except Exception: 48 | return "#" 49 | 50 | 51 | @dataclass 52 | class SearchResult: 53 | query: str 54 | results: list[Result] 55 | fake_answer: Optional[str] = field(default =None) 56 | 57 | class AppState: 58 | docs: list 59 | database: faiss.IndexFlatL2 60 | K: int 61 | cache: dict[str, SearchResult] 62 | """ Number of results to return """ 63 | 64 | def __init__( 65 | self, 66 | *, 67 | docs_path="./src/parse_docgen/docgen_export_with_formal_statement.jsonl", 68 | vecs_path="./src/embed_mathlib/np_embeddings.npy", 69 | D=1536, # dimensionality of embedding 70 | K=10, # number of results to retrieve 71 | ): 72 | self.cache = {} 73 | self.K = K 74 | self.db = DB() 75 | print(f"loading docs from {docs_path}") 76 | with open(docs_path) as f: 77 | self.docs = ndjson.load(f) 78 | 79 | print(f"loading embeddings from {vecs_path}") 80 | embeddings = np.load(vecs_path).astype("float32") 81 | 82 | # sanity checks 83 | assert D == embeddings.shape[1] 84 | assert embeddings.shape[0] == len(self.docs) 85 | 86 | print(f"Found {len(self.docs)} mathlib declarations") 87 | 88 | print("creating fast kNN database...") 89 | self.database = faiss.IndexFlatL2(D) 90 | self.database.add(embeddings) # type: ignore 91 | 92 | print("\n" + "#" * 10, "MATHLIB SEMANTIC SEARCH", "#" * 10 + "\n") 93 | 94 | def upvote(self, name : str, query : str): 95 | self.db.put({ 96 | "kind" : "mathlib-semantic-search/vote", 97 | "name" : name, 98 | "timestamp" : datetime.now().isoformat(), 99 | "id" : uuid4().hex, 100 | "query" : query, 101 | }) 102 | 103 | def search(self, query: str, K=None, gen_fake_answer : bool = False) -> SearchResult: 104 | if query in self.cache: 105 | return self.cache[query] 106 | fake_ans = None 107 | if gen_fake_answer: 108 | few_shot = open('./src/codex_prompt.txt').read().strip() 109 | codex_prompt = few_shot + " " + query + "\n" 110 | 111 | print("###PROMPT: \n", codex_prompt) 112 | 113 | out = openai.Completion.create( 114 | engine="code-davinci-002", 115 | prompt=codex_prompt, 116 | max_tokens=512, 117 | n=1, 118 | temperature=0, 119 | stop=":=", 120 | ) 121 | 122 | print(out) 123 | 124 | fake_ans = out["choices"][0]["text"] # type: ignore 125 | query = f"/-- {query} -/\n" + fake_ans 126 | print("###QUERY: \n", query) 127 | K = K or self.K 128 | start_time = time.time() 129 | 130 | responses: Any = openai.Embedding.create( 131 | input=[query], model="text-embedding-ada-002" 132 | ) 133 | 134 | query_vec = np.expand_dims( 135 | np.array(responses["data"][0]["embedding"]).astype("float32"), axis=0 136 | ) 137 | 138 | _, idxs_np = self.database.search(query_vec, K) # type: ignore 139 | 140 | idxs = np.squeeze(idxs_np).tolist() 141 | 142 | results = [Result(**self.docs[i]) for i in idxs] 143 | 144 | end_time = time.time() 145 | 146 | print(f"Retrieved {K} results in {end_time - start_time} seconds") 147 | result = SearchResult( 148 | query = query, results = results, fake_answer = fake_ans 149 | ) 150 | self.cache[query] = result 151 | return result 152 | 153 | 154 | _cur = None 155 | @classmethod 156 | def current(cls): 157 | if cls._cur is None: 158 | cls._cur = cls() 159 | return cls._cur 160 | -------------------------------------------------------------------------------- /src/codex_prompt.txt: -------------------------------------------------------------------------------- 1 | The following is a natural language request for a theorem or definition. Give the corresponding theorem or definition in Lean mathlib. 2 | Request: The triangle inequality. 3 | theorem abs_sum_leq_sum_abs (n : ℕ) (f : ℕ → ℂ) : 4 | abs (∑ i in finset.range n, f i) ≤ ∑ i in finset.range n, abs (f i) := 5 | 6 | The following is a natural language request for a theorem or definition. Give the corresponding theorem or definition in Lean mathlib. 7 | Request: Sylow subgroup. 8 | structure sylow (p : ℕ) (G : Type*) [group G] : Type* := 9 | 10 | The following is a natural language request for a theorem or definition. Give the corresponding theorem or definition in Lean mathlib. 11 | Description: What happens when we take the quotient of a group by its center? 12 | theorem comm_group_of_cycle_center_quotient {G H : Type*} [group G] [group H] 13 | [is_cyclic H] (f : G →* H) (hf : f.ker ≤ center G) : 14 | comm_group G := 15 | 16 | The following is a natural language request for a theorem or definition. Give the corresponding theorem or definition in Lean mathlib. 17 | Request: Definition of the Krull Topology. 18 | def krull_topology (K : Type*) (L : Type*) [field K] [field L] [algebra K L] : 19 | topological_space (L ≃ₐ[K] L) := 20 | 21 | The following is a natural language request for a theorem or definition. Give the corresponding theorem or definition in Lean mathlib. 22 | Request: What can I use to prove a complex-valued function is constant? 23 | theorem exists_eq_const_of_bounded {E : Type u} [normed_group E] 24 | [normed_space ℂ E] {F : Type v} [normed_group F] [normed_space ℂ F] 25 | {f : E → F} (hf : differentiable ℂ f) (hb : metric.bounded (set.range f)) : 26 | ∃ (c : F), f = function.const E c := 27 | 28 | The following is a natural language request for a theorem or definition. Give the corresponding theorem or definition in Lean mathlib. 29 | Request: If I know `a^(p-1) = 1 mod p`, how can I prove `nat.prime p`? 30 | theorem lucas_primality (p : ℕ) (a : zmod p) (ha : a ^ (p - 1) = 1) 31 | (hd : ∀ (q : ℕ), nat.prime q → q ∣ p - 1 → a ^ ((p - 1) / q) ≠ 1) : 32 | nat.prime p := 33 | 34 | The following is a natural language request for a theorem or definition. Give the corresponding theorem or definition in Lean mathlib. 35 | Request: 36 | -------------------------------------------------------------------------------- /src/database.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, fields, is_dataclass 2 | from typing import Any, Literal, Optional 3 | import boto3 4 | from boto3.dynamodb.conditions import Key, Attr 5 | import os 6 | from uuid import uuid4 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class DB: 12 | def __init__(self): 13 | region = os.environ.get('AWS_REGION', "us-east-1") 14 | 15 | 16 | self.db: Any = boto3.resource("dynamodb", region_name = region) 17 | self.table_name = os.getenv("LEAN_CHAT_TABLE_NAME", "lean-chat") 18 | self.table = self.db.Table(self.table_name) 19 | 20 | def put(self, obj): 21 | def to_field(v): 22 | t = type(v) 23 | if t is not str: 24 | print('warning: check how dynamo expects non-string values') 25 | return v 26 | if is_dataclass(obj): 27 | item = {f.name: to_field(getattr(obj, f.name)) for f in fields(obj)} 28 | elif isinstance(obj, dict): 29 | item = {k: to_field(v) for k,v in obj.items()} 30 | print(item) 31 | else: 32 | raise TypeError(f"unsupported type {type(obj)}") 33 | try: 34 | self.table.put_item(Item=item ) 35 | return True 36 | except Exception as e: 37 | logger.error(f"Database error: {e}") 38 | return False 39 | -------------------------------------------------------------------------------- /src/demo_search.py: -------------------------------------------------------------------------------- 1 | from app import AppState 2 | from embed_mathlib.embed_mathlib import text_of_entry 3 | 4 | if __name__ == "__main__": 5 | app = AppState() 6 | 7 | while True: 8 | query = input("\n\nInput search query: ") 9 | 10 | print("searching...") 11 | 12 | results = app.search(query, gen_fake_answer=True) 13 | 14 | for i, x in enumerate(results.results): 15 | print(f"RESULT {i}: ") 16 | print(text_of_entry(x)) 17 | -------------------------------------------------------------------------------- /src/embed_mathlib/.gitattributes: -------------------------------------------------------------------------------- 1 | np_embeddings.npy filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /src/embed_mathlib/count_tokens.py: -------------------------------------------------------------------------------- 1 | import ndjson 2 | import json 3 | import sys 4 | import os 5 | 6 | from tqdm import tqdm 7 | import numpy as np 8 | import openai 9 | 10 | from transformers import AutoTokenizer 11 | 12 | IN_DIR = "../docgen_parse/docgen_export_with_formal_statement.jsonl" 13 | 14 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 15 | 16 | with open(IN_DIR) as f: 17 | data = ndjson.load(f) 18 | 19 | total = 0 20 | for x in tqdm(data): 21 | text = ( 22 | "/-- " + x["doc_string"] + " -/" + "\n" + x["formal_statement"] 23 | if x["doc_string"] 24 | else x["formal_statement"] 25 | ) 26 | 27 | count = len(tokenizer(text)['input_ids']) 28 | 29 | total += count 30 | 31 | print(total) 32 | 33 | 34 | -------------------------------------------------------------------------------- /src/embed_mathlib/embed_mathlib.py: -------------------------------------------------------------------------------- 1 | import ndjson 2 | import json 3 | import sys 4 | import os 5 | 6 | from tqdm import tqdm 7 | import openai 8 | from dataclasses import dataclass, field 9 | from uuid import uuid4 10 | from typing import Optional, Literal 11 | 12 | 13 | def batch_loader(seq, size): 14 | """ 15 | Iterator that takes in a list `seq` and returns 16 | chunks of size `size` 17 | """ 18 | return [seq[pos : pos + size] for pos in range(0, len(seq), size)] 19 | 20 | 21 | def text_of_entry(x): 22 | return ( 23 | "/-- " + x["doc_string"] + " -/" + "\n" + x["formal_statement"] 24 | if x["doc_string"] 25 | else x["formal_statement"] 26 | ) 27 | 28 | 29 | def main(): 30 | READ_DIR = "../parse_docgen/docgen_export_with_formal_statement.jsonl" 31 | OUT_DIR = "./embeddings.jsonl" 32 | 33 | if os.path.isfile(OUT_DIR): 34 | raise AssertionError(f"{OUT_DIR} is already a file") 35 | 36 | print("loading docgen data...") 37 | with open(READ_DIR) as f: 38 | data = ndjson.load(f) 39 | 40 | print("creating embeddings") 41 | for batch in tqdm(batch_loader(data, 100)): 42 | texts = [text_of_entry(x) for x in batch] 43 | 44 | responses = openai.Embedding.create( 45 | input=texts, 46 | model="text-embedding-ada-002", 47 | ) 48 | 49 | log = [] 50 | for entry, response in zip(batch, responses["data"]): 51 | to_log = {"name": entry["name"], "embedding": response["embedding"]} 52 | log.append(to_log) 53 | 54 | with open(OUT_DIR, "a+") as f: 55 | jsonstr = ndjson.dumps(log) 56 | f.write(jsonstr + "\n") 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /src/embed_mathlib/embeddings_to_numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json, ndjson 3 | from tqdm import tqdm 4 | 5 | 6 | print("loading docs...") 7 | with open("../parse_docgen/docgen_export_with_formal_statement.jsonl") as f: 8 | docs = ndjson.load(f) 9 | 10 | print("loading embeddings...") 11 | embeddings = [] 12 | with open("embeddings.jsonl") as f: 13 | for i, line in tqdm(enumerate(f)): 14 | entry = json.loads(line) 15 | 16 | # check for alignment 17 | assert entry["name"] == docs[i]["name"] 18 | 19 | embeddings.append(np.array(entry["embedding"]).astype('float32')) 20 | 21 | embeddings = np.stack(embeddings, axis=0).astype('float32') 22 | 23 | print("saving array...") 24 | np.save("np_embeddings.npy", embeddings) 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/embed_mathlib/np_embeddings.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3a2e0afc36e2fcc72e7f0956a76e7c6cd20378161bd12d1755a0227c4b0fcbbd 3 | size 899039360 4 | -------------------------------------------------------------------------------- /src/parse_docgen/.gitattributes: -------------------------------------------------------------------------------- 1 | docgen_export_with_formal_statement.jsonl filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /src/parse_docgen/docgen_export_with_formal_statement.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:00554198bc1b2454e58413bc57079a852b575d1a1f0bf37e85fc6aaf527d79f1 3 | size 270535507 4 | -------------------------------------------------------------------------------- /src/parse_docgen/parse.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ndjson 3 | import sys 4 | import re 5 | from tqdm import tqdm 6 | 7 | PATH_DOCGEN_EXPORT = "docgen_export.json" 8 | OUT_PATH = "docgen_export_with_formal_statement.jsonl" 9 | ALLOWED_KINDS = ["theorem", "def", "structure"] 10 | 11 | def merge_typestars_of_binders(binders): 12 | binders = [re.sub(r": Type u_[0-9]", ": Type*", x) for x in binders] 13 | 14 | for i in range(len(binders)-1): 15 | if re.search(r": Type\*", binders[i]) and re.search(r": Type\*", binders[i+1]) and binders[i][0]==binders[i+1][0]: 16 | #print("left: ", binders[i][:binders[i].index(":")]) 17 | #print("right: ", binders[i+1][1:]) 18 | 19 | # try and except wrapper for debugging 20 | try: 21 | binders[i+1] = binders[i][:binders[i].index(":")] + binders[i+1][1:] 22 | except: 23 | print(f"something went wrong, problem at index {i}") 24 | print(binders) 25 | binders[i] = "" 26 | 27 | return [x for x in binders if x != ""] 28 | 29 | def assemble_statement(kind, nm, binders, tp): 30 | binders = merge_typestars_of_binders(binders) 31 | 32 | statement = kind + " " + nm 33 | for binder in binders: 34 | sc = statement + " " + binder 35 | if len(sc[sc.rfind("\n")+1:]) > 80: 36 | statement += "\n\t" + binder 37 | else: 38 | statement += " " + binder 39 | 40 | statement += " :\n\t" + tp 41 | 42 | return statement 43 | 44 | def process_ue_string(arg: str): 45 | #print(arg) 46 | arg = arg.replace("\n", " ") 47 | #print("ARG BEFORE PROCESSING: ", repr(arg)) 48 | arg = re.sub(r"\ue000(.*?)\ue001", "", arg) 49 | arg = re.sub(r"\ue002", "", arg) 50 | #print("ARG: ", repr(arg)) 51 | return arg 52 | 53 | def parse_single_arg(arg): 54 | if arg == "c": 55 | return "" 56 | elif isinstance(arg, str): 57 | return process_ue_string(arg) 58 | else: 59 | assert isinstance(arg, list) 60 | if arg[0] == "n": 61 | return parse_single_arg(arg[1:]) 62 | else: 63 | return "".join([parse_single_arg(x) for x in arg]) 64 | 65 | 66 | def main(): 67 | print("loading docgen export...") 68 | with open(PATH_DOCGEN_EXPORT) as f: 69 | db = json.load(f) 70 | 71 | log = [] 72 | for x in tqdm(db["decls"]): 73 | # not memory efficient but that's ok 74 | if x["kind"] in ALLOWED_KINDS: 75 | 76 | list_of_args = [y["arg"] for y in x["args"]] 77 | binders = [parse_single_arg(y) for y in list_of_args] 78 | processed_tp = parse_single_arg(x["type"]) 79 | 80 | statement = assemble_statement(x["kind"], x["name"], binders, processed_tp) 81 | 82 | # print(statement + "\n") 83 | 84 | log.append({ 85 | **x, 86 | "formal_statement": statement, 87 | }) 88 | 89 | 90 | with open(OUT_PATH, "w") as f: 91 | ndjson.dump(log, f) 92 | 93 | if __name__=="__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /src/templates/main.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Mathlib Semantic Search 5 |
6 |

Mathlib Semantic Search!

7 |
8 |
9 |
10 | Search mathlib with natural language 11 |
12 | 13 | 16 | 19 |
20 |
21 |
22 |
23 | {% if fake_answer %} 24 | 31 | Using the embeddings from the fake answer we get: 32 | {% endif %} 33 | 34 |
35 | {% if results %} 36 |
    37 | {% for x in results %} 38 | 50 | {% endfor %} 51 |
52 | {% endif %} 53 |
54 |
55 | -------------------------------------------------------------------------------- /src/web.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | load_dotenv() 3 | 4 | from flask import Flask, request, render_template 5 | from src.app import AppState, url_of_entry 6 | from src.embed_mathlib.embed_mathlib import text_of_entry 7 | 8 | app = Flask(__name__) 9 | 10 | app.config['TEMPLATES_AUTO_RELOAD'] = True 11 | 12 | @app.route("/") 13 | def index(): 14 | query = request.args.get('query', None) 15 | print(f"query: {query}") 16 | if query is None: 17 | return render_template('main.html', query=query or "", results = None) 18 | 19 | search_result = AppState.current().search(query, gen_fake_answer = True) 20 | results = [ 21 | dict( 22 | text = text_of_entry(r), 23 | name = r['name'], 24 | url = url_of_entry(r)) 25 | for r in search_result.results 26 | ] 27 | return render_template('main.html', query = query, results = results, fake_answer = search_result.fake_answer) 28 | 29 | 30 | @app.post("/upvote/") 31 | def upvote(): 32 | query = request.args.get('query', None) 33 | name = request.args.get('name', None) 34 | assert name is not None 35 | assert query is not None 36 | print(f'upvoting {name} for {query}') 37 | name = request.args['name'] 38 | AppState.current().upvote(name, query) 39 | 40 | return "Success." --------------------------------------------------------------------------------