├── lean-toolchain ├── .gitignore ├── scripts ├── format_cpp.sh ├── build.sh ├── build_example.sh ├── unpickle_premises.py ├── validate_retrieval.py └── convert_t5encoder_to_ct2.py ├── .dockerignore ├── ModelCheckpointManager.lean ├── LeanCopilot.lean ├── python ├── external_models │ ├── __init__.py │ ├── claude_runner.py │ ├── gemini_runner.py │ ├── oai_runner.py │ ├── vllm_runner.py │ ├── external_parser.py │ └── hf_runner.py ├── README.md ├── server.py └── models.py ├── LeanCopilot ├── Models.lean ├── Models │ ├── Generic.lean │ ├── Interface.lean │ ├── Builtin.lean │ ├── Native.lean │ ├── External.lean │ ├── Registry.lean │ ├── ByT5.lean │ └── FFI.lean ├── Options.lean ├── LlmAesop.lean ├── Frontend.lean └── Tactics.lean ├── LeanCopilotTests ├── PremiseSelection.lean ├── ProofSearch.lean ├── TacticSuggestion.lean └── ModelAPIs.lean ├── Dockerfile ├── lake-manifest.json ├── ModelCheckpointManager ├── Main.lean ├── Url.lean └── Download.lean ├── LICENSE ├── .github └── workflows │ └── ci.yml ├── external_model_api.yaml ├── README.md ├── cpp ├── ct2.cpp └── npy.hpp └── lakefile.lean /lean-toolchain: -------------------------------------------------------------------------------- 1 | leanprover/lean4:v4.27.0-rc1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .lake 2 | .vscode 3 | *.olean 4 | **/__pycache__ 5 | */.DS_Store 6 | -------------------------------------------------------------------------------- /scripts/format_cpp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | clang-format --style Google -i cpp/*.cpp 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | .vscode 3 | tmp 4 | .lake 5 | clang+llvm* 6 | **/__pycache__ 7 | */.DS_Store 8 | -------------------------------------------------------------------------------- /ModelCheckpointManager.lean: -------------------------------------------------------------------------------- 1 | import ModelCheckpointManager.Url 2 | import ModelCheckpointManager.Download 3 | -------------------------------------------------------------------------------- /LeanCopilot.lean: -------------------------------------------------------------------------------- 1 | import LeanCopilot.Frontend 2 | import LeanCopilot.LlmAesop 3 | import LeanCopilot.Models 4 | import LeanCopilot.Options 5 | import LeanCopilot.Tactics 6 | -------------------------------------------------------------------------------- /python/external_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .oai_runner import OpenAIRunner 2 | from .hf_runner import HFTacticGenerator 3 | from .vllm_runner import VLLMTacticGenerator 4 | from .claude_runner import ClaudeRunner 5 | from .gemini_runner import GeminiRunner 6 | -------------------------------------------------------------------------------- /LeanCopilot/Models.lean: -------------------------------------------------------------------------------- 1 | import LeanCopilot.Models.Builtin 2 | import LeanCopilot.Models.ByT5 3 | import LeanCopilot.Models.Native 4 | import LeanCopilot.Models.External 5 | import LeanCopilot.Models.Generic 6 | import LeanCopilot.Models.FFI 7 | import LeanCopilot.Models.Interface 8 | import LeanCopilot.Models.Registry 9 | -------------------------------------------------------------------------------- /LeanCopilotTests/PremiseSelection.lean: -------------------------------------------------------------------------------- 1 | import LeanCopilot 2 | 3 | 4 | example (a b c : Nat) : a + b + c = a + c + b := by 5 | select_premises 6 | sorry 7 | 8 | 9 | set_option LeanCopilot.select_premises.k 4 10 | 11 | example (a b c : Nat) : a + b + c = a + c + b := by 12 | select_premises 13 | sorry 14 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This script demonstrates how to build LeanCopilot in GitHub Codespace. 4 | # 1. Launch a codespace for LeanCopilot. 5 | # 2. Run `source scripts/build.sh`. 6 | 7 | # Set up elan. 8 | curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | bash -s -- -y 9 | source $HOME/.elan/env 10 | 11 | # Build the project. 12 | lake build 13 | git lfs install 14 | lake exe LeanCopilot/download 15 | lake build LeanCopilotTests 16 | -------------------------------------------------------------------------------- /LeanCopilot/Models/Generic.lean: -------------------------------------------------------------------------------- 1 | import LeanCopilot.Models.Interface 2 | 3 | set_option autoImplicit false 4 | 5 | namespace LeanCopilot 6 | 7 | 8 | structure GenericGenerator where 9 | generate : String → String → IO (Array (String × Float)) 10 | 11 | 12 | instance : TextToText GenericGenerator := ⟨GenericGenerator.generate⟩ 13 | 14 | 15 | structure GenericEncoder where 16 | encode : String → IO FloatArray 17 | 18 | 19 | instance : TextToVec GenericEncoder := ⟨GenericEncoder.encode⟩ 20 | 21 | 22 | end LeanCopilot 23 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | WORKDIR /LeanCopilot 4 | COPY . . 5 | 6 | # Install dependencies. 7 | RUN apt-get update && apt-get install -y curl wget git git-lfs cmake clang lld libc++-dev 8 | RUN git lfs update --force 9 | RUN git lfs install 10 | 11 | # Install elan. 12 | ENV ELAN_HOME="/.elan" 13 | ENV PATH="${ELAN_HOME}/bin:${PATH}" 14 | RUN curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | bash -s -- -y 15 | 16 | # Build the Lean project. 17 | RUN lake build 18 | RUN lake exe LeanCopilot/download 19 | RUN lake build LeanCopilotTests 20 | -------------------------------------------------------------------------------- /scripts/build_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This script demonstrates how to build a repo that depends on Lean Copilot in GitHub Codespace. 4 | # 1. Launch a codespace for LeanCopilot. 5 | # 2. Run `source scripts/build_example.sh`. 6 | 7 | # Set up elan. 8 | curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | bash -s -- -y 9 | source $HOME/.elan/env 10 | 11 | # Set up lean4-example. 12 | cd /workspaces 13 | git clone https://github.com/yangky11/lean4-example 14 | cd lean4-example 15 | git checkout LeanCopilot-demo 16 | 17 | # Build lean4-example. 18 | git lfs install 19 | lake exe LeanCopilot/download 20 | lake build 21 | -------------------------------------------------------------------------------- /LeanCopilotTests/ProofSearch.lean: -------------------------------------------------------------------------------- 1 | import Lean 2 | import LeanCopilot 3 | 4 | open Lean Meta LeanCopilot 5 | 6 | 7 | /- 8 | ## Basic Usage 9 | -/ 10 | 11 | example (a b c : Nat) : a + b + c = c + b + a := by 12 | search_proof 13 | 14 | 15 | /- 16 | ## Advanced Usage 17 | -/ 18 | 19 | 20 | example (a b c : Nat) : a + b + c = c + b + a := by 21 | try aesop? 22 | sorry 23 | 24 | 25 | #configure_llm_aesop 26 | 27 | 28 | example (a b c : Nat) : a + b + c = c + b + a := by 29 | aesop? 30 | 31 | 32 | set_option trace.aesop true 33 | 34 | 35 | example (a b c : Nat) : a + b + c = c + b + a := by 36 | aesop? (config := { maxRuleApplications := 2 }) 37 | -------------------------------------------------------------------------------- /LeanCopilot/Models/Interface.lean: -------------------------------------------------------------------------------- 1 | set_option autoImplicit false 2 | 3 | namespace LeanCopilot 4 | 5 | 6 | class TextToText (τ : Type) where 7 | generate (model : τ) (input : String) (targetPrefix : String) : IO $ Array (String × Float) 8 | 9 | 10 | class TextToVec (τ : Type) where 11 | encode : τ → String → IO FloatArray 12 | 13 | 14 | def generate {τ : Type} [TextToText τ] (model : τ) (input : String) (targetPrefix : String := "") : IO $ Array (String × Float) := 15 | TextToText.generate model input targetPrefix 16 | 17 | 18 | def encode {τ : Type} [TextToVec τ] (model : τ) (input : String) : IO FloatArray := 19 | TextToVec.encode model input 20 | 21 | 22 | end LeanCopilot 23 | -------------------------------------------------------------------------------- /LeanCopilot/Models/Builtin.lean: -------------------------------------------------------------------------------- 1 | import ModelCheckpointManager 2 | import LeanCopilot.Models.ByT5 3 | 4 | set_option autoImplicit false 5 | 6 | namespace LeanCopilot.Builtin 7 | 8 | 9 | def generator : NativeGenerator := { 10 | url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small" 11 | tokenizer := ByT5.tokenizer 12 | params := { 13 | numReturnSequences := 32 14 | } 15 | } 16 | 17 | 18 | def encoder : NativeEncoder := { 19 | url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small" 20 | tokenizer := ByT5.tokenizer 21 | } 22 | 23 | 24 | def premisesUrl := Url.parse! "https://huggingface.co/kaiyuy/premise-embeddings-leandojo-lean4-retriever-byt5-small" 25 | 26 | 27 | end LeanCopilot.Builtin 28 | -------------------------------------------------------------------------------- /lake-manifest.json: -------------------------------------------------------------------------------- 1 | {"version": "1.1.0", 2 | "packagesDir": ".lake/packages", 3 | "packages": 4 | [{"url": "https://github.com/leanprover-community/aesop", 5 | "type": "git", 6 | "subDir": null, 7 | "scope": "", 8 | "rev": "fa78cf032194308a950a264ed87b422a2a7c1c6c", 9 | "name": "aesop", 10 | "manifestFile": "lake-manifest.json", 11 | "inputRev": "master", 12 | "inherited": false, 13 | "configFile": "lakefile.toml"}, 14 | {"url": "https://github.com/leanprover-community/batteries.git", 15 | "type": "git", 16 | "subDir": null, 17 | "scope": "", 18 | "rev": "6cae843edf5b3abc871c557614eaffdcb4492d89", 19 | "name": "batteries", 20 | "manifestFile": "lake-manifest.json", 21 | "inputRev": "main", 22 | "inherited": false, 23 | "configFile": "lakefile.toml"}], 24 | "name": "LeanCopilot", 25 | "lakeDir": ".lake"} 26 | -------------------------------------------------------------------------------- /ModelCheckpointManager/Main.lean: -------------------------------------------------------------------------------- 1 | import ModelCheckpointManager.Url 2 | import ModelCheckpointManager.Download 3 | 4 | open LeanCopilot 5 | 6 | 7 | def builtinModelUrls : List String := [ 8 | "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small", 9 | "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small", 10 | "https://huggingface.co/kaiyuy/premise-embeddings-leandojo-lean4-retriever-byt5-small", 11 | "https://huggingface.co/kaiyuy/ct2-byt5-small" 12 | ] 13 | 14 | 15 | def main (args : List String) : IO Unit := do 16 | let mut tasks := #[] 17 | let urls := Url.parse! <$> (if args.isEmpty then builtinModelUrls else args) 18 | 19 | for url in urls do 20 | tasks := tasks.push $ ← IO.asTask $ downloadUnlessUpToDate url 21 | 22 | for t in tasks do 23 | match ← IO.wait t with 24 | | Except.error e => throw e 25 | | Except.ok _ => pure () 26 | 27 | println! "Done!" 28 | -------------------------------------------------------------------------------- /scripts/unpickle_premises.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | 5 | 6 | # `indexed_corpurs.pickle` is produced by `retrieval/index.py` in [ReProver](https://github.com/lean-dojo/ReProver). 7 | indexed_corpus = pickle.load(open("indexed_corpus.pickle", "rb")) 8 | 9 | embeddings_tensor = indexed_corpus.embeddings 10 | embeddings_array = embeddings_tensor.numpy() 11 | embeddings_array_64 = embeddings_array.astype(np.float64) 12 | 13 | np.save("embeddings.npy", embeddings_array_64) 14 | print("Embeddings saved to embeddings.npy") 15 | 16 | all_premises = indexed_corpus.corpus.all_premises 17 | 18 | premise_dict = { 19 | index: {"full_name": premise.full_name, "path": premise.path, "code": premise.code} 20 | for index, premise in enumerate(all_premises) 21 | } 22 | 23 | file_name = "dictionary.json" 24 | json.dump(premise_dict, open(file_name, "wt"), indent=4) 25 | print(f"Dictionary saved to dictionary.json") 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LeanDojo Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/validate_retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | from transformers import AutoTokenizer, T5EncoderModel 5 | 6 | 7 | tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small") 8 | model = T5EncoderModel.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small") 9 | 10 | 11 | premise_embeddings_np = np.load("embeddings.npy") 12 | premise_embeddings = torch.from_numpy(premise_embeddings_np).float() 13 | 14 | # state = "n: Nat\n⊢ Nat.gcd n n = n" 15 | state = "a b c : Nat\n⊢ a + b + c = a + c + b" 16 | 17 | 18 | @torch.no_grad() 19 | def encode(s: str) -> torch.Tensor: 20 | """Encode texts into feature vectors.""" 21 | s = [s] 22 | should_squeeze = True 23 | tokenized_s = tokenizer(s, return_tensors="pt", padding=True) 24 | hidden_state = model(tokenized_s.input_ids).last_hidden_state 25 | lens = tokenized_s.attention_mask.sum(dim=1) 26 | features = (hidden_state * tokenized_s.attention_mask.unsqueeze(2)).sum( 27 | dim=1 28 | ) / lens.unsqueeze(1) 29 | if should_squeeze: 30 | features = features.squeeze() 31 | return features 32 | 33 | 34 | k = 16 35 | state_embedding = encode(state) 36 | probs = torch.matmul(premise_embeddings, state_embedding) 37 | topK = torch.topk(probs, k).indices.tolist() 38 | print(topK) 39 | 40 | 41 | with open("dictionary.json", "r") as f: 42 | dictionary = json.load(f) 43 | 44 | for i in topK: 45 | print(dictionary[str(i)]) 46 | -------------------------------------------------------------------------------- /python/external_models/claude_runner.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import os 3 | 4 | try: 5 | from anthropic import Anthropic 6 | except ImportError as e: 7 | pass 8 | from .external_parser import * 9 | 10 | 11 | class ClaudeRunner(Generator, Transformer): 12 | client = Anthropic(api_key=os.getenv("ANTHROPIC_KEY")) 13 | 14 | def __init__(self, **args): 15 | self.client_kwargs: dict[str | str] = { 16 | "model": args["model"], 17 | "temperature": args["temperature"], 18 | "max_tokens": args["max_tokens"], 19 | "top_p": args["top_p"], 20 | } 21 | self.name = self.client_kwargs["model"] 22 | 23 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 24 | prompt = pre_process_input(self.name, input + target_prefix) 25 | 26 | response = self.client.completions.create( 27 | prompt=prompt, 28 | **self.client_kwargs, 29 | ) 30 | content = response.completion 31 | 32 | results = [ 33 | (post_process_output(self.name, content), 1.0) 34 | ] # Currently Claude only supports one output. 35 | return choices_dedup(results) 36 | 37 | 38 | if __name__ == "__main__": 39 | generation_kwargs = { 40 | "model": "claude-3-opus", 41 | "temperature": 0.9, 42 | "max_tokens": 1024, 43 | "top_p": 0.9, 44 | } 45 | 46 | model = ClaudeRunner(**generation_kwargs) 47 | print(model.generate("n : ℕ\n⊢ gcd n n = n")) 48 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | Python Server for External Models 2 | ================================= 3 | 4 | This folder contains code that enables running some of the leading general-purpose or math-specific models. It is also fairly easy to adapt the existing code and run other external models you would like to bring. 5 | 6 | ## Requirements 7 | 8 | The setup steps are pretty simple. The script below is sufficient to run all external models already supported in this folder. If you only want to run a subset of them, you may not need all packages in the last step of pip installation. 9 | 10 | ```bash 11 | conda create --name lean-copilot python=3.10 python numpy 12 | conda activate lean-copilot 13 | pip install torch --index-url https://download.pytorch.org/whl/cu121 # Depending on whether you have CUDA and, if so, your CUDA version; see https://pytorch.org/. 14 | pip install fastapi uvicorn loguru transformers openai anthropic google.generativeai vllm 15 | ``` 16 | 17 | ## Running the Server 18 | 19 | ```bash 20 | uvicorn server:app --port 23337 21 | ``` 22 | 23 | After the server is up running, you can go to `LeanCopilotTests/ModelAPIs.lean` to try your external models out! 24 | 25 | ## Contributions 26 | 27 | We welcome contributions. If you think it would beneficial to add some other external models, or if you would like to make other contributions regarding the external model support in Lean Copilot, please feel free to open a PR. The main entry point is this `python` folder as well as the `ModelAPIs.lean` file under `LeanCopilotTests`. 28 | 29 | We use [`black`](https://pypi.org/project/black/) to format code in this folder. 30 | -------------------------------------------------------------------------------- /LeanCopilot/Options.lean: -------------------------------------------------------------------------------- 1 | import Lean 2 | import LeanCopilot.Models 3 | 4 | set_option autoImplicit false 5 | 6 | open Lean 7 | 8 | namespace LeanCopilot 9 | 10 | section 11 | 12 | 13 | variable {m : Type → Type} [Monad m] [MonadOptions m] [MonadEnv m] [MonadLift IO m] 14 | 15 | 16 | register_option LeanCopilot.verbose : Bool := { 17 | defValue := false 18 | descr := "Whether to log various debugging information." 19 | } 20 | 21 | 22 | def isVerbose : m Bool := do 23 | match LeanCopilot.verbose.get? (← getOptions) with 24 | | some true => return true 25 | | _ => return false 26 | 27 | 28 | namespace SuggestTactics 29 | 30 | 31 | register_option LeanCopilot.suggest_tactics.check : Bool := { 32 | defValue := true 33 | descr := "Whether to run the generated tactics." 34 | } 35 | 36 | def checkTactics : CoreM Bool := do 37 | match LeanCopilot.suggest_tactics.check.get? (← getOptions) with 38 | | some false => return false 39 | | _ => return true 40 | 41 | 42 | register_option LeanCopilot.suggest_tactics.model : String := { 43 | defValue := Builtin.generator.name 44 | } 45 | 46 | 47 | def getGeneratorName : m String := do 48 | match LeanCopilot.suggest_tactics.model.get? (← getOptions) with 49 | | some n => return n 50 | | _ => return Builtin.generator.name 51 | 52 | 53 | end SuggestTactics 54 | 55 | 56 | namespace SelectPremises 57 | 58 | 59 | register_option LeanCopilot.select_premises.k : Nat := { 60 | defValue := 16 61 | } 62 | 63 | 64 | def getNumPremises : m Nat := do 65 | match LeanCopilot.select_premises.k.get? (← getOptions) with 66 | | some k => return k 67 | | _ => return 16 68 | 69 | 70 | end SelectPremises 71 | 72 | end 73 | 74 | end LeanCopilot 75 | -------------------------------------------------------------------------------- /LeanCopilot/LlmAesop.lean: -------------------------------------------------------------------------------- 1 | import LeanCopilot.Tactics 2 | import LeanCopilot.Options 3 | import Batteries.Data.String.Basic 4 | import Aesop 5 | 6 | set_option autoImplicit false 7 | 8 | open Lean Meta Elab Term Tactic 9 | 10 | namespace LeanCopilot 11 | 12 | 13 | def tacGen : Aesop.TacGen := fun (mvarId : MVarId) => do 14 | let state ← ppTacticState [mvarId] 15 | let nm ← SuggestTactics.getGeneratorName 16 | let model ← getGenerator nm 17 | let suggestions ← generate model state "" 18 | -- A temporary workaround to prevent the tactic from using the current theorem. 19 | -- TODO: Use a more principled way, e.g., see `Lean4Repl.lean` in `LeanDojo`. 20 | if let some declName := (← liftM (m := MetaM) <| Term.TermElabM.run getDeclName?).1 then 21 | let theoremName := match declName.toString with 22 | | "_example" => "" 23 | | n => n.splitOn "." |>.getLast! 24 | let theoremNameMatcher := String.Matcher.ofString theoremName 25 | let filteredSuggestions := suggestions.filterMap fun ((t, s) : String × Float) => 26 | let isAesop := t == "aesop" 27 | let isSelfReference := ¬ (theoremName == "") ∧ (theoremNameMatcher.find? t |>.isSome) 28 | if isSelfReference ∨ isAesop then none else some (t, s) 29 | return filteredSuggestions 30 | else 31 | let filteredSuggestions := suggestions.filterMap fun ((t, s) : String × Float) => 32 | let isAesop := t == "aesop" 33 | if isAesop then none else some (t, s) 34 | return filteredSuggestions 35 | 36 | 37 | macro "#configure_llm_aesop" : command => `(@[aesop 100%] def tacGen := LeanCopilot.tacGen) 38 | 39 | 40 | macro "search_proof" : tactic => `(tactic| aesop? (add 100% tacGen)) 41 | 42 | 43 | end LeanCopilot 44 | -------------------------------------------------------------------------------- /ModelCheckpointManager/Url.lean: -------------------------------------------------------------------------------- 1 | open System (FilePath) 2 | 3 | set_option autoImplicit false 4 | 5 | namespace LeanCopilot 6 | 7 | 8 | structure Url where 9 | protocol : String 10 | hostname : String 11 | path : FilePath 12 | deriving Inhabited, Repr 13 | 14 | 15 | namespace Url 16 | 17 | def isValid (url : Url) : Bool := 18 | ¬ url.protocol.isEmpty ∧ ¬ url.hostname.isEmpty ∧ ¬ url.path.toString.isEmpty ∧ url.path.isRelative ∧ url.path.fileName.isSome 19 | 20 | 21 | def toString (url : Url) : String := 22 | assert! isValid url 23 | s!"{url.protocol}://{url.hostname}/{url.path}" 24 | 25 | 26 | instance : ToString Url := ⟨toString⟩ 27 | 28 | 29 | def parse (s : String) : Option Url := 30 | let parts := s.splitOn "://" 31 | if h : parts.length != 2 then 32 | none 33 | else 34 | have : parts.length > 1 := by 35 | by_cases h' : parts.length = 2 36 | · rw [h'] 37 | apply Nat.lt_succ_of_le 38 | simp 39 | · simp_all 40 | have : parts.length > 0 := by 41 | apply Nat.lt_of_succ_lt 42 | assumption 43 | let protocol := parts[0] 44 | match parts[1].splitOn "/" with 45 | | hostname :: path => 46 | let path := FilePath.mk $ "/".intercalate path 47 | let url : Url := ⟨protocol, hostname, path⟩ 48 | if url.isValid then 49 | some url 50 | else 51 | none 52 | | _ => none 53 | 54 | 55 | def parse! (s : String) : Url := 56 | match parse s with 57 | | some url => url 58 | | none => panic! "Invalid url: {s}" 59 | 60 | 61 | def name! (url : Url) : String := 62 | url.path.fileName.get! 63 | 64 | 65 | private def url₁ := parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small" 66 | private def url₂ := parse! "https://huggingface.co/bert-base-uncased" 67 | 68 | #eval url₁ 69 | #eval url₂ 70 | 71 | #eval url₁.name! 72 | #eval url₂.name! 73 | 74 | 75 | end Url 76 | 77 | end LeanCopilot 78 | -------------------------------------------------------------------------------- /LeanCopilotTests/TacticSuggestion.lean: -------------------------------------------------------------------------------- 1 | import LeanCopilot 2 | 3 | 4 | /- 5 | ## Basic Usage 6 | -/ 7 | 8 | 9 | example (a b c : Nat) : a + b + c = a + c + b := by 10 | suggest_tactics 11 | 12 | 13 | -- You may provide a prefix to constrain the generated tactics. 14 | example (a b c : Nat) : a + b + c = a + c + b := by 15 | suggest_tactics "rw" 16 | 17 | 18 | /- 19 | ## Advanced Usage 20 | -/ 21 | 22 | 23 | open Lean Meta LeanCopilot 24 | 25 | 26 | set_option LeanCopilot.verbose true in 27 | example (a b c : Nat) : a + b + c = a + c + b := by 28 | suggest_tactics 29 | 30 | 31 | set_option LeanCopilot.suggest_tactics.check false in 32 | example (a b c : Nat) : a + b + c = a + c + b := by 33 | suggest_tactics 34 | sorry 35 | 36 | 37 | /- 38 | ### Configure Generation Parameters 39 | -/ 40 | 41 | def params := {Builtin.generator.params with 42 | numReturnSequences := 4 43 | minLength := 100 44 | lengthPenalty := 1.0 45 | temperature := 0.5 46 | } 47 | 48 | def updatedModel := {Builtin.generator with params := params} 49 | 50 | #eval getModelRegistry 51 | #eval registerGenerator "updatedModel" (.native updatedModel) 52 | #eval getModelRegistry 53 | 54 | 55 | set_option LeanCopilot.suggest_tactics.model "updatedModel" in 56 | example (a b c : Nat) : a + b + c = a + c + b := by 57 | try suggest_tactics 58 | try sorry 59 | 60 | 61 | /- 62 | ### Bring Your Own Model 63 | 64 | 1. Make sure the model is up and running, e.g., by going to ./python and running `uvicorn server:app --port 23337`. 65 | 2. Uncomment the code below. 66 | -/ 67 | 68 | 69 | /- 70 | def myModel : ExternalGenerator := { 71 | name := "wellecks/llmstep-mathlib4-pythia2.8b" 72 | host := "localhost" 73 | port := 23337 74 | } 75 | 76 | 77 | #eval registerGenerator "wellecks/llmstep-mathlib4-pythia2.8b" (.external myModel) 78 | 79 | 80 | set_option LeanCopilot.suggest_tactics.check false in 81 | set_option LeanCopilot.suggest_tactics.model "wellecks/llmstep-mathlib4-pythia2.8b" in 82 | example (a b c : Nat) : a + b + c = a + c + b := by 83 | suggest_tactics 84 | 85 | -/ 86 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - stable 8 | 9 | push: 10 | branches: 11 | - main 12 | - stable 13 | 14 | jobs: 15 | build-release: 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | matrix: 19 | os: [macos-latest, ubuntu-latest] 20 | name: BuildRelease 21 | steps: 22 | - name: Checkout project 23 | uses: actions/checkout@v5 24 | with: 25 | fetch-depth: 0 26 | - name: Install Git LFS 27 | run: | 28 | git lfs update --force 29 | git lfs install 30 | - name: Set up elan 31 | run: curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y 32 | - name: Add .lake/build/lib to PATH 33 | shell: bash 34 | run: | 35 | echo "$GITHUB_WORKSPACE/.lake/build/lib" >> $GITHUB_PATH 36 | - name: Build project 37 | run: ~/.elan/bin/lake build 38 | - name: Download model 39 | run: | 40 | ~/.elan/bin/lake exe LeanCopilot/download 41 | - name: Build tests 42 | run: ~/.elan/bin/lake build LeanCopilotTests 43 | 44 | build-beta: 45 | runs-on: windows-latest 46 | name: BuildBeta 47 | continue-on-error: true 48 | steps: 49 | - name: Checkout project 50 | uses: actions/checkout@v5 51 | with: 52 | fetch-depth: 0 53 | - name: Install Git LFS 54 | run: | 55 | git lfs update --force 56 | git lfs install 57 | - name: Set up elan 58 | run: curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y 59 | - name: Add .lake/build/lib to PATH 60 | shell: bash 61 | run: | 62 | echo "$GITHUB_WORKSPACE/.lake/build/lib" >> $GITHUB_PATH 63 | - name: Build project 64 | run: ~/.elan/bin/lake build 65 | - name: Download model 66 | run: | 67 | ~/.elan/bin/lake exe LeanCopilot/download 68 | - name: Build tests 69 | run: ~/.elan/bin/lake build LeanCopilotTests 70 | -------------------------------------------------------------------------------- /LeanCopilot/Models/Native.lean: -------------------------------------------------------------------------------- 1 | import ModelCheckpointManager 2 | 3 | set_option autoImplicit false 4 | 5 | open System (FilePath) 6 | 7 | namespace LeanCopilot 8 | 9 | 10 | inductive Device where 11 | | cpu 12 | | cuda 13 | | auto 14 | deriving Repr 15 | 16 | 17 | instance : Inhabited Device where 18 | default := .auto 19 | 20 | 21 | def Device.toString : Device → String 22 | | Device.cpu => "cpu" 23 | | Device.cuda => "cuda" 24 | | Device.auto => "auto" 25 | 26 | instance : ToString Device := ⟨Device.toString⟩ 27 | 28 | 29 | inductive ComputeType where 30 | | default 31 | | auto 32 | | int8 33 | | int8_float32 34 | | int8_float16 35 | | int8_bfloat16 36 | | int16 37 | | float16 38 | | bfloat16 39 | | float32 40 | deriving Repr 41 | 42 | 43 | def ComputeType.toString : ComputeType → String 44 | | ComputeType.default => "default" 45 | | ComputeType.auto => "auto" 46 | | ComputeType.int8 => "int8" 47 | | ComputeType.int8_float32 => "int8_float32" 48 | | ComputeType.int8_float16 => "int8_float16" 49 | | ComputeType.int8_bfloat16 => "int8_bfloat16" 50 | | ComputeType.int16 => "int16" 51 | | ComputeType.float16 => "float16" 52 | | ComputeType.bfloat16 => "bfloat16" 53 | | ComputeType.float32 => "float32" 54 | 55 | 56 | instance : ToString ComputeType := ⟨ComputeType.toString⟩ 57 | 58 | 59 | structure Tokenizer where 60 | tokenize : String → Array String 61 | detokenize : Array String → String 62 | eosToken : String 63 | 64 | 65 | structure NativeModel where 66 | url : Url 67 | device : Device := .auto 68 | deviceIndex : Array UInt64 := #[0] 69 | computeType : ComputeType := .default 70 | tokenizer : Tokenizer 71 | 72 | 73 | def NativeModel.name (model : NativeModel) : String := model.url.name! 74 | 75 | 76 | def NativeModel.path (model : NativeModel) : IO FilePath := 77 | getModelDir model.url 78 | 79 | 80 | structure BeamSearchParams where 81 | numReturnSequences : UInt64 82 | beamSize : UInt64 := numReturnSequences 83 | minLength : UInt64 := 1 84 | maxLength : UInt64 := 1024 85 | lengthPenalty : Float := 0.0 86 | patience : Float := 2.0 87 | temperature : Float := 1.0 88 | deriving Repr 89 | 90 | 91 | structure NativeGenerator extends NativeModel where 92 | params : BeamSearchParams 93 | 94 | 95 | structure NativeEncoder extends NativeModel 96 | 97 | 98 | end LeanCopilot 99 | -------------------------------------------------------------------------------- /external_model_api.yaml: -------------------------------------------------------------------------------- 1 | paths: 2 | /generate: 3 | post: 4 | requestBody: 5 | required: true 6 | content: 7 | application/json: 8 | schema: 9 | $ref: '#/components/schemas/GeneratorRequest' 10 | responses: 11 | "200": 12 | description: OK 13 | content: 14 | application/json: 15 | schema: 16 | $ref: '#/components/schemas/GeneratorResponse' 17 | 18 | /encode: 19 | post: 20 | requestBody: 21 | required: true 22 | content: 23 | application/json: 24 | schema: 25 | $ref: '#/components/schemas/EncoderRequest' 26 | responses: 27 | "200": 28 | description: OK 29 | content: 30 | application/json: 31 | schema: 32 | $ref: '#/components/schemas/EncoderResponse' 33 | 34 | components: 35 | schemas: 36 | GeneratorRequest: 37 | type: object 38 | properties: 39 | name: 40 | type: string 41 | description: Model name 42 | input: 43 | type: string 44 | description: Input to the generator 45 | prefix: string 46 | type: string 47 | description: Prefix for constraining the output (only supported by some models) 48 | 49 | Generation: 50 | type: object 51 | properties: 52 | output: 53 | type: string 54 | description: Generator's output 55 | score: 56 | type: number 57 | description: Generator's output score 58 | 59 | GeneratorResponse: 60 | type: object 61 | properties: 62 | outputs: 63 | type: array 64 | items: 65 | $ref: '#/components/schemas/Generation' 66 | description: Multiple outputs from the generator, each with a score 67 | 68 | EncoderRequest: 69 | type: object 70 | properties: 71 | name: 72 | type: string 73 | description: Model name 74 | input: 75 | type: string 76 | description: Input to the encoder 77 | 78 | EncoderResponse: 79 | type: object 80 | properties: 81 | outputs: 82 | type: array 83 | items: 84 | type: number 85 | description: Vector embedding produced by the encoder 86 | -------------------------------------------------------------------------------- /python/external_models/gemini_runner.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import os 3 | 4 | try: 5 | import google.generativeai as genai 6 | from google.generativeai import GenerationConfig 7 | except ImportError as e: 8 | pass 9 | from .external_parser import * 10 | 11 | 12 | class GeminiRunner(Generator, Transformer): 13 | client = genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) 14 | safety_settings = [ 15 | { 16 | "category": "HARM_CATEGORY_HARASSMENT", 17 | "threshold": "BLOCK_NONE", 18 | }, 19 | { 20 | "category": "HARM_CATEGORY_HATE_SPEECH", 21 | "threshold": "BLOCK_NONE", 22 | }, 23 | { 24 | "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", 25 | "threshold": "BLOCK_NONE", 26 | }, 27 | { 28 | "category": "HARM_CATEGORY_DANGEROUS_CONTENT", 29 | "threshold": "BLOCK_NONE", 30 | }, 31 | ] 32 | 33 | def __init__(self, **args): 34 | self.client_kwargs: dict[str | str] = { 35 | "model": args["model"], 36 | "temperature": args["temperature"], 37 | "max_tokens": args["max_tokens"], 38 | "top_p": args["top_p"], 39 | } 40 | self.name = self.client_kwargs["model"] 41 | self.client = genai.GenerativeModel(args["model"]) 42 | self.generation_config = GenerationConfig( 43 | candidate_count=1, 44 | max_output_tokens=args["max_tokens"], 45 | temperature=args["temperature"], 46 | top_p=args["top_p"], 47 | ) 48 | 49 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 50 | prompt = pre_process_input(self.name, input + target_prefix) 51 | 52 | response = self.client.generate_content( 53 | prompt, 54 | generation_config=self.generation_config, 55 | safety_settings=GeminiRunner.safety_settings, 56 | ) 57 | 58 | results = [ 59 | (post_process_output(self.name, response.text), 1.0) 60 | ] # Currently Gemini only supports one output. 61 | return choices_dedup(results) 62 | 63 | 64 | if __name__ == "__main__": 65 | generation_kwargs = { 66 | "model": "gemini-1.0-pro", 67 | "temperature": 0.9, 68 | "max_tokens": 1024, 69 | "top_p": 0.9, 70 | } 71 | 72 | model = GeminiRunner(**generation_kwargs) 73 | print(model.generate("n : ℕ\n⊢ gcd n n = n")) 74 | -------------------------------------------------------------------------------- /LeanCopilot/Models/External.lean: -------------------------------------------------------------------------------- 1 | import Lean 2 | import LeanCopilot.Models.Interface 3 | 4 | set_option autoImplicit false 5 | 6 | open Lean 7 | 8 | namespace LeanCopilot 9 | 10 | 11 | structure ExternalModel where 12 | name : String 13 | host : String := "localhost" 14 | port : UInt16 := 23337 15 | deriving Inhabited, Repr 16 | 17 | 18 | structure ExternalGenerator extends ExternalModel 19 | deriving Repr 20 | 21 | 22 | structure GeneratorRequest where 23 | name : String 24 | input : String 25 | «prefix» : String 26 | deriving ToJson 27 | 28 | 29 | structure Generation where 30 | output: String 31 | score: Float 32 | deriving FromJson 33 | 34 | 35 | structure GeneratorResponse where 36 | outputs : Array Generation 37 | deriving FromJson 38 | 39 | 40 | structure EnencoderRequest where 41 | name : String 42 | input : String 43 | deriving ToJson 44 | 45 | 46 | structure EncoderResponse where 47 | outputs : Array Float 48 | deriving FromJson 49 | 50 | 51 | def send {α β : Type} [ToJson α] [FromJson β] (req : α) (url : String) : IO β := do 52 | let reqStr := (toJson req).pretty 99999999999999999 53 | let out ← IO.Process.output { 54 | cmd := "curl" 55 | args := #["-X", "POST", url, "-H", "accept: application/json", "-H", "Content-Type: application/json", "-d", reqStr] 56 | } 57 | if out.exitCode != 0 then 58 | throw $ IO.userError s!"Request failed. Please check if the server is up at `{url}`." 59 | let some json := Json.parse out.stdout |>.toOption 60 | | throw $ IO.userError "Failed to parse response" 61 | let some res := (fromJson? json : Except String β) |>.toOption 62 | | throw $ IO.userError "Failed to parse response" 63 | return res 64 | 65 | 66 | def ExternalGenerator.generate (model : ExternalGenerator) (input : String) (targetPrefix : String) : IO $ Array (String × Float) := do 67 | let url := s!"http://{model.host}:{model.port}/generate" 68 | let req : GeneratorRequest := { 69 | name := model.name, 70 | input := input, 71 | «prefix» := targetPrefix 72 | } 73 | let res : GeneratorResponse ← send req url 74 | return res.outputs.map fun g => (g.output, g.score) 75 | 76 | 77 | instance : TextToText ExternalGenerator := ⟨ExternalGenerator.generate⟩ 78 | 79 | 80 | structure ExternalEncoder extends ExternalModel 81 | deriving Repr 82 | 83 | 84 | def ExternalEncoder.encode (model : ExternalEncoder) (input : String) : IO FloatArray := do 85 | let url := s!"http://{model.host}:{model.port}/encode" 86 | let req : EnencoderRequest := { 87 | name := model.name, 88 | input := input, 89 | } 90 | let res : EncoderResponse ← send req url 91 | return FloatArray.mk res.outputs 92 | 93 | 94 | instance : TextToVec ExternalEncoder := ⟨ExternalEncoder.encode⟩ 95 | 96 | 97 | end LeanCopilot 98 | -------------------------------------------------------------------------------- /python/external_models/oai_runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Tuple 3 | import os 4 | import numpy as np 5 | import openai 6 | from openai import OpenAI 7 | from .external_parser import * 8 | 9 | 10 | class OpenAIRunner(Generator, Transformer): 11 | client = OpenAI( 12 | api_key=os.getenv("OPENAI_API_KEY"), 13 | ) 14 | 15 | def __init__(self, **args): 16 | self.client_kwargs: dict[str | str] = { 17 | "model": args["model"], 18 | "temperature": args["temperature"], 19 | "max_tokens": args["max_tokens"], 20 | "top_p": args["top_p"], 21 | "frequency_penalty": 0, 22 | "presence_penalty": 0, 23 | "n": args["num_return_sequences"], 24 | "timeout": args["openai_timeout"], 25 | # "stop": args.stop, # stop is only used for base models currently 26 | } 27 | self.name = self.client_kwargs["model"] 28 | 29 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 30 | prompt = pre_process_input(self.name, input + target_prefix) 31 | prompt = [ 32 | {"role": "user", "content": f"{prompt}"}, 33 | ] 34 | try: 35 | response = OpenAIRunner.client.chat.completions.create( 36 | messages=prompt, 37 | logprobs=True, 38 | **self.client_kwargs, 39 | ) 40 | except ( 41 | openai.APIError, 42 | openai.RateLimitError, 43 | openai.InternalServerError, 44 | openai.OpenAIError, 45 | openai.APIStatusError, 46 | openai.APITimeoutError, 47 | openai.InternalServerError, 48 | openai.APIConnectionError, 49 | ) as e: 50 | print("Exception: ", repr(e)) 51 | print("Consider reducing the number of parallel processes.") 52 | return OpenAIRunner.generate(self, input, target_prefix) 53 | except Exception as e: 54 | print(f"Failed to run the model for {prompt}!") 55 | print("Exception: ", repr(e)) 56 | raise e 57 | 58 | results = [ 59 | ( 60 | post_process_output(self.name, c.message.content), 61 | np.exp(-np.mean([token.logprob for token in c.logprobs.content])), 62 | ) 63 | for c in response.choices 64 | ] 65 | return choices_dedup(results) 66 | 67 | 68 | if __name__ == "__main__": 69 | generation_kwargs = { 70 | "model": "gpt-4-turbo-preview", 71 | "temperature": 0.9, 72 | "max_tokens": 1024, 73 | "top_p": 0.9, 74 | "frequency_penalty": 0, 75 | "presence_penalty": 0, 76 | "num_return_sequences": 16, 77 | "openai_timeout": 45, 78 | # "stop": args.stop, # stop is only used for base models currently 79 | } 80 | 81 | model = OpenAIRunner(**generation_kwargs) 82 | print(model.generate("n : ℕ\n⊢ gcd n n = n")) 83 | -------------------------------------------------------------------------------- /python/external_models/vllm_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from loguru import logger 4 | from typing import List, Tuple 5 | from transformers import AutoTokenizer 6 | 7 | try: 8 | from vllm import LLM, SamplingParams 9 | except ImportError as e: 10 | print("Cannot import vllm") 11 | pass 12 | from .external_parser import * 13 | 14 | 15 | class VLLMTacticGenerator(Generator, Transformer): 16 | def __init__(self, **args) -> None: 17 | self.name = args["model"] 18 | self.llm = LLM( 19 | model=self.name, 20 | tokenizer=self.name, 21 | tensor_parallel_size=args["tensor_parallel_size"], 22 | enforce_eager=True, 23 | max_model_len=4096, 24 | disable_custom_all_reduce=False, 25 | trust_remote_code=True, 26 | ) 27 | self.sampling_params = SamplingParams( 28 | n=args["n"], 29 | max_tokens=args["max_tokens"], 30 | temperature=args["temperature"], 31 | top_p=args["top_p"], 32 | frequency_penalty=0, 33 | presence_penalty=0, 34 | logprobs=0, 35 | prompt_logprobs=0, 36 | ) 37 | 38 | self.tokenizer = AutoTokenizer.from_pretrained( 39 | self.name, trust_remote_code=True 40 | ) 41 | device = args["device"] 42 | if device == "auto": 43 | device = get_cuda_if_available() 44 | else: 45 | device = torch.device(device) 46 | logger.info(f"Loading {self.name} on {device}") 47 | 48 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 49 | prompt = input + target_prefix 50 | '''prompt= 'Here is a theorom you need to prove in Lean:\n'+prompt+'\nNow you should suggest one line tactic in lean code:' 51 | prompt = f"""<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n""" 52 | ''' 53 | prompt = pre_process_input(self.name, prompt) 54 | 55 | vllm_outputs = self.llm.generate(prompt, self.sampling_params) 56 | result = [] 57 | for output in vllm_outputs[0].outputs: # bsz=1 for now 58 | out = output.text.split("<|im_end|>")[0] 59 | result.append( 60 | (post_process_output(self.name, out), np.exp(output.cumulative_logprob)) 61 | ) 62 | 63 | result = choices_dedup(result) 64 | return result 65 | 66 | 67 | if __name__ == "__main__": 68 | generation_kwargs = { 69 | "model": "AI-MO/Kimina-Prover-Preview-Distill-7B", 70 | "tensor_parallel_size": 1, 71 | "temperature": 0.6, 72 | "max_tokens": 1024, 73 | "top_p": 0.9, 74 | "length_penalty": 0, 75 | "n": 32, 76 | "do_sample": True, 77 | "output_scores": True, 78 | "output_logits": False, 79 | "return_dict_in_generate": True, 80 | "device": "auto", 81 | } 82 | model = VLLMTacticGenerator(**generation_kwargs) 83 | print(model.generate("n : ℕ\n⊢ gcd n n = n")) 84 | -------------------------------------------------------------------------------- /python/server.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from fastapi import FastAPI 3 | from pydantic import BaseModel 4 | 5 | from models import * 6 | from external_models import * 7 | 8 | app = FastAPI() 9 | 10 | models = { 11 | "gpt4": OpenAIRunner( 12 | model="gpt-4-turbo-preview", 13 | temperature=0.9, 14 | max_tokens=1024, 15 | top_p=0.9, 16 | frequency_penalty=0, 17 | presence_penalty=0, 18 | num_return_sequences=16, 19 | openai_timeout=45, 20 | ), 21 | "InternLM": VLLMTacticGenerator( 22 | model="internlm/internlm2-math-plus-1_8b", 23 | tensor_parallel_size=2, 24 | temperature=0.6, 25 | max_tokens=1024, 26 | top_p=0.9, 27 | length_penalty=0, 28 | n=32, 29 | do_sample=True, 30 | output_scores=True, 31 | output_logits=False, 32 | return_dict_in_generate=True, 33 | device="auto", 34 | ), 35 | "kimina": VLLMTacticGenerator( 36 | model="AI-MO/Kimina-Prover-Preview-Distill-7B", 37 | tensor_parallel_size=1, 38 | temperature=0.6, 39 | max_tokens=1024, 40 | top_p=0.9, 41 | length_penalty=0, 42 | n=32, 43 | do_sample=True, 44 | output_scores=True, 45 | output_logits=False, 46 | return_dict_in_generate=True, 47 | device="auto", 48 | ), 49 | "wellecks/llmstep-mathlib4-pythia2.8b": PythiaTacticGenerator( 50 | num_return_sequences=32, max_length=1024, device="auto" 51 | ), 52 | "t5-small": EncoderDecoderTransformer( 53 | "t5-small", num_return_sequences=3, max_length=1024 54 | ), 55 | "kaiyuy/leandojo-lean4-tacgen-byt5-small": EncoderDecoderTransformer( 56 | "kaiyuy/leandojo-lean4-tacgen-byt5-small", 57 | num_return_sequences=32, 58 | max_length=1024, 59 | ), 60 | "kaiyuy/leandojo-lean4-retriever-byt5-small": EncoderOnlyTransformer( 61 | "kaiyuy/leandojo-lean4-retriever-byt5-small" 62 | ), 63 | } 64 | 65 | 66 | class GeneratorRequest(BaseModel): 67 | name: str 68 | input: str 69 | prefix: Optional[str] 70 | 71 | 72 | class Generation(BaseModel): 73 | output: str 74 | score: float 75 | 76 | 77 | class GeneratorResponse(BaseModel): 78 | outputs: List[Generation] 79 | 80 | 81 | class EncoderRequest(BaseModel): 82 | name: str 83 | input: str 84 | 85 | 86 | class EncoderResponse(BaseModel): 87 | outputs: List[float] 88 | 89 | 90 | @app.post("/generate") 91 | async def generate(req: GeneratorRequest) -> GeneratorResponse: 92 | model = models[req.name] 93 | target_prefix = req.prefix if req.prefix is not None else "" 94 | outputs = model.generate(req.input, target_prefix) 95 | return GeneratorResponse( 96 | outputs=[Generation(output=out[0], score=out[1]) for out in outputs] 97 | ) 98 | 99 | 100 | @app.post("/encode") 101 | async def encode(req: EncoderRequest) -> EncoderResponse: 102 | model = models[req.name] 103 | feature = model.encode(req.input) 104 | return EncoderResponse(outputs=feature.tolist()) 105 | -------------------------------------------------------------------------------- /ModelCheckpointManager/Download.lean: -------------------------------------------------------------------------------- 1 | import ModelCheckpointManager.Url 2 | 3 | set_option autoImplicit false 4 | 5 | open System (FilePath) 6 | 7 | namespace LeanCopilot 8 | 9 | inductive SupportedOS where 10 | | linux 11 | | macos 12 | | windows 13 | deriving Inhabited, BEq 14 | 15 | def getOS! : SupportedOS := 16 | if System.Platform.isWindows then 17 | .windows 18 | else if System.Platform.isOSX then 19 | .macos 20 | else 21 | .linux 22 | 23 | def ensureDirExists (dir : FilePath) : IO Unit := do 24 | if ¬ (← dir.pathExists) then 25 | IO.FS.createDirAll dir 26 | 27 | 28 | def getHomeDir : IO FilePath := do 29 | let home := if getOS! == .windows then "USERPROFILE" else "HOME" 30 | let some dir ← IO.getEnv home | throw $ IO.userError s!"Cannot find the ${home} environment variable." 31 | return dir 32 | 33 | 34 | def getDefaultCacheDir : IO FilePath := do 35 | return (← getHomeDir) / ".cache/lean_copilot/models" 36 | 37 | 38 | def getCacheDir : IO FilePath := do 39 | let defaultCacheDir ← getDefaultCacheDir 40 | let dir := match ← IO.getEnv "LEAN_COPILOT_CACHE_DIR" with 41 | | some dir => (dir : FilePath) 42 | | none => defaultCacheDir 43 | ensureDirExists dir 44 | return dir.normalize 45 | 46 | 47 | def getModelDir (url : Url) : IO FilePath := do 48 | return (← getCacheDir) / url.hostname / url.path |>.normalize 49 | 50 | 51 | def isUpToDate (url : Url) : IO Bool := do 52 | let dir := ← getModelDir url 53 | if ¬ (← dir.pathExists) then 54 | return false 55 | 56 | let _ ← IO.Process.run { 57 | cmd := "git" 58 | args := #["fetch", "--quiet", "--all"] 59 | cwd := dir 60 | } 61 | 62 | let branch := (← IO.Process.run { 63 | cmd := "git" 64 | args := #["symbolic-ref", "refs/remotes/origin/HEAD","--short"] 65 | cwd := dir 66 | }).trim 67 | 68 | let hasRemoteChange := (← IO.Process.run { 69 | cmd := "git" 70 | args := #["diff", (branch.splitOn "/")[1]!, branch, "--shortstat"] 71 | cwd := dir 72 | }).trim != "" 73 | 74 | return ¬hasRemoteChange 75 | 76 | 77 | def initGitLFS : IO Unit := do 78 | let proc ← IO.Process.output { 79 | cmd := "git" 80 | args := #["lfs", "install"] 81 | } 82 | if proc.exitCode != 0 then 83 | throw $ IO.userError "Failed to initialize Git LFS. Please install it." 84 | 85 | 86 | def downloadUnlessUpToDate (url : Url) : IO Unit := do 87 | let dir := ← getModelDir url 88 | if ← isUpToDate url then 89 | println! s!"The model is available at {dir}" 90 | return 91 | 92 | println! s!"Downloading the model into {dir}" 93 | if ← dir.pathExists then 94 | IO.FS.removeDirAll dir 95 | let some parentDir := dir.parent | unreachable! 96 | IO.FS.createDirAll parentDir 97 | 98 | initGitLFS 99 | let proc ← IO.Process.output { 100 | cmd := "git" 101 | args := #["clone", toString url] 102 | cwd := parentDir 103 | } 104 | if proc.exitCode != 0 then 105 | throw $ IO.userError s!"Failed to download the model. You can download it manually from {url} and store it in `{dir}/`. See https://huggingface.co/docs/hub/models-downloading for details." 106 | 107 | 108 | end LeanCopilot 109 | -------------------------------------------------------------------------------- /python/external_models/external_parser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Tuple 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | def get_cuda_if_available(): 8 | return torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def pre_process_input(model_name, input): 12 | if model_name == "internlm/internlm2-math-plus-1_8b" or model_name == "AI-MO/Kimina-Prover-Preview-Distill-7B": 13 | prompt = ( 14 | "My LEAN 4 state is:\n```lean\n" 15 | + input 16 | + "```\nPlease predict a possible tactic to help me prove the theorem." 17 | ) 18 | prompt = f"""<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n""" 19 | elif model_name == "gpt-3.5-turbo" or model_name == "gpt-4-turbo-preview": 20 | prompt = ( 21 | "Here is a theorem you need to prove in Lean:\n" 22 | + input 23 | + "\nNow you should suggest one line tactic in lean code:" 24 | ) 25 | elif "gemini" in model_name or "claude" in model_name: 26 | prompt = ( 27 | "Here is a theorem you need to prove in Lean:\n" 28 | + input 29 | + "\nNow you should suggest one line tactic in lean code:" 30 | ) 31 | else: 32 | raise NotImplementedError(f"External model '{model_name}' not supported") 33 | return prompt 34 | 35 | 36 | def post_process_output(model_name, output): 37 | if model_name == "internlm/internlm2-math-plus-1_8b": 38 | result = ( 39 | output.split("assistant")[-1] 40 | .split("lean")[-1] 41 | .split("```")[0] 42 | .split("\n")[1] 43 | ) 44 | elif model_name == "AI-MO/Kimina-Prover-Preview-Distill-7B": 45 | result = ( 46 | output.split("assistant")[-1] 47 | .split("lean")[-1] 48 | .split("```")[0] 49 | .split("\n")[-2] 50 | .lstrip() 51 | ) 52 | elif model_name == "gpt-3.5-turbo" or model_name == "gpt-4-turbo-preview": 53 | result = output.split("lean")[-1].split("```")[0].split("\n")[1] 54 | elif "gemini" in model_name or "claude" in model_name: 55 | result = output.split("lean")[-1].split("```")[0].split("\n")[1] 56 | else: 57 | raise NotImplementedError(f"External model '{model_name}' not supported") 58 | return result 59 | 60 | 61 | def choices_dedup(output_list: List[tuple[str, float]]) -> List[tuple[str, float]]: 62 | unique_data = {} 63 | for item in output_list: 64 | if item[0] not in unique_data or item[1] > unique_data[item[0]]: 65 | unique_data[item[0]] = item[1] 66 | sorted_data = sorted(unique_data.items(), key=lambda x: x[1], reverse=True) 67 | return sorted_data 68 | 69 | 70 | class Generator(ABC): 71 | @abstractmethod 72 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 73 | pass 74 | 75 | 76 | class Encoder(ABC): 77 | @abstractmethod 78 | def encode(self, input: str) -> np.ndarray: 79 | pass 80 | 81 | 82 | class Transformer: 83 | def cuda(self) -> None: 84 | self.model.cuda() 85 | 86 | def cpu(self) -> None: 87 | self.model.cpu() 88 | 89 | @property 90 | def device(self) -> torch.device: 91 | return self.model.device 92 | -------------------------------------------------------------------------------- /LeanCopilotTests/ModelAPIs.lean: -------------------------------------------------------------------------------- 1 | import LeanCopilot 2 | 3 | open LeanCopilot 4 | 5 | #eval cudaAvailable 6 | 7 | /-- 8 | ReProver's tactic generator in CT2 format. 9 | -/ 10 | def reprover : NativeGenerator := { 11 | url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small" 12 | tokenizer := ByT5.tokenizer 13 | params := { 14 | numReturnSequences := 1 15 | } 16 | } 17 | 18 | #eval generate reprover "n : ℕ\n⊢ gcd n n = n" 19 | 20 | def reprover' : NativeGenerator := {reprover with 21 | device := .cpu 22 | computeType := .float32 23 | params := {numReturnSequences := 4} 24 | } 25 | 26 | #eval generate reprover' "n : ℕ\n⊢ gcd n n = n" 27 | 28 | 29 | /-- 30 | The original ByT5 checkpoint in CT2 format. 31 | -/ 32 | def byt5 : NativeGenerator := { 33 | url := Url.parse! "https://huggingface.co/kaiyuy/ct2-byt5-small" 34 | tokenizer := ByT5.tokenizer 35 | params := { 36 | numReturnSequences := 1 37 | } 38 | } 39 | 40 | #eval generate byt5 "Hello, world!" 41 | 42 | 43 | /-- 44 | ReProver's retriever encoder in CT2 format. 45 | -/ 46 | def reproverEncoder : NativeEncoder := { 47 | url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small" 48 | tokenizer := ByT5.tokenizer 49 | } 50 | 51 | #eval encode reproverEncoder "n : ℕ\n⊢ gcd n n = n" 52 | 53 | 54 | /-- 55 | Arbitrary generator you can define. 56 | -/ 57 | def dummyGenerator : GenericGenerator where 58 | generate _ _ := return #[⟨"Hello, world!", 0.5⟩, ("Hi!", 0.3)] 59 | 60 | #eval generate dummyGenerator "n : ℕ\n⊢ gcd n n = n" 61 | 62 | 63 | /-- 64 | Arbitrary encoder you can define. 65 | -/ 66 | def dummyEncoder : GenericEncoder where 67 | encode _ := return FloatArray.mk #[1, 2, 3] 68 | 69 | #eval encode dummyEncoder "Hi!" 70 | 71 | /- 72 | External Models 73 | 74 | 1. Make sure the model is up and running, e.g., by going to ./python and running `uvicorn server:app --port 23337`. 75 | 2. Uncomment the code below. 76 | -/ 77 | 78 | /- 79 | /-- 80 | https://huggingface.co/wellecks/llmstep-mathlib4-pythia2.8b 81 | -/ 82 | def pythia : ExternalGenerator := { 83 | name := "wellecks/llmstep-mathlib4-pythia2.8b" 84 | host := "localhost" 85 | port := 23337 86 | } 87 | 88 | #eval generate pythia "n : ℕ\n⊢ gcd n n = n" 89 | 90 | 91 | /-- 92 | ReProver's retriever encoder as an external model. 93 | -/ 94 | def reproverExternalEncoder : ExternalEncoder := { 95 | name := "kaiyuy/leandojo-lean4-retriever-byt5-small" 96 | host := "localhost" 97 | port := 23337 98 | } 99 | 100 | -- Go to ./python and run `uvicorn server:app --port 23337` 101 | #eval encode reproverExternalEncoder "n : ℕ\n⊢ gcd n n = n" 102 | 103 | /-- 104 | General-purpose LLM apis: openai, claude, etc. 105 | -/ 106 | def gpt4 : ExternalGenerator := { 107 | name := "gpt4" 108 | host := "localhost" 109 | port := 23337 110 | } 111 | 112 | #eval generate gpt4 "n : ℕ\n⊢ gcd n n = n" 113 | 114 | /-- 115 | Math LLMs: InternLM, Deepseekmath, etc. 116 | -/ 117 | def internLM : ExternalGenerator := { 118 | name := "InternLM" 119 | host := "localhost" 120 | port := 23337 121 | } 122 | 123 | #eval generate internLM "n : ℕ\n⊢ gcd n n = n" 124 | 125 | -/ 126 | 127 | /- 128 | def kimina : ExternalGenerator := { 129 | name := "kimina" 130 | host := "localhost" 131 | port := 23337 132 | } 133 | 134 | #eval generate kimina "n : ℕ\n⊢ gcd n n = n" 135 | -/ 136 | -------------------------------------------------------------------------------- /python/external_models/hf_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loguru import logger 3 | from typing import List, Tuple 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | ) 8 | from .external_parser import * 9 | 10 | 11 | class HFTacticGenerator(Generator, Transformer): 12 | def __init__(self, **args) -> None: 13 | self.name = args["model"] 14 | self.tokenizer = AutoTokenizer.from_pretrained( 15 | self.name, trust_remote_code=True 16 | ) 17 | device = args["device"] 18 | if device == "auto": 19 | device = get_cuda_if_available() 20 | else: 21 | device = torch.device(device) 22 | logger.info(f"Loading {self.name} on {device}") 23 | self.model = AutoModelForCausalLM.from_pretrained( 24 | self.name, trust_remote_code=True 25 | ).to(device) 26 | 27 | self.generation_args: dict[str | str] = { 28 | "do_sample": args["do_sample"], 29 | "temperature": args["temperature"], # chat default is 0.8. 30 | "max_new_tokens": args["max_new_tokens"], 31 | "top_p": args["top_p"], # chat default is 0.8. 32 | "num_return_sequences": args["num_return_sequences"], 33 | "output_scores": args["output_scores"], 34 | "output_logits": args["output_logits"], 35 | "return_dict_in_generate": args["return_dict_in_generate"], 36 | } 37 | 38 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 39 | prompt = input + target_prefix 40 | '''prompt= 'Here is a theorom you need to prove in Lean:\n'+prompt+'\nNow you should suggest one line tactic in lean code:' 41 | prompt = f"""<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n""" 42 | ''' 43 | prompt = pre_process_input(self.name, prompt) 44 | 45 | self.model = self.model.eval() 46 | 47 | tokenized_input = self.tokenizer(prompt, return_tensors="pt") 48 | eos_token_id = [ 49 | self.tokenizer.eos_token_id, 50 | self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0], 51 | ] 52 | outputs = self.model.generate( 53 | tokenized_input.input_ids.to(self.device), 54 | eos_token_id=eos_token_id, 55 | **self.generation_args, 56 | ) 57 | response = self.tokenizer.batch_decode( 58 | outputs["sequences"], skip_special_tokens=True 59 | ) 60 | 61 | result = [] 62 | index = 0 63 | for out, score in zip(response, outputs.scores): 64 | out = post_process_output(self.name, out) 65 | result.append((out, score[index].exp().sum().log().cpu().item())) 66 | index += 1 67 | result = choices_dedup(result) 68 | return result 69 | 70 | 71 | if __name__ == "__main__": 72 | generation_kwargs = { 73 | "model": "internlm/internlm2-math-plus-1_8b", 74 | "temperature": 0.6, 75 | "max_new_tokens": 1024, 76 | "top_p": 0.9, 77 | "length_penalty": 0, 78 | "num_return_sequences": 64, 79 | "do_sample": True, 80 | "output_scores": True, 81 | "output_logits": False, 82 | "return_dict_in_generate": True, 83 | "device": "auto", 84 | } 85 | model = HFTacticGenerator(**generation_kwargs) 86 | model.cuda() 87 | print(model.generate("n : ℕ\n⊢ gcd n n = n")) 88 | -------------------------------------------------------------------------------- /LeanCopilot/Models/Registry.lean: -------------------------------------------------------------------------------- 1 | import Lean 2 | import Batteries.Data.HashMap 3 | import LeanCopilot.Models.Native 4 | import LeanCopilot.Models.External 5 | import LeanCopilot.Models.Generic 6 | import LeanCopilot.Models.Builtin 7 | import LeanCopilot.Models.FFI 8 | 9 | set_option autoImplicit false 10 | 11 | open Batteries 12 | 13 | namespace LeanCopilot 14 | 15 | 16 | inductive Generator where 17 | | native : NativeGenerator → Generator 18 | | external : ExternalGenerator → Generator 19 | | generic : GenericGenerator → Generator 20 | 21 | 22 | instance : TextToText Generator where 23 | generate (model : Generator) (input : String) (targetPrefix : String) := 24 | match model with 25 | | .native ng => ng.generate input targetPrefix 26 | | .external eg => eg.generate input targetPrefix 27 | | .generic gg => gg.generate input targetPrefix 28 | 29 | 30 | inductive Encoder where 31 | | native : NativeEncoder → Encoder 32 | | external : ExternalEncoder → Encoder 33 | | generic : GenericEncoder → Encoder 34 | 35 | 36 | instance : TextToVec Encoder where 37 | encode (model : Encoder) (input : String) := 38 | match model with 39 | | .native ne => ne.encode input 40 | | .external ee => ee.encode input 41 | | .generic ge => ge.encode input 42 | 43 | 44 | instance {α β : Type} [BEq α] [Hashable α] [Repr α] [Repr β] : Repr (Std.HashMap α β) where 45 | reprPrec hm n := reprPrec hm.toList n 46 | 47 | 48 | structure ModelRegistry where 49 | generators : Std.HashMap String Generator := 50 | Std.HashMap.ofList [(Builtin.generator.name, .native Builtin.generator)] 51 | encoders : Std.HashMap String Encoder := 52 | Std.HashMap.ofList [(Builtin.encoder.name, .native Builtin.encoder)] 53 | 54 | 55 | namespace ModelRegistry 56 | 57 | 58 | def generatorNames (mr : ModelRegistry) : List String := 59 | mr.generators.toList.map (·.1) 60 | 61 | 62 | def encoderNames (mr : ModelRegistry) : List String := 63 | mr.encoders.toList.map (·.1) 64 | 65 | 66 | def modelNames (mr : ModelRegistry) : List String := 67 | mr.generatorNames ++ mr.encoderNames 68 | 69 | 70 | end ModelRegistry 71 | 72 | 73 | instance : Repr ModelRegistry where 74 | reprPrec mr n := reprPrec mr.modelNames n 75 | 76 | 77 | instance : Inhabited ModelRegistry where 78 | default := {} 79 | 80 | 81 | initialize modelRegistryRef : IO.Ref ModelRegistry ← IO.mkRef default 82 | 83 | 84 | def getModelRegistry : IO ModelRegistry := modelRegistryRef.get 85 | 86 | 87 | def getGenerator (name : String) : Lean.CoreM Generator := do 88 | let mr ← getModelRegistry 89 | match mr.generators[name]? with 90 | | some (.native model) => 91 | if ¬(← isUpToDate model.url) then 92 | Lean.logWarning s!"The local model {model.name} is not up to date. You may want to run `lake exe LeanCopilot/download` to re-download it." 93 | return .native model 94 | | some descr => return descr 95 | | none => throwError s!"unknown generator: {name}" 96 | 97 | 98 | def getEncoder (name : String) : IO Encoder := do 99 | let mr ← getModelRegistry 100 | match mr.encoders[name]? with 101 | | some descr => return descr 102 | | none => throw $ IO.userError s!"unknown encoder: {name}" 103 | 104 | 105 | def registerGenerator (name : String) (model : Generator) := do 106 | let mr ← getModelRegistry 107 | modelRegistryRef.modify fun _ => 108 | {mr with generators := mr.generators.insert name model} 109 | 110 | 111 | end LeanCopilot 112 | -------------------------------------------------------------------------------- /LeanCopilot/Frontend.lean: -------------------------------------------------------------------------------- 1 | /- This frontend is developed partly based on `mathlib4/Mathlib/Tactic/Hint.lean` -/ 2 | import Lean 3 | import LeanCopilot.Options 4 | import Lean.Meta.Tactic.TryThis 5 | import Batteries.Data.MLList.Basic 6 | import Batteries.Control.Nondet.Basic 7 | 8 | open Lean Parser Elab Tactic 9 | 10 | 11 | set_option autoImplicit false 12 | 13 | 14 | open Lean.Meta.Tactic.TryThis in 15 | /-- 16 | Construct a suggestion for a tactic. 17 | * Check the passed `MessageLog` for an info message beginning with "Try this: ". 18 | * If found, use that as the suggestion. 19 | * Otherwise use the provided syntax. 20 | * Also, look for remaining goals and pretty print them after the suggestion. 21 | -/ 22 | def suggestion (tac : String) (msgs : MessageLog := {}) : TacticM Suggestion := do 23 | -- TODO `addExactSuggestion` has an option to construct `postInfo?` 24 | -- Factor that out so we can use it here instead of copying and pasting? 25 | let goals ← getGoals 26 | let postInfo? ← if goals.isEmpty then pure none else 27 | let mut str := "\nRemaining subgoals:" 28 | for g in goals do 29 | let goalType ← instantiateMVars (← g.getType) 30 | let e ← g.withContext do (PrettyPrinter.ppExpr goalType) 31 | str := str ++ Format.pretty ("\n⊢ " ++ e) 32 | pure (some str) 33 | let msg? ← msgs.toList.findM? fun m => do pure <| 34 | m.severity == MessageSeverity.information && (← m.data.toString).startsWith "Try this: " 35 | let suggestion ← match msg? with 36 | | some m => pure <| SuggestionText.string (((← m.data.toString).drop 10).takeWhile (· != '\n')) 37 | | none => pure <| SuggestionText.string tac 38 | return { suggestion, postInfo? } 39 | 40 | 41 | /-- Run a tactic, returning any new messages rather than adding them to the message log. -/ 42 | def withMessageLog (t : TacticM Unit) : TacticM MessageLog := do 43 | let initMsgs ← modifyGetThe Core.State fun st => (st.messages, { st with messages := {} }) 44 | t 45 | modifyGetThe Core.State fun st => (st.messages, { st with messages := initMsgs }) 46 | 47 | 48 | /-- 49 | Run a tactic, but revert any changes to info trees. 50 | We use this to inhibit the creation of widgets by subsidiary tactics. 51 | -/ 52 | def withoutInfoTrees (t : TacticM Unit) : TacticM Unit := do 53 | let trees := (← getInfoState).trees 54 | t 55 | modifyInfoState fun s => { s with trees } 56 | 57 | 58 | open Lean.Meta.Tactic.TryThis in 59 | def hint (stx : Syntax) (tacStrs : Array String) (check : Bool) : TacticM Unit := do 60 | if check then 61 | let tacStxs ← tacStrs.filterMapM fun tstr : String => do match runParserCategory (← getEnv) `tactic tstr with 62 | | Except.error _ => return none 63 | | Except.ok stx => return some (tstr, stx) 64 | let tacs := Nondet.ofList tacStxs.toList 65 | let results := tacs.filterMapM fun t : (String × Syntax) => do 66 | if let some msgs ← observing? (withMessageLog (withoutInfoTrees (evalTactic t.2))) then 67 | return some (← getGoals, ← suggestion t.1 msgs) 68 | else 69 | return none 70 | let results ← (results.toMLList.takeUpToFirst fun r => r.1.1.isEmpty).asArray 71 | let results := results.qsort (·.1.1.length < ·.1.1.length) 72 | addSuggestions stx (results.map (·.1.2)) 73 | match results.find? (·.1.1.isEmpty) with 74 | | some r => 75 | setMCtx r.2.term.meta.meta.mctx 76 | | none => admitGoal (← getMainGoal) 77 | else 78 | let tacsNoCheck : Array Suggestion := tacStrs.map fun tac => { suggestion := SuggestionText.string tac } 79 | addSuggestions stx tacsNoCheck 80 | -------------------------------------------------------------------------------- /LeanCopilot/Models/ByT5.lean: -------------------------------------------------------------------------------- 1 | /- 2 | ByT5 tokenization implemented in Lean. 3 | -/ 4 | import LeanCopilot.Models.Native 5 | 6 | set_option autoImplicit false 7 | 8 | namespace LeanCopilot.ByT5 9 | 10 | 11 | def vocab : Array String := #[ 12 | "\u0000", 13 | "\u0001", 14 | "\u0002", 15 | "\u0003", 16 | "\u0004", 17 | "\u0005", 18 | "\u0006", 19 | "\u0007", 20 | "\\b", 21 | "\t", 22 | "\n", 23 | "\u000b", 24 | "\\f", 25 | "\r", 26 | "\u000e", 27 | "\u000f", 28 | "\u0010", 29 | "\u0011", 30 | "\u0012", 31 | "\u0013", 32 | "\u0014", 33 | "\u0015", 34 | "\u0016", 35 | "\u0017", 36 | "\u0018", 37 | "\u0019", 38 | "\u001a", 39 | "\u001b", 40 | "\u001c", 41 | "\u001d", 42 | "\u001e", 43 | "\u001f", 44 | " ", 45 | "!", 46 | "\"", 47 | "#", 48 | "$", 49 | "%", 50 | "&", 51 | "'", 52 | "(", 53 | ")", 54 | "*", 55 | "+", 56 | ",", 57 | "-", 58 | ".", 59 | "/", 60 | "0", 61 | "1", 62 | "2", 63 | "3", 64 | "4", 65 | "5", 66 | "6", 67 | "7", 68 | "8", 69 | "9", 70 | ":", 71 | ";", 72 | "<", 73 | "=", 74 | ">", 75 | "?", 76 | "@", 77 | "A", 78 | "B", 79 | "C", 80 | "D", 81 | "E", 82 | "F", 83 | "G", 84 | "H", 85 | "I", 86 | "J", 87 | "K", 88 | "L", 89 | "M", 90 | "N", 91 | "O", 92 | "P", 93 | "Q", 94 | "R", 95 | "S", 96 | "T", 97 | "U", 98 | "V", 99 | "W", 100 | "X", 101 | "Y", 102 | "Z", 103 | "[", 104 | "\\", 105 | "]", 106 | "^", 107 | "_", 108 | "`", 109 | "a", 110 | "b", 111 | "c", 112 | "d", 113 | "e", 114 | "f", 115 | "g", 116 | "h", 117 | "i", 118 | "j", 119 | "k", 120 | "l", 121 | "m", 122 | "n", 123 | "o", 124 | "p", 125 | "q", 126 | "r", 127 | "s", 128 | "t", 129 | "u", 130 | "v", 131 | "w", 132 | "x", 133 | "y", 134 | "z", 135 | "{", 136 | "|", 137 | "}", 138 | "~", 139 | "\u007f", 140 | "\u0080", 141 | "\u0081", 142 | "\u0082", 143 | "\u0083", 144 | "\u0084", 145 | "\u0085", 146 | "\u0086", 147 | "\u0087", 148 | "\u0088", 149 | "\u0089", 150 | "\u008a", 151 | "\u008b", 152 | "\u008c", 153 | "\u008d", 154 | "\u008e", 155 | "\u008f", 156 | "\u0090", 157 | "\u0091", 158 | "\u0092", 159 | "\u0093", 160 | "\u0094", 161 | "\u0095", 162 | "\u0096", 163 | "\u0097", 164 | "\u0098", 165 | "\u0099", 166 | "\u009a", 167 | "\u009b", 168 | "\u009c", 169 | "\u009d", 170 | "\u009e", 171 | "\u009f", 172 | "\u00a0", 173 | "\u00a1", 174 | "\u00a2", 175 | "\u00a3", 176 | "\u00a4", 177 | "\u00a5", 178 | "\u00a6", 179 | "\u00a7", 180 | "\u00a8", 181 | "\u00a9", 182 | "\u00aa", 183 | "\u00ab", 184 | "\u00ac", 185 | "\u00ad", 186 | "\u00ae", 187 | "\u00af", 188 | "\u00b0", 189 | "\u00b1", 190 | "\u00b2", 191 | "\u00b3", 192 | "\u00b4", 193 | "\u00b5", 194 | "\u00b6", 195 | "\u00b7", 196 | "\u00b8", 197 | "\u00b9", 198 | "\u00ba", 199 | "\u00bb", 200 | "\u00bc", 201 | "\u00bd", 202 | "\u00be", 203 | "\u00bf", 204 | "\u00c0", 205 | "\u00c1", 206 | "\u00c2", 207 | "\u00c3", 208 | "\u00c4", 209 | "\u00c5", 210 | "\u00c6", 211 | "\u00c7", 212 | "\u00c8", 213 | "\u00c9", 214 | "\u00ca", 215 | "\u00cb", 216 | "\u00cc", 217 | "\u00cd", 218 | "\u00ce", 219 | "\u00cf", 220 | "\u00d0", 221 | "\u00d1", 222 | "\u00d2", 223 | "\u00d3", 224 | "\u00d4", 225 | "\u00d5", 226 | "\u00d6", 227 | "\u00d7", 228 | "\u00d8", 229 | "\u00d9", 230 | "\u00da", 231 | "\u00db", 232 | "\u00dc", 233 | "\u00dd", 234 | "\u00de", 235 | "\u00df", 236 | "\u00e0", 237 | "\u00e1", 238 | "\u00e2", 239 | "\u00e3", 240 | "\u00e4", 241 | "\u00e5", 242 | "\u00e6", 243 | "\u00e7", 244 | "\u00e8", 245 | "\u00e9", 246 | "\u00ea", 247 | "\u00eb", 248 | "\u00ec", 249 | "\u00ed", 250 | "\u00ee", 251 | "\u00ef", 252 | "\u00f0", 253 | "\u00f1", 254 | "\u00f2", 255 | "\u00f3", 256 | "\u00f4", 257 | "\u00f5", 258 | "\u00f6", 259 | "\u00f7", 260 | "\u00f8", 261 | "\u00f9", 262 | "\u00fa", 263 | "\u00fb", 264 | "\u00fc", 265 | "\u00fd", 266 | "\u00fe", 267 | "\u00ff" 268 | ] 269 | 270 | 271 | private def byteToToken (b : UInt8) : String := 272 | vocab[b.toNat]! 273 | 274 | 275 | private def tokenToByte! (t : String) : UInt8 := 276 | vocab.findIdx? (· = t) |>.get! |>.toUInt8 277 | 278 | 279 | def tokenize (text : String) : Array String := 280 | (byteToToken <$> text.toUTF8.toList).toArray 281 | 282 | 283 | def detokenize (tokens : Array String) : String := 284 | match (String.fromUTF8? ⟨tokens.map tokenToByte!⟩) with 285 | | some s => s 286 | | none => "" 287 | 288 | 289 | def eosToken := "" 290 | 291 | 292 | def tokenizer : Tokenizer := { 293 | tokenize := tokenize, 294 | detokenize := detokenize, 295 | eosToken := eosToken 296 | } 297 | 298 | 299 | end LeanCopilot.ByT5 300 | -------------------------------------------------------------------------------- /scripts/convert_t5encoder_to_ct2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from transformers import AutoTokenizer 3 | import ctranslate2 4 | from ctranslate2.converters.transformers import ( 5 | TransformersConverter, 6 | ModelLoader, 7 | _MODEL_LOADERS, 8 | _SUPPORTED_ACTIVATIONS, 9 | ) 10 | from ctranslate2.specs import transformer_spec, common_spec 11 | import ctranslate2.converters.utils as utils 12 | 13 | 14 | class T5EncoderLoader(ModelLoader): 15 | @property 16 | def architecture_name(self): 17 | return "T5EncoderModel" 18 | 19 | def get_model_spec(self, model): 20 | encoder_spec = transformer_spec.TransformerEncoderSpec( 21 | model.config.num_layers, 22 | model.config.num_heads, 23 | pre_norm=True, 24 | activation=_SUPPORTED_ACTIVATIONS[model.config.dense_act_fn], 25 | ffn_glu=model.config.is_gated_act, 26 | relative_attention_bias=True, 27 | rms_norm=True, 28 | ) 29 | spec = transformer_spec.TransformerEncoderModelSpec(encoder_spec) 30 | self.set_stack(spec.encoder, model.encoder) 31 | return spec 32 | 33 | def get_vocabulary(self, model, tokenizer): 34 | tokens = super().get_vocabulary(model, tokenizer) 35 | 36 | extra_ids = model.config.vocab_size - len(tokens) 37 | for i in range(extra_ids): 38 | tokens.append("" % i) 39 | 40 | return tokens 41 | 42 | def set_vocabulary(self, spec, tokens): 43 | spec.register_vocabulary(tokens) 44 | 45 | def set_config(self, config, model, tokenizer): 46 | config.bos_token = tokenizer.pad_token 47 | config.eos_token = tokenizer.eos_token 48 | config.unk_token = tokenizer.unk_token 49 | 50 | def set_stack(self, spec, module): 51 | self.set_layer_norm(spec.layer_norm, module.final_layer_norm) 52 | self.set_embeddings( 53 | spec.embeddings[0] 54 | if isinstance(spec.embeddings, list) 55 | else spec.embeddings, 56 | module.embed_tokens, 57 | ) 58 | 59 | spec.scale_embeddings = False 60 | 61 | for i, (layer_spec, block) in enumerate(zip(spec.layer, module.block)): 62 | self.set_self_attention(layer_spec.self_attention, block.layer[0]) 63 | 64 | if i > 0: 65 | # Reuse relative attention bias from the first layer. 66 | first_self_attention = spec.layer[0].self_attention 67 | layer_spec.self_attention.relative_attention_bias = ( 68 | first_self_attention.relative_attention_bias 69 | ) 70 | layer_spec.self_attention.relative_attention_max_distance = ( 71 | first_self_attention.relative_attention_max_distance 72 | ) 73 | 74 | self.set_ffn(layer_spec.ffn, block.layer[-1]) 75 | 76 | def set_ffn(self, spec, module): 77 | if hasattr(spec, "linear_0_noact"): 78 | self.set_linear(spec.linear_0, module.DenseReluDense.wi_0) 79 | self.set_linear(spec.linear_0_noact, module.DenseReluDense.wi_1) 80 | else: 81 | self.set_linear(spec.linear_0, module.DenseReluDense.wi) 82 | 83 | self.set_linear(spec.linear_1, module.DenseReluDense.wo) 84 | self.set_layer_norm(spec.layer_norm, module.layer_norm) 85 | 86 | def set_self_attention(self, spec, module): 87 | self.set_attention(spec, module.SelfAttention, self_attention=True) 88 | self.set_layer_norm(spec.layer_norm, module.layer_norm) 89 | 90 | def set_attention(self, spec, attention, self_attention=False): 91 | spec.queries_scale = 1.0 92 | 93 | split_layers = [common_spec.LinearSpec() for _ in range(3)] 94 | self.set_linear(split_layers[0], attention.q) 95 | self.set_linear(split_layers[1], attention.k) 96 | self.set_linear(split_layers[2], attention.v) 97 | 98 | if self_attention: 99 | utils.fuse_linear(spec.linear[0], split_layers) 100 | else: 101 | utils.fuse_linear(spec.linear[0], split_layers[:1]) 102 | utils.fuse_linear(spec.linear[1], split_layers[1:]) 103 | 104 | self.set_linear(spec.linear[-1], attention.o) 105 | 106 | if attention.has_relative_attention_bias: 107 | spec.relative_attention_bias = attention.relative_attention_bias.weight 108 | spec.relative_attention_max_distance = np.dtype("int32").type( 109 | attention.relative_attention_max_distance 110 | ) 111 | 112 | def set_layer_norm(self, spec, layer_norm): 113 | spec.gamma = layer_norm.weight 114 | 115 | 116 | _MODEL_LOADERS["T5Config"] = T5EncoderLoader() 117 | 118 | converter = TransformersConverter("kaiyuy/leandojo-lean4-retriever-byt5-small") 119 | converter.convert("ct2-leandojo-lean4-retriever-byt5-small", force=True) 120 | 121 | encoder = ctranslate2.Encoder("ct2-leandojo-lean4-retriever-byt5-small") 122 | state = "n : ℕ\n⊢ gcd n n = n" 123 | tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small") 124 | output = encoder.forward_batch( 125 | [tokenizer.convert_ids_to_tokens(tokenizer.encode(state))] 126 | ) 127 | feature = np.array(output.last_hidden_state).mean(axis=1) 128 | -------------------------------------------------------------------------------- /LeanCopilot/Tactics.lean: -------------------------------------------------------------------------------- 1 | import Lean 2 | import LeanCopilot.Options 3 | import LeanCopilot.Frontend 4 | import Aesop.Util.Basic 5 | import Batteries.Data.String.Basic 6 | import Batteries.Data.String.Matcher 7 | 8 | open Lean Meta Parser Elab Term Tactic 9 | 10 | 11 | set_option autoImplicit false 12 | 13 | 14 | namespace LeanCopilot 15 | 16 | /-- 17 | Pretty-print a list of goals. 18 | -/ 19 | def ppTacticState : List MVarId → MetaM String 20 | | [] => return "no goals" 21 | | [g] => return (← Meta.ppGoal g).pretty 22 | | goals => 23 | return (← goals.foldlM (init := "") (fun a b => do return s!"{a}\n\n{(← Meta.ppGoal b).pretty}")).trim 24 | 25 | 26 | /-- 27 | Pretty-print the current tactic state. 28 | -/ 29 | def getPpTacticState : TacticM String := do 30 | let goals ← getUnsolvedGoals 31 | ppTacticState goals 32 | 33 | 34 | open SuggestTactics in 35 | /-- 36 | Generate a list of tactic suggestions. 37 | -/ 38 | def suggestTactics (targetPrefix : String) : TacticM (Array (String × Float)) := do 39 | let state ← getPpTacticState 40 | let nm ← getGeneratorName 41 | let model ← getGenerator nm 42 | let suggestions ← generate model state targetPrefix 43 | -- A temporary workaround to prevent the tactic from using the current theorem. 44 | -- TODO: Use a more principled way, e.g., see `Lean4Repl.lean` in `LeanDojo`. 45 | if let some declName ← getDeclName? then 46 | let theoremName := match declName.toString with 47 | | "_example" => "" 48 | | n => n.splitOn "." |>.getLast! 49 | let theoremNameMatcher := String.Matcher.ofString theoremName 50 | if ← isVerbose then 51 | logInfo s!"State:\n{state}" 52 | logInfo s!"Theorem name:\n{theoremName}" 53 | let filteredSuggestions := suggestions.filterMap fun ((t, s) : String × Float) => 54 | let isAesop := t == "aesop" 55 | let isSelfReference := ¬ (theoremName == "") ∧ (theoremNameMatcher.find? t |>.isSome) 56 | if isSelfReference ∨ isAesop then none else some (t, s) 57 | return filteredSuggestions 58 | else 59 | let filteredSuggestions := suggestions.filterMap fun ((t, s) : String × Float) => 60 | let isAesop := t == "aesop" 61 | if isAesop then none else some (t, s) 62 | return filteredSuggestions 63 | 64 | 65 | /-- 66 | Information of a premise. 67 | -/ 68 | structure PremiseInfo where 69 | name : String 70 | path : String 71 | code : String 72 | score : Float 73 | 74 | 75 | /-- 76 | Annotate a premise with its type, doc string, import module path, and definition code. 77 | -/ 78 | private def annotatePremise (pi : PremiseInfo) : MetaM String := do 79 | let declName := pi.name.toName 80 | try 81 | let info ← getConstInfo declName 82 | let premise_type ← Meta.ppExpr info.type 83 | let some doc_str ← findDocString? (← getEnv) declName 84 | | return s!"{pi.name} : {premise_type}\n" 85 | return s!"{pi.name} : {premise_type}\n```doc\n{doc_str}\n```\n" 86 | catch _ => return s!"{pi.name} needs to be imported from `{pi.path}`.\n```code\n{pi.code}\n```\n" 87 | 88 | 89 | /-- 90 | Retrieve a list of premises given a query. 91 | -/ 92 | def retrieve (input : String) : TacticM (Array PremiseInfo) := do 93 | if ¬ (← premiseEmbeddingsInitialized) ∧ ¬ (← initPremiseEmbeddings .auto) then 94 | throwError "Cannot initialize premise embeddings" 95 | 96 | if ¬ (← premiseDictionaryInitialized) ∧ ¬ (← initPremiseDictionary) then 97 | throwError "Cannot initialize premise dictionary" 98 | 99 | let k ← SelectPremises.getNumPremises 100 | let query ← encode Builtin.encoder input 101 | 102 | let rawPremiseInfo := FFI.retrieve query k.toUInt64 103 | let premiseInfo : Array PremiseInfo := rawPremiseInfo.map fun (name, path, code, score) => 104 | { name := name, path := path, code := code, score := score } 105 | return premiseInfo 106 | 107 | 108 | /-- 109 | Retrieve a list of premises using the current pretty-printed tactic state as the query. 110 | -/ 111 | def selectPremises : TacticM (Array PremiseInfo) := do 112 | retrieve (← getPpTacticState) 113 | 114 | 115 | syntax "pp_state" : tactic 116 | syntax "suggest_tactics" : tactic 117 | syntax "suggest_tactics" str : tactic 118 | syntax "select_premises" : tactic 119 | 120 | 121 | macro_rules 122 | | `(tactic | suggest_tactics%$tac) => `(tactic | suggest_tactics%$tac "") 123 | 124 | 125 | elab_rules : tactic 126 | | `(tactic | pp_state) => do 127 | let state ← getPpTacticState 128 | logInfo state 129 | 130 | | `(tactic | suggest_tactics%$tac $pfx:str) => do 131 | let (tacticsWithScores, elapsed) ← Aesop.time $ suggestTactics pfx.getString 132 | if ← isVerbose then 133 | logInfo s!"{elapsed.printAsMillis} for generating {tacticsWithScores.size} tactics" 134 | let tactics := tacticsWithScores.map (·.1) 135 | if ← isVerbose then 136 | logInfo s!"Tactics: {tactics}" 137 | let range : Lean.Syntax.Range := { start := tac.getRange?.get!.start, stop := pfx.raw.getRange?.get!.stop } 138 | let ref := Syntax.ofRange range 139 | hint ref tactics (← SuggestTactics.checkTactics) 140 | 141 | | `(tactic | select_premises) => do 142 | let premisesWithInfoAndScores ← selectPremises 143 | let rankedPremisesWithInfoAndScores := premisesWithInfoAndScores.qsort (·.score > ·.score) 144 | let richPremises ← Meta.liftMetaM $ (rankedPremisesWithInfoAndScores.mapM annotatePremise) 145 | let richPremisesExpand := richPremises.foldl (init := "") (· ++ · ++ "\n") 146 | logInfo richPremisesExpand 147 | 148 | 149 | end LeanCopilot 150 | -------------------------------------------------------------------------------- /LeanCopilot/Models/FFI.lean: -------------------------------------------------------------------------------- 1 | import Lean 2 | import LeanCopilot.Models.Interface 3 | import LeanCopilot.Models.Native 4 | import LeanCopilot.Models.Builtin 5 | 6 | namespace LeanCopilot 7 | 8 | set_option autoImplicit false 9 | 10 | namespace FFI 11 | 12 | 13 | @[extern "is_generator_initialized"] 14 | opaque isGeneratorInitialized : (name : @& String) → Bool 15 | 16 | @[extern "is_encoder_initialized"] 17 | opaque isEncoderInitialized : (name : @& String) → Bool 18 | 19 | @[extern "init_generator"] 20 | opaque initGenerator (name : @& String) (modelPath : @& String) (computeType : @& String) (device : @& String) (deviceIndex : @& Array UInt64) : Bool 21 | 22 | @[extern "init_encoder"] 23 | opaque initEncoder (name : @& String) (modelPath : @& String) (computeType : @& String) (device : @& String) (deviceIndex : @& Array UInt64) : Bool 24 | 25 | @[extern "generate"] 26 | opaque generate (name : @& String) (inputTokens : @& Array String) (targetPrefixTokens : @& Array String) (numReturnSequences : UInt64) (beamSize : UInt64) 27 | (minLength : UInt64) (maxLength : UInt64) (lengthPenalty : Float) (patience : Float) (temperature : Float) 28 | : Array (Array String × Float) 29 | 30 | @[extern "encode"] 31 | opaque encode (name : @& String) (inputTokens : @& Array String) : FloatArray 32 | 33 | @[extern "init_premise_embeddings"] 34 | opaque initPremiseEmbeddings (path : @& String) (device : @& String) : Bool 35 | 36 | @[extern "premise_embeddings_initialized"] 37 | opaque premiseEmbeddingsInitialized : Unit → Bool 38 | 39 | @[extern "init_premise_dictionary"] 40 | opaque initPremiseDictionary (path : @& String) : Bool 41 | 42 | @[extern "premise_dictionary_initialized"] 43 | opaque premiseDictionaryInitialized : Unit → Bool 44 | 45 | @[extern "retrieve"] 46 | opaque retrieve (queryEmb : @& FloatArray) (k : UInt64) : Array (String × String × String × Float) 47 | 48 | @[extern "cuda_available"] 49 | opaque cudaAvailable : Unit → Bool 50 | 51 | 52 | end FFI 53 | 54 | 55 | def cudaAvailable : Bool := FFI.cudaAvailable () 56 | 57 | 58 | namespace NativeGenerator 59 | 60 | 61 | def generate (model : NativeGenerator) (input : String) (targetPrefix : String) : IO $ Array (String × Float) := do 62 | if ¬ FFI.isGeneratorInitialized model.name then 63 | let path ← model.path 64 | if ¬ (← path.pathExists) then 65 | throw $ IO.userError s!"Cannot find the model {model.name}. Please run `lake exe download {model.url}`." 66 | let device := model.device.toString 67 | let computeType := model.computeType.toString 68 | if ¬ (FFI.initGenerator model.name path.toString computeType device model.deviceIndex) then 69 | throw $ IO.userError s!"Failed to initialize model {model.name}" 70 | 71 | let tokenizer := model.tokenizer 72 | let inputTokens := tokenizer.tokenize input |>.push tokenizer.eosToken 73 | let targetPrefixTokens := tokenizer.tokenize targetPrefix 74 | let numReturnSequences := model.params.numReturnSequences 75 | let beamSize := model.params.beamSize 76 | let minLength := model.params.minLength 77 | let maxLength := model.params.maxLength 78 | let lengthPenalty := model.params.lengthPenalty 79 | let patience := model.params.patience 80 | let temperature := model.params.temperature 81 | let tokensWithScores := FFI.generate model.name inputTokens targetPrefixTokens numReturnSequences beamSize minLength maxLength lengthPenalty patience temperature 82 | 83 | return tokensWithScores.filterMap fun ((ts, s) : Array String × Float) => (tokenizer.detokenize ts, s) 84 | 85 | 86 | instance : TextToText NativeGenerator where 87 | generate := NativeGenerator.generate 88 | 89 | 90 | end NativeGenerator 91 | 92 | 93 | namespace NativeEncoder 94 | 95 | 96 | def encode (model : NativeEncoder) (input : String) : IO FloatArray := do 97 | if ¬ FFI.isEncoderInitialized model.name then 98 | let path ← model.path 99 | if ¬ (← path.pathExists) then 100 | throw $ IO.userError s!"Cannot find the model {model.name}. Please run `lake exe download {model.url}`." 101 | let device := model.device.toString 102 | let computeType := model.computeType.toString 103 | if ¬ (FFI.initEncoder model.name path.toString computeType device model.deviceIndex) then 104 | throw $ IO.userError s!"Failed to initialize model {model.name}" 105 | 106 | let tokenizer := model.tokenizer 107 | let inputTokens := tokenizer.tokenize input |>.push tokenizer.eosToken 108 | return FFI.encode model.name inputTokens 109 | 110 | 111 | instance : TextToVec NativeEncoder where 112 | encode := NativeEncoder.encode 113 | 114 | 115 | end NativeEncoder 116 | 117 | 118 | def premiseEmbeddingsInitialized : IO Bool := do 119 | return FFI.premiseEmbeddingsInitialized () 120 | 121 | 122 | def initPremiseEmbeddings (device : Device) : Lean.CoreM Bool := do 123 | let url := Builtin.premisesUrl 124 | if ¬(← isUpToDate url) then 125 | Lean.logWarning s!"The local premise embeddings are not up to date. You may want to run `lake exe LeanCopilot/download` to re-download it." 126 | let path := (← getModelDir url) / "embeddings.npy" 127 | if ¬ (← path.pathExists) then 128 | throwError s!"Please run `lake exe download {url}` to download premise embeddings." 129 | return false 130 | return FFI.initPremiseEmbeddings path.toString device.toString 131 | 132 | 133 | def premiseDictionaryInitialized : IO Bool := do 134 | return FFI.premiseDictionaryInitialized () 135 | 136 | 137 | def initPremiseDictionary : IO Bool := do 138 | let path := (← getModelDir Builtin.premisesUrl) / "dictionary.json" 139 | if ¬ (← path.pathExists) then 140 | throw $ IO.userError s!"Please run `lake exe download {Builtin.premisesUrl}` to download the premise dictionary." 141 | return false 142 | return FFI.initPremiseDictionary path.toString 143 | 144 | 145 | end LeanCopilot 146 | -------------------------------------------------------------------------------- /python/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from loguru import logger 4 | from typing import List, Tuple 5 | from abc import ABC, abstractmethod 6 | from transformers import ( 7 | AutoModelForCausalLM, 8 | AutoModelForSeq2SeqLM, 9 | AutoTokenizer, 10 | AutoModelForTextEncoding, 11 | ) 12 | 13 | 14 | class Generator(ABC): 15 | @abstractmethod 16 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 17 | pass 18 | 19 | 20 | class Encoder(ABC): 21 | @abstractmethod 22 | def encode(self, input: str) -> np.ndarray: 23 | pass 24 | 25 | 26 | class Transformer: 27 | def cuda(self) -> None: 28 | self.model.cuda() 29 | 30 | def cpu(self) -> None: 31 | self.model.cpu() 32 | 33 | @property 34 | def device(self) -> torch.device: 35 | return self.model.device 36 | 37 | 38 | def get_cuda_if_available(): 39 | return torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | 42 | class DecoderOnlyTransformer(Generator, Transformer): 43 | def __init__( 44 | self, 45 | name: str, 46 | num_return_sequences: int, 47 | max_length: int, 48 | length_penalty: float = 0.0, 49 | device: str = "cpu", 50 | ) -> None: 51 | self.tokenizer = AutoTokenizer.from_pretrained(name) 52 | if device == "auto": 53 | device = get_cuda_if_available() 54 | else: 55 | device = torch.device(device) 56 | logger.info(f"Loading {name} on {device}") 57 | self.model = AutoModelForCausalLM.from_pretrained(name).to(device) 58 | self.max_length = max_length 59 | self.num_return_sequences = num_return_sequences 60 | self.length_penalty = length_penalty 61 | 62 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 63 | tokenized_input = self.tokenizer(input + target_prefix, return_tensors="pt") 64 | output = self.model.generate( 65 | tokenized_input.input_ids.to(self.device), 66 | max_length=self.max_length, 67 | num_beams=self.num_return_sequences, 68 | length_penalty=self.length_penalty, 69 | do_sample=False, 70 | num_return_sequences=self.num_return_sequences, 71 | early_stopping=False, 72 | return_dict_in_generate=True, 73 | output_scores=True, 74 | ) 75 | raw_outputs = self.tokenizer.batch_decode( 76 | output.sequences, skip_special_tokens=True 77 | ) 78 | outputs = [] 79 | 80 | for out, score in zip(raw_outputs, output.sequences_scores.exp()): 81 | assert out.startswith(input + target_prefix) 82 | outputs.append((out[len(input) :], score.item())) 83 | 84 | return outputs 85 | 86 | 87 | class PythiaTacticGenerator(DecoderOnlyTransformer): 88 | def __init__( 89 | self, 90 | num_return_sequences: int, 91 | max_length: int, 92 | length_penalty: float = 0.0, 93 | device: str = "cpu", 94 | ) -> None: 95 | super().__init__( 96 | "wellecks/llmstep-mathlib4-pythia2.8b", 97 | num_return_sequences, 98 | max_length, 99 | length_penalty, 100 | device, 101 | ) 102 | 103 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 104 | return super().generate(f"[GOAL]{input}[PROOFSTEP]{target_prefix}") 105 | 106 | 107 | class EncoderDecoderTransformer(Generator, Transformer): 108 | def __init__( 109 | self, 110 | name: str, 111 | num_return_sequences: int, 112 | max_length: int, 113 | length_penalty: float = 0.0, 114 | device: str = "cpu", 115 | ) -> None: 116 | self.tokenizer = AutoTokenizer.from_pretrained(name) 117 | if device == "auto": 118 | device = get_cuda_if_available() 119 | else: 120 | device = torch.device(device) 121 | logger.info(f"Loading {name} on {device}") 122 | self.model = AutoModelForSeq2SeqLM.from_pretrained(name) 123 | self.max_length = max_length 124 | self.num_return_sequences = num_return_sequences 125 | self.length_penalty = length_penalty 126 | 127 | def generate(self, input: str, target_prefix: str = "") -> List[Tuple[str, float]]: 128 | assert ( 129 | target_prefix == "" 130 | ), "target_prefix is not supported by encoder-decoder Transformer" 131 | tokenized_input = self.tokenizer(input, return_tensors="pt") 132 | output = self.model.generate( 133 | tokenized_input.input_ids.to(self.device), 134 | max_length=self.max_length, 135 | num_beams=self.num_return_sequences, 136 | length_penalty=self.length_penalty, 137 | do_sample=False, 138 | num_return_sequences=self.num_return_sequences, 139 | early_stopping=False, 140 | return_dict_in_generate=True, 141 | output_scores=True, 142 | ) 143 | raw_outputs = self.tokenizer.batch_decode( 144 | output.sequences, skip_special_tokens=True 145 | ) 146 | return list(zip(raw_outputs, output.sequences_scores.exp().tolist())) 147 | 148 | 149 | class EncoderOnlyTransformer(Encoder, Transformer): 150 | def __init__(self, name: str, device: str = "cpu") -> None: 151 | self.tokenizer = AutoTokenizer.from_pretrained(name) 152 | if device == "auto": 153 | device = get_cuda_if_available() 154 | else: 155 | device = torch.device(device) 156 | logger.info(f"Loading {name} on {device}") 157 | self.model = AutoModelForTextEncoding.from_pretrained(name) 158 | 159 | @torch.no_grad() 160 | def encode(self, input: str) -> np.ndarray: 161 | tokenized_input = self.tokenizer(input, return_tensors="pt") 162 | hidden_state = self.model( 163 | tokenized_input.input_ids.to(self.device) 164 | ).last_hidden_state 165 | feature = hidden_state.mean(dim=1).squeeze() 166 | return feature.cpu().numpy() 167 | 168 | 169 | if __name__ == "__main__": 170 | model = PythiaTacticGenerator(num_return_sequences=32, max_length=1024) 171 | model.cuda() 172 | print(model.generate("n : ℕ\n⊢ gcd n n = n")) 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Lean Copilot: LLMs as Copilots for Theorem Proving in Lean 2 | ========================================================== 3 | 4 | 🚩**News**: [Our paper](https://arxiv.org/abs/2404.12534) is accepted to the International Conference on Neuro-symbolic Systems (NeuS), 2025. See you in Philadelphia! 5 | 6 | Lean Copilot allows large language models (LLMs) to be used natively in Lean for proof automation, e.g., suggesting tactics/premises and searching for proofs. You can use our built-in models from [LeanDojo](https://leandojo.org/) or bring your own models that run either locally (w/ or w/o GPUs) or on the cloud. 7 | 8 | 9 | 10 | ## Table of Contents 11 | 12 | 1. [Requirements](#requirements) 13 | 1. [Using Lean Copilot in Your Project](#using-lean-copilot-in-your-project) 14 | 1. [Adding Lean Copilot as a Dependency](#adding-lean-copilot-as-a-dependency) 15 | 1. [Getting Started with Lean Copilot](#getting-started-with-lean-copilot) 16 | 1. [Tactic Suggestion](#tactic-suggestion) 17 | 1. [Proof Search](#proof-search) 18 | 1. [Premise Selection](#premise-selection) 19 | 1. [Advanced Usage](#advanced-usage) 20 | 1. [Tactic APIs](#tactic-apis) 21 | 1. [Model APIs](#model-apis) 22 | 1. [Bring Your Own Model](#bring-your-own-model) 23 | 1. [Caveats](#caveats) 24 | 1. [Getting in Touch](#getting-in-touch) 25 | 1. [Acknowledgements](#acknowledgements) 26 | 1. [Citation](#citation) 27 | 28 | ## Requirements 29 | 30 | * Supported platforms: Linux, macOS, Windows and [Windows WSL](https://learn.microsoft.com/en-us/windows/wsl/install). 31 | * [Git LFS](https://git-lfs.com/). 32 | * Optional (recommended if you have a [CUDA-enabled GPU](https://developer.nvidia.com/cuda-gpus)): CUDA and [cuDNN](https://developer.nvidia.com/cudnn). 33 | * Required for building Lean Copilot itself (rather than a downstream package): CMake >= 3.7 and a C++17 compatible compiler. 34 | 35 | ## Using Lean Copilot in Your Project 36 | 37 | :warning: Your project must use a Lean version of at least `lean4:v4.3.0-rc2`. 38 | 39 | ### Adding Lean Copilot as a Dependency 40 | 41 | 1. Add the package configuration option `moreLinkArgs := #["-L./.lake/packages/LeanCopilot/.lake/build/lib", "-lctranslate2"]` to lakefile.lean. For example, 42 | 43 | ```lean 44 | package «my-package» { 45 | moreLinkArgs := #[ 46 | "-L./.lake/packages/LeanCopilot/.lake/build/lib", 47 | "-lctranslate2" 48 | ] 49 | } 50 | ``` 51 | 52 | Alternatively, if your project uses lakefile.toml, it should include: 53 | 54 | ```toml 55 | moreLinkArgs = ["-L./.lake/packages/LeanCopilot/.lake/build/lib", "-lctranslate2"] 56 | ``` 57 | 58 | 2. Add the following line to lakefile.lean, including the quotation marks: 59 | 60 | ```lean 61 | require LeanCopilot from git "https://github.com/lean-dojo/LeanCopilot.git" @ "LEAN_COPILOT_VERSION" 62 | ``` 63 | 64 | For stable Lean versions (e.g., `v4.26.0`), set `LEAN_COPILOT_VERSION` to be that version. For the latest unstable Lean versions (e.g., `v4.27.0-rc1`), set `LEAN_COPILOT_VERSION` to `main`. In either case, make sure the version is compatible with other dependencies such as mathlib. If your project uses lakefile.toml instead of lakefile.lean, it should include: 65 | 66 | ```toml 67 | [[require]] 68 | name = "LeanCopilot" 69 | git = "https://github.com/lean-dojo/LeanCopilot.git" 70 | rev = "LEAN_COPILOT_VERSION" 71 | ``` 72 | 73 | 3. If you are using native Windows, add `/.lake/packages/LeanCopilot/.lake/build/lib` to your `Path` variable in Advanced System Settings > Environment Variables... > System variables. 74 | 75 | 4. Run `lake update LeanCopilot`. 76 | 77 | 5. Run `lake exe LeanCopilot/download` to download the built-in models from Hugging Face to `~/.cache/lean_copilot/`. *Alternatively*, you can download the models from Hugging Face manually from 78 | 79 | * [ct2-leandojo-lean4-tacgen-byt5-small](https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small) 80 | * [ct2-leandojo-lean4-retriever-byt5-small](https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small) 81 | * [premise-embeddings-leandojo-lean4-retriever-byt5-small](https://huggingface.co/kaiyuy/premise-embeddings-leandojo-lean4-retriever-byt5-small) 82 | * [ct2-byt5-small](https://huggingface.co/kaiyuy/ct2-byt5-small) 83 | 84 | 6. Run `lake build`. 85 | 86 | [Here](https://github.com/yangky11/lean4-example/blob/LeanCopilot-demo) is an example of a Lean package depending on Lean Copilot. If you have problems building the project, our [Dockerfile](./Dockerfile), [build.sh](scripts/build.sh) or [build_example.sh](scripts/build_example.sh) may be helpful. 87 | 88 | ### Getting Started with Lean Copilot 89 | 90 | #### Tactic Suggestion 91 | 92 | After `import LeanCopilot`, you can use the tactic `suggest_tactics` to generate tactic suggestions. You can click on any of the suggested tactics to use it in the proof. 93 | 94 | suggest_tactics 95 | 96 | You can provide a prefix (e.g., `simp`) to constrain the generated tactics: 97 | 98 | suggest_tactics_simp 99 | 100 | #### Proof Search 101 | 102 | The tactic `search_proof` combines LLM-generated tactics with [aesop](https://github.com/leanprover-community/aesop) to search for multi-tactic proofs. When a proof is found, you can click on it to insert it into the editor. 103 | 104 | search_proof 105 | 106 | #### Premise Selection 107 | 108 | The `select_premises` tactic retrieves a list of potentially useful premises. Currently, it uses the retriever in [LeanDojo](https://leandojo.org/) to select premises from a fixed snapshot of Lean and [mathlib4](https://github.com/leanprover-community/mathlib4/tree/3ce43c18f614b76e161f911b75a3e1ef641620ff). 109 | 110 | ![select_premises](https://github.com/lean-dojo/LeanCopilot/assets/114432581/2817663c-ba98-4a47-9ae9-5b8680b6265a) 111 | 112 | #### Running LLMs 113 | 114 | You can also run the inference of any LLMs in Lean, which can be used to build customized proof automation or other LLM-based applications (not limited to theorem proving). It's possible to run arbitrary models either locally or remotely (see [Bring Your Own Model](#bring-your-own-model)). 115 | 116 | run_llms 117 | 118 | ## Advanced Usage 119 | 120 | **This section is only for advanced users who would like to change the default behavior of `suggest_tactics`, `search_proof`, or `select_premises`, e.g., to use different models or hyperparameters.** 121 | 122 | ### Tactic APIs 123 | 124 | * Examples in [TacticSuggestion.lean](LeanCopilotTests/TacticSuggestion.lean) showcase how to configure `suggest_tactics`, e.g., to use different models or generate different numbers of tactics. 125 | * Examples in [ProofSearch.lean](LeanCopilotTests/ProofSearch.lean) showcase how to configure `search_proof` using options provided by [aesop](https://github.com/leanprover-community/aesop). 126 | * Examples in [PremiseSelection.lean](LeanCopilotTests/PremiseSelection.lean) showcase how to set the number of retrieved premises for `select_premises`. 127 | 128 | ### Model APIs 129 | 130 | **Examples in [ModelAPIs.lean](LeanCopilotTests/ModelAPIs.lean) showcase how to run the inference of different models and configure their parameters (temperature, beam size, etc.).** 131 | 132 | Lean Copilot supports two kinds of models: generators and encoders. Generators must implement the `TextToText` interface: 133 | 134 | ```lean 135 | class TextToText (τ : Type) where 136 | generate (model : τ) (input : String) (targetPrefix : String) : IO $ Array (String × Float) 137 | ``` 138 | 139 | * `input` is the input string 140 | * `targetPrefix` is used to constrain the generator's output. `""` means no constraint. 141 | * `generate` should return an array of `String × Float`. Each `String` is an output from the model, and `Float` is the corresponding score. 142 | 143 | We provide three types of Generators: 144 | 145 | * [`NativeGenerator`](LeanCopilot/Models/Native.lean) runs locally powered by [CTranslate2](https://github.com/OpenNMT/CTranslate2) and is linked to Lean using Foreign Function Interface (FFI). 146 | * [`ExternalGenerator`](LeanCopilot/Models/External.lean) is hosted either locally or remotely. See [Bring Your Own Model](#bring-your-own-model) for details. 147 | * [`GenericGenerator`](LeanCopilot/Models/Generic.lean) can be anything that implements the `generate` function in the `TextToText` typeclass. 148 | 149 | Encoders must implement `TextToVec`: 150 | 151 | ```lean 152 | class TextToVec (τ : Type) where 153 | encode : τ → String → IO FloatArray 154 | ``` 155 | 156 | * `input` is the input string 157 | * `encode` should return a vector embedding produced by the model. 158 | 159 | Similar to generators, we have `NativeEncoder`, `ExternalEncoder`, and `GenericEncoder`. 160 | 161 | ### Bring Your Own Model 162 | 163 | In principle, it is possible to run any model using Lean Copilot through `ExternalGenerator` or `ExternalEncoder` (examples in [ModelAPIs.lean](LeanCopilotTests/ModelAPIs.lean)). To use a model, you need to wrap it properly to expose the APIs in [external_model_api.yaml](./external_model_api.yaml). As an example, we provide a [Python API server](./python) and use it to run a few models. 164 | 165 | ## Caveats 166 | 167 | * `select_premises` always retrieves the original form of a premise. For example, `Nat.add_left_comm` is a result of the theorem below. In this case, `select_premises` retrieves `Nat.mul_left_comm` instead of `Nat.add_left_comm`. 168 | 169 | ```lean 170 | @[to_additive] 171 | theorem mul_left_comm : ∀ a b c : G, a * (b * c) = b * (a * c) 172 | ``` 173 | 174 | * In some cases, `search_proof` produces an erroneous proof with error messages like `fail to show termination for ...`. A temporary workaround is changing the theorem's name before applying `search_proof`. You can change it back after `search_proof` completes. 175 | 176 | ## Getting in Touch 177 | 178 | * For general questions and discussions, please use [GitHub Discussions](https://github.com/lean-dojo/LeanCopilot/discussions). 179 | * To report a potential bug, please open an issue. In the issue, please include your OS information, the exact steps to reproduce the error on **the latest stable version of Lean Copilot**, and complete logs preferrably in debug mode. **Important: If your issue cannot be reproduced easily, it will be unlikely to receive help.** 180 | * Feature requests and contributions are warmly welcomed. Please feel free to start a [discussion](https://github.com/lean-dojo/LeanCopilot/discussions) or open a [pull request](https://github.com/lean-dojo/LeanCopilot/pulls). 181 | 182 | ## Acknowledgements 183 | 184 | * We thank Scott Morrison for suggestions on simplifying Lean Copilot's installation and Mac Malone for helping implement it. Both Scott and Mac work for the [Lean FRO](https://lean-fro.org/). 185 | * We thank Jannis Limperg for supporting our LLM-generated tactics in Aesop (). 186 | 187 | ## Citation 188 | 189 | If you find our work useful, please consider citing [our paper](https://arxiv.org/abs/2404.12534): 190 | 191 | ```BibTeX 192 | @article{song2024lean, 193 | title={Lean copilot: Large language models as copilots for theorem proving in lean}, 194 | author={Song, Peiyang and Yang, Kaiyu and Anandkumar, Anima}, 195 | journal={arXiv preprint arXiv:2404.12534}, 196 | year={2024} 197 | } 198 | ``` 199 | -------------------------------------------------------------------------------- /cpp/ct2.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "json.hpp" 17 | #include "npy.hpp" 18 | 19 | using json = nlohmann::json; 20 | 21 | std::map> generators; 22 | std::map> encoders; 23 | 24 | ctranslate2::StorageView *p_premise_embeddings = nullptr; 25 | json *p_premise_dictionary = nullptr; 26 | 27 | // ifstream does not support directories on Windows 28 | inline bool exists(const std::string &path) { 29 | return std::filesystem::exists(path); 30 | } 31 | 32 | inline lean_obj_res lean_mk_pair(lean_obj_arg a, lean_obj_arg b) { 33 | lean_object *r = lean_alloc_ctor(0, 2, 0); 34 | lean_ctor_set(r, 0, a); 35 | lean_ctor_set(r, 1, b); 36 | return r; 37 | } 38 | 39 | extern "C" uint8_t cuda_available(b_lean_obj_arg) { 40 | return ctranslate2::str_to_device("auto") == ctranslate2::Device::CUDA; 41 | } 42 | 43 | template 44 | bool is_initialized_aux(const std::string &name); 45 | 46 | template <> 47 | bool is_initialized_aux(const std::string &name) { 48 | return generators.find(name) != generators.end(); 49 | } 50 | 51 | template <> 52 | bool is_initialized_aux(const std::string &name) { 53 | return encoders.find(name) != encoders.end(); 54 | } 55 | 56 | extern "C" uint8_t is_generator_initialized(b_lean_obj_arg _name) { 57 | std::string name = std::string(lean_string_cstr(_name)); 58 | return is_initialized_aux(name); 59 | } 60 | 61 | extern "C" uint8_t is_encoder_initialized(b_lean_obj_arg _name) { 62 | std::string name = std::string(lean_string_cstr(_name)); 63 | return is_initialized_aux(name); 64 | } 65 | 66 | template 67 | bool init_model(b_lean_obj_arg _name, // String 68 | b_lean_obj_arg _model_path, // String 69 | b_lean_obj_arg _compute_type, // String 70 | b_lean_obj_arg _device, // String 71 | b_lean_obj_arg _device_index, // Array UInt64 72 | std::map> &models) { 73 | std::string name = std::string(lean_string_cstr(_name)); 74 | if (is_initialized_aux(name)) { 75 | throw std::runtime_error(name + " already exists."); 76 | } 77 | 78 | std::string model_path = std::string(lean_string_cstr(_model_path)); 79 | if (!exists(model_path)) { // Cannot find the model. 80 | return false; 81 | } 82 | 83 | ctranslate2::Device device = 84 | ctranslate2::str_to_device(lean_string_cstr(_device)); 85 | ctranslate2::ComputeType compute_type = 86 | ctranslate2::str_to_compute_type(lean_string_cstr(_compute_type)); 87 | 88 | std::vector device_indices; 89 | const lean_array_object *p_arr = lean_to_array(_device_index); 90 | for (int i = 0; i < p_arr->m_size; i++) { 91 | device_indices.push_back(lean_unbox_uint64(p_arr->m_data[i])); 92 | } 93 | 94 | auto p_model = 95 | std::make_unique(model_path, device, compute_type, device_indices); 96 | models.emplace(name, std::move(p_model)); 97 | return true; 98 | } 99 | 100 | extern "C" uint8_t init_generator( 101 | b_lean_obj_arg _name, // String 102 | b_lean_obj_arg _model_path, // String 103 | b_lean_obj_arg _compute_type, // String 104 | b_lean_obj_arg _device, // String 105 | b_lean_obj_arg _device_index) { // Array UInt64 106 | return init_model(_name, _model_path, _compute_type, _device, _device_index, 107 | generators); 108 | } 109 | 110 | extern "C" uint8_t init_encoder(b_lean_obj_arg _name, // String 111 | b_lean_obj_arg _model_path, // String 112 | b_lean_obj_arg _compute_type, // String 113 | b_lean_obj_arg _device, // String 114 | b_lean_obj_arg _device_index) { // Array UInt64 115 | return init_model(_name, _model_path, _compute_type, _device, _device_index, 116 | encoders); 117 | } 118 | 119 | inline std::vector convert_tokens(b_lean_obj_arg _tokens) { 120 | std::vector tokens; 121 | const lean_array_object *p_arr = lean_to_array(_tokens); 122 | for (int i = 0; i < p_arr->m_size; i++) { 123 | tokens.emplace_back(lean_string_cstr(p_arr->m_data[i])); 124 | } 125 | return tokens; 126 | } 127 | 128 | extern "C" lean_obj_res generate( 129 | b_lean_obj_arg _name, // String 130 | b_lean_obj_arg _input_tokens, // Array String 131 | b_lean_obj_arg _target_prefix_tokens, // Array String 132 | uint64_t num_return_sequences, // UInt64 133 | uint64_t beam_size, // UInt64 134 | uint64_t min_length, // UInt64 135 | uint64_t max_length, // UInt64 136 | double length_penalty, // Float 137 | double patience, // Float 138 | double temperature) { // Float 139 | // Check the arguments. 140 | std::string name = std::string(lean_string_cstr(_name)); 141 | if (!is_initialized_aux(name)) { 142 | throw std::runtime_error(name + " hasn't been initialized."); 143 | } 144 | if (num_return_sequences <= 0) { 145 | throw std::invalid_argument("num_return_sequences must be positive."); 146 | } 147 | if (beam_size <= 0) { 148 | throw std::invalid_argument("beam_size must be positive."); 149 | } 150 | if (min_length < 0 || max_length < 0 || min_length > max_length) { 151 | throw std::invalid_argument("Invalid min_length or max_length."); 152 | } 153 | if (patience < 1.0) { 154 | throw std::invalid_argument("patience must be at least 1.0."); 155 | } 156 | if (temperature <= 0) { 157 | throw std::invalid_argument("temperature must be positive."); 158 | } 159 | 160 | // Set beam search's hyperparameters. 161 | ctranslate2::TranslationOptions opts; 162 | opts.num_hypotheses = num_return_sequences; 163 | opts.beam_size = beam_size; 164 | opts.patience = patience; 165 | opts.length_penalty = length_penalty; 166 | opts.min_decoding_length = min_length; 167 | opts.max_decoding_length = max_length; 168 | opts.sampling_temperature = temperature; 169 | opts.sampling_topk = 0; 170 | opts.sampling_topp = 1.0; 171 | opts.max_input_length = 0; 172 | opts.use_vmap = true; 173 | opts.disable_unk = true; 174 | opts.return_scores = true; 175 | 176 | // Get the input tokens ready. 177 | std::vector input_tokens = convert_tokens(_input_tokens); 178 | std::vector target_prefix_tokens = 179 | convert_tokens(_target_prefix_tokens); 180 | 181 | // Generate tactics with beam search. 182 | ctranslate2::TranslationResult results = generators.at(name)->translate_batch( 183 | {input_tokens}, {target_prefix_tokens}, opts)[0]; 184 | assert(results.hypotheses.size() == num_return_sequences && 185 | results.scores.size() == num_return_sequences); 186 | 187 | // Return the output. 188 | lean_object *output = lean_mk_empty_array(); 189 | 190 | for (int i = 0; i < num_return_sequences; i++) { 191 | int l = results.hypotheses[i].size(); 192 | 193 | lean_object *tokens = lean_mk_empty_array(); 194 | for (int j = 0; j < l; j++) { 195 | tokens = lean_array_push( 196 | tokens, lean_mk_string(results.hypotheses[i][j].c_str())); 197 | } 198 | double score = std::exp(results.scores[i]); 199 | assert(0.0 <= score && score <= 1.0); 200 | output = 201 | lean_array_push(output, lean_mk_pair(tokens, lean_box_float(score))); 202 | } 203 | 204 | return output; 205 | } 206 | 207 | extern "C" lean_obj_res encode(b_lean_obj_arg _name, // String 208 | b_lean_obj_arg _input_tokens) { // Array String 209 | std::string name = std::string(lean_string_cstr(_name)); 210 | if (!is_initialized_aux(name)) { 211 | throw std::runtime_error(name + " hasn't been initialized."); 212 | } 213 | 214 | std::vector input_tokens = convert_tokens(_input_tokens); 215 | ctranslate2::EncoderForwardOutput results = 216 | encoders.at(name)->forward_batch_async({input_tokens}).get(); 217 | ctranslate2::StorageView hidden_state = results.last_hidden_state; 218 | 219 | assert(hidden_state.dim(0) == 1); 220 | int l = hidden_state.dim(1); 221 | int d = hidden_state.dim(2); 222 | lean_object *arr = lean_mk_empty_float_array(lean_box(d)); 223 | 224 | for (ctranslate2::dim_t i = 0; i < d; i++) { 225 | double sum = 0.0; 226 | for (ctranslate2::dim_t j = 0; j < l; j++) { 227 | sum += hidden_state.scalar_at({0, j, i}); 228 | } 229 | lean_float_array_push(arr, sum / l); 230 | } 231 | 232 | return arr; 233 | } 234 | 235 | extern "C" uint8_t init_premise_embeddings(b_lean_obj_arg _path, // String 236 | b_lean_obj_arg _device) { // String 237 | std::string path = std::string(lean_string_cstr(_path)); 238 | if (!exists(path)) { 239 | return false; 240 | } 241 | if (p_premise_embeddings != nullptr) { 242 | delete p_premise_embeddings; 243 | } 244 | 245 | // ctranslate2::Device device = 246 | // ctranslate2::str_to_device(lean_string_cstr(_device)); 247 | // TODO: We should remove this line when everything can work well on CUDA. 248 | ctranslate2::Device device = ctranslate2::Device::CPU; 249 | 250 | const auto &d = npy::read_npy(path); 251 | std::vector data = d.data; 252 | std::vector shape = d.shape; 253 | bool fortran_order = d.fortran_order; 254 | 255 | std::vector data_f; 256 | data_f.resize(data.size()); 257 | std::transform(data.begin(), data.end(), data_f.begin(), 258 | [](double d) { return static_cast(d); }); 259 | 260 | std::vector shape_i64; 261 | shape_i64.resize(shape.size()); 262 | std::transform(shape.begin(), shape.end(), shape_i64.begin(), 263 | [](unsigned long ul) { return static_cast(ul); }); 264 | 265 | p_premise_embeddings = 266 | new ctranslate2::StorageView(shape_i64, data_f, device); 267 | return true; 268 | } 269 | 270 | inline bool premise_embeddings_initialized_aux() { 271 | return p_premise_embeddings != nullptr; 272 | } 273 | 274 | extern "C" uint8_t premise_embeddings_initialized(lean_object *) { 275 | return premise_embeddings_initialized_aux(); 276 | } 277 | 278 | extern "C" uint8_t init_premise_dictionary(b_lean_obj_arg _path) { 279 | std::string path = std::string(lean_string_cstr(_path)); 280 | if (!exists(path)) { 281 | return false; 282 | } 283 | if (p_premise_dictionary != nullptr) { 284 | delete p_premise_dictionary; 285 | } 286 | 287 | std::ifstream f(path); 288 | p_premise_dictionary = new json(json::parse(f)); 289 | 290 | return true; 291 | } 292 | 293 | inline bool premise_dictionary_initialized_aux() { 294 | return p_premise_dictionary != nullptr; 295 | } 296 | 297 | extern "C" uint8_t premise_dictionary_initialized(lean_object *) { 298 | return premise_dictionary_initialized_aux(); 299 | } 300 | 301 | extern "C" lean_obj_res retrieve(b_lean_obj_arg _query_emb, 302 | uint64_t _k) { // FloatArray 303 | // lean_object *arr 304 | // assert(p_premise_embeddings && static_cast(p_arr->m_size) == 305 | // p_premise_embeddings->dim(1)); 306 | 307 | int64_t d = lean_unbox(lean_float_array_size(_query_emb)); 308 | std::vector query_emb_data; 309 | for (int i = 0; i < d; i++) { 310 | query_emb_data.push_back(lean_float_array_uget(_query_emb, i)); 311 | } 312 | 313 | ctranslate2::Device device = p_premise_embeddings->device(); 314 | ctranslate2::StorageView query_emb = 315 | ctranslate2::StorageView({d, 1}, query_emb_data, device); 316 | 317 | ctranslate2::ops::MatMul matmul(false, false, 1.0); 318 | long int k = static_cast(_k); 319 | ctranslate2::ops::TopK topk(k, -1); 320 | 321 | int num_premises = p_premise_embeddings->dim(0); 322 | std::vector probs_shape{num_premises, 1}; 323 | 324 | ctranslate2::StorageView probs = ctranslate2::StorageView( 325 | probs_shape, ctranslate2::DataType::FLOAT32, device); 326 | matmul(*p_premise_embeddings, query_emb, probs); 327 | probs.resize({num_premises}); 328 | 329 | ctranslate2::StorageView topk_values = 330 | ctranslate2::StorageView({k}, ctranslate2::DataType::FLOAT32, device); 331 | ctranslate2::StorageView topk_indices = 332 | ctranslate2::StorageView({k}, ctranslate2::DataType::INT32, device); 333 | topk(probs, topk_values, topk_indices); 334 | 335 | lean_object *output = lean_mk_empty_array(); 336 | const int *p_topk_indices = topk_indices.data(); 337 | const float *p_topk_values = topk_values.data(); 338 | 339 | for (int i = 0; i < k; i++) { 340 | int idx = p_topk_indices[i]; 341 | assert(0 < idx && idx < num_premises); 342 | // [NOTE]: This is where the server crash occurs on CUDA. 343 | const std::string this_premise = 344 | (*p_premise_dictionary)[std::to_string(idx)]["full_name"]; 345 | const std::string this_path = 346 | (*p_premise_dictionary)[std::to_string(idx)]["path"]; 347 | const std::string this_code = 348 | (*p_premise_dictionary)[std::to_string(idx)]["code"]; 349 | 350 | output = lean_array_push( 351 | output, 352 | lean_mk_pair( 353 | lean_mk_string(this_premise.c_str()), 354 | lean_mk_pair(lean_mk_string(this_path.c_str()), 355 | lean_mk_pair(lean_mk_string(this_code.c_str()), 356 | lean_box_float(p_topk_values[i]))))); 357 | } 358 | 359 | return output; 360 | } 361 | -------------------------------------------------------------------------------- /lakefile.lean: -------------------------------------------------------------------------------- 1 | import Lake 2 | 3 | open Lake DSL System Lean Elab 4 | 5 | set_option autoImplicit false 6 | 7 | 8 | inductive SupportedOS where 9 | | linux 10 | | macos 11 | | windows 12 | deriving Inhabited, BEq 13 | 14 | 15 | def getOS! : SupportedOS := 16 | if Platform.isWindows then 17 | .windows 18 | else if Platform.isOSX then 19 | .macos 20 | else 21 | .linux 22 | 23 | 24 | inductive SupportedArch where 25 | | x86_64 26 | | arm64 27 | deriving Inhabited, BEq 28 | 29 | 30 | def nproc : IO Nat := do 31 | let cmd := if getOS! == .windows then "cmd" else "nproc" 32 | let args := if getOS! == .windows then #["/c echo %NUMBER_OF_PROCESSORS%"] else #[] 33 | let out ← IO.Process.output {cmd := cmd, args := args, stdin := .null} 34 | return out.stdout.trim.toNat! 35 | 36 | 37 | def getArch? : IO (Option SupportedArch) := do 38 | let cmd := if getOS! == .windows then "cmd" else "uname" 39 | let args := if getOS! == .windows then #["/c echo %PROCESSOR_ARCHITECTURE%\n"] else #["-m"] 40 | 41 | let out ← IO.Process.output {cmd := cmd, args := args, stdin := .null} 42 | let arch := out.stdout.trim 43 | 44 | if arch ∈ ["arm64", "aarch64", "ARM64"] then 45 | return some .arm64 46 | else if arch ∈ ["x86_64", "AMD64"] then 47 | return some .x86_64 48 | else 49 | return none 50 | 51 | 52 | def getArch! : IO SupportedArch := do 53 | if let some arch ← getArch? then 54 | return arch 55 | else 56 | error "Unknown architecture" 57 | 58 | 59 | def isArm! : IO Bool := do 60 | return (← getArch!) == .arm64 61 | 62 | 63 | def hasCUDA : IO Bool := do 64 | if getOS! == .windows then 65 | let ok ← testProc { 66 | cmd := "nvidia-smi" 67 | args := #[] 68 | } 69 | return ok 70 | else 71 | let out ← IO.Process.output {cmd := "which", args := #["nvcc"], stdin := .null} 72 | return out.exitCode == 0 73 | 74 | def useCUDA : IO Bool := do 75 | return (get_config? noCUDA |>.isNone) ∧ (← hasCUDA) 76 | 77 | 78 | def buildArchiveName : String := 79 | let arch := if run_io isArm! then "arm64" else "x86_64" 80 | let os := if getOS! == .macos then "macOS" else "linux" 81 | if run_io useCUDA then 82 | s!"{arch}-cuda-{os}.tar.gz" 83 | else 84 | s!"{arch}-{os}.tar.gz" 85 | 86 | 87 | structure SupportedPlatform where 88 | os : SupportedOS 89 | arch : SupportedArch 90 | 91 | 92 | def getPlatform! : IO SupportedPlatform := do 93 | if Platform.numBits != 64 then 94 | error "Only 64-bit platforms are supported" 95 | return ⟨getOS!, ← getArch!⟩ 96 | 97 | def copySingleFile (src dst : FilePath) : LogIO Unit := do 98 | let cmd := if getOS! == .windows then "cmd" else "cp" 99 | let args := 100 | if getOS! == .windows then 101 | #[s!"/c copy {src.toString.replace "/" "\\"} {dst.toString.replace "/" "\\"}"] 102 | else 103 | #[src.toString, dst.toString] 104 | 105 | proc { 106 | cmd := cmd 107 | args := args 108 | } 109 | 110 | def copyFolder (src dst : FilePath) : LogIO Unit := do 111 | let cmd := if getOS! == .windows then "robocopy" else "cp" 112 | let args := 113 | if getOS! == .windows then 114 | #[src.toString, dst.toString, "/E"] 115 | else 116 | #["-r", src.toString, dst.toString] 117 | 118 | let _out ← rawProc { 119 | cmd := cmd 120 | args := args 121 | } 122 | 123 | def removeFolder (dir : FilePath) : LogIO Unit := do 124 | let cmd := if getOS! == .windows then "cmd" else "rm" 125 | let args := 126 | if getOS! == .windows then 127 | #[s!"/c rmdir /s /q {dir.toString.replace "/" "\\"}"] 128 | else 129 | #["-rf", dir.toString] 130 | 131 | proc { 132 | cmd := cmd 133 | args := args 134 | } 135 | 136 | def removeFile (src: FilePath) : LogIO Unit := do 137 | proc { 138 | cmd := if getOS! == .windows then "cmd" else "rm" 139 | args := if getOS! == .windows then #[s!"/c del {src.toString.replace "/" "\\"}"] else #[src.toString] 140 | } 141 | 142 | package LeanCopilot where 143 | preferReleaseBuild := get_config? noCloudRelease |>.isNone 144 | buildArchive? := buildArchiveName 145 | precompileModules := true 146 | buildType := BuildType.release 147 | moreLinkArgs := #[s!"-L{__dir__}/.lake/build/lib", "-l" ++ if getOS! == .windows then "libctranslate2" else "ctranslate2"] 148 | weakLeanArgs := #[s!"--load-dynlib={__dir__}/.lake/build/lib/" ++ nameToSharedLib (if getOS! == .windows then "libctranslate2" else "ctranslate2")] 149 | 150 | 151 | @[default_target] 152 | lean_lib LeanCopilot { 153 | } 154 | 155 | 156 | lean_lib ModelCheckpointManager { 157 | } 158 | 159 | 160 | lean_exe download { 161 | root := `ModelCheckpointManager.Main 162 | } 163 | 164 | 165 | lean_lib LeanCopilotTests { 166 | globs := #[.submodules "LeanCopilotTests".toName] 167 | } 168 | 169 | 170 | private def nameToVersionedSharedLib (name : String) (v : String) : String := 171 | if Platform.isWindows then s!"lib{name}.{v}.dll" 172 | else if Platform.isOSX then s!"lib{name}.{v}.dylib" 173 | else s!"lib{name}.so.{v}" 174 | 175 | 176 | def afterReleaseSync {α : Type} (pkg : Package) (build : SpawnM (Job α)) : FetchM (Job α) := do 177 | if pkg.preferReleaseBuild ∧ pkg.name ≠ (← getRootPackage).name then 178 | (← pkg.optGitHubRelease.fetch).bindM fun _ => build 179 | else 180 | build 181 | 182 | 183 | def afterReleaseAsync {α : Type} (pkg : Package) (build : JobM α) : FetchM (Job α) := do 184 | if pkg.preferReleaseBuild ∧ pkg.name ≠ (← getRootPackage).name then 185 | (← pkg.optGitHubRelease.fetch).mapM fun _ => build 186 | else 187 | Job.async build 188 | 189 | 190 | def ensureDirExists (dir : FilePath) : IO Unit := do 191 | if !(← dir.pathExists) then 192 | IO.FS.createDirAll dir 193 | 194 | 195 | def gitClone (url : String) (cwd : Option FilePath) : LogIO Unit := do 196 | proc (quiet := true) { 197 | cmd := "git" 198 | args := if getOS! == .windows then #["clone", url] else #["clone", "--recursive", url] 199 | cwd := cwd 200 | } 201 | 202 | 203 | def runCmake (root : FilePath) (flags : Array String) : LogIO Unit := do 204 | assert! (← root.pathExists) ∧ (← (root / "CMakeLists.txt").pathExists) 205 | let buildDir := root / "build" 206 | if ← buildDir.pathExists then 207 | IO.FS.removeDirAll buildDir 208 | IO.FS.createDirAll buildDir 209 | let ok ← testProc { 210 | cmd := "cmake" 211 | args := flags ++ #[".."] 212 | cwd := buildDir 213 | } 214 | if ¬ ok then 215 | if flags.contains "-DWITH_CUDNN=ON" then -- Some users may have CUDA but not cuDNN. 216 | let ok' ← testProc { 217 | cmd := "cmake" 218 | args := (flags.erase "-DWITH_CUDNN=ON" |>.push "-DWITH_CUDNN=OFF") ++ #[".."] 219 | cwd := buildDir 220 | } 221 | if ok' then 222 | return () 223 | error "Failed to run cmake" 224 | 225 | 226 | target libopenblas pkg : FilePath := do 227 | afterReleaseAsync pkg do 228 | let rootDir := pkg.buildDir / "OpenBLAS" 229 | ensureDirExists rootDir 230 | let dst := pkg.sharedLibDir / (nameToSharedLib (if getOS! == .windows then "libopenblas" else "openblas")) 231 | createParentDirs dst 232 | let url := "https://github.com/OpenMathLib/OpenBLAS" 233 | 234 | let depTrace := Hash.ofString url 235 | setTrace depTrace 236 | buildFileUnlessUpToDate' dst do 237 | if getOS! == .windows then 238 | -- For Windows, the binary for OpenBLAS is provided. 239 | let _out ← rawProc { 240 | cmd := "curl" 241 | args := #["-L", "-o", "OpenBLAS.zip", "https://sourceforge.net/projects/openblas/files/v0.3.29/OpenBLAS-0.3.29_x64.zip/download"] 242 | cwd := pkg.buildDir 243 | } 244 | proc { 245 | cmd := "tar" 246 | args := #["-xvf", "OpenBLAS.zip"] 247 | cwd := pkg.buildDir 248 | } 249 | copySingleFile (pkg.buildDir / "bin" / "libopenblas.dll") (pkg.buildDir / "lib" / "libopenblas.dll") 250 | else 251 | logInfo s!"Cloning OpenBLAS from {url}" 252 | gitClone url pkg.buildDir 253 | 254 | let numThreads := max 4 $ min 32 (← nproc) 255 | let flags := #["NO_LAPACK=1", "NO_FORTRAN=1", s!"-j{numThreads}"] 256 | logInfo s!"Building OpenBLAS with `make{flags.foldl (· ++ " " ++ ·) ""}`" 257 | proc (quiet := true) { 258 | cmd := "make" 259 | args := flags 260 | cwd := rootDir 261 | } 262 | copySingleFile (rootDir / nameToSharedLib "openblas") dst 263 | -- TODO: Don't hardcode the version "0". 264 | let dst' := pkg.sharedLibDir / (nameToVersionedSharedLib "openblas" "0") 265 | copySingleFile dst dst' 266 | let _ := (← getTrace) 267 | return dst 268 | 269 | 270 | def getCt2CmakeFlags : IO (Array String) := do 271 | let mut flags := #["-DOPENMP_RUNTIME=NONE", "-DWITH_MKL=OFF"] 272 | 273 | match getOS! with 274 | | .macos => flags := flags ++ #["-DWITH_ACCELERATE=ON", "-DWITH_OPENBLAS=OFF"] 275 | | .linux => flags := flags ++ #["-DWITH_ACCELERATE=OFF", "-DWITH_OPENBLAS=ON", "-DOPENBLAS_INCLUDE_DIR=../../OpenBLAS", "-DOPENBLAS_LIBRARY=../../OpenBLAS/libopenblas.so"] 276 | | .windows => flags := flags 277 | 278 | -- [TODO] Temporary fix: Do not use CUDA even if it is available. 279 | -- if ← useCUDA then 280 | -- flags := flags ++ #["-DWITH_CUDA=ON", "-DWITH_CUDNN=ON"] 281 | -- else 282 | -- flags := flags ++ #["-DWITH_CUDA=OFF", "-DWITH_CUDNN=OFF"] 283 | 284 | return flags 285 | 286 | 287 | /- Download and build CTranslate2. Copy its C++ header files to `build/include` and shared libraries to `build/lib` -/ 288 | target libctranslate2 pkg : FilePath := do 289 | if getOS! == .linux ∨ getOS! == .windows then 290 | let openblas ← libopenblas.fetch 291 | let _ ← openblas.await 292 | 293 | afterReleaseAsync pkg do 294 | let dst := pkg.sharedLibDir / (nameToSharedLib (if getOS! == .windows then "libctranslate2" else "ctranslate2")) 295 | createParentDirs dst 296 | let ct2URL := "https://github.com/OpenNMT/CTranslate2" 297 | 298 | let depTrace := Hash.ofString ct2URL 299 | setTrace depTrace 300 | buildFileUnlessUpToDate' dst do 301 | logInfo s!"Cloning CTranslate2 from {ct2URL}" 302 | if !(← (pkg.buildDir / "CTranslate2").pathExists) then 303 | let _ ← gitClone ct2URL pkg.buildDir 304 | 305 | let ct2Dir := pkg.buildDir / "CTranslate2" 306 | if getOS! == .windows then 307 | ensureDirExists $ ct2Dir / "build" 308 | let _out ← rawProc { 309 | cmd := "curl" 310 | args := #["-L", "-o", "libctranslate2.dll", "https://drive.google.com/uc?export=download&id=1W6ZsbBG8gK9FRoMedNCKkg8qqS-bDa9U"] 311 | cwd := ct2Dir / "build" 312 | } 313 | else 314 | let flags ← getCt2CmakeFlags 315 | logInfo s!"Configuring CTranslate2 with `cmake{flags.foldl (· ++ " " ++ ·) ""} ..`" 316 | runCmake ct2Dir flags 317 | let numThreads := max 4 $ min 32 (← nproc) 318 | logInfo s!"Building CTranslate2 with `make -j{numThreads}`" 319 | proc { 320 | cmd := "make" 321 | args := #[s!"-j{numThreads}"] 322 | cwd := ct2Dir / "build" 323 | } 324 | 325 | ensureDirExists $ pkg.buildDir / "include" 326 | 327 | copySingleFile (pkg.buildDir / "CTranslate2" / "build" / nameToSharedLib (if getOS! == .windows then "libctranslate2" else "ctranslate2")) dst 328 | 329 | -- [TODO]: Don't hardcode the version "4". 330 | let dst' := pkg.sharedLibDir / (nameToVersionedSharedLib "ctranslate2" "4") 331 | copySingleFile dst dst' 332 | 333 | copyFolder (ct2Dir / "include" / "ctranslate2") (pkg.buildDir / "include" / "ctranslate2") 334 | 335 | copyFolder (ct2Dir / "include" / "nlohmann") (pkg.buildDir / "include" / "nlohmann") 336 | 337 | copyFolder (ct2Dir / "include" / "half_float") (pkg.buildDir / "include" / "half_float") 338 | 339 | removeFolder ct2Dir 340 | 341 | if getOS! == .windows then 342 | removeFolder (pkg.buildDir / "OPENBLAS") 343 | removeFile (pkg.buildDir / "OPENBLAS.zip") 344 | 345 | let _ := (← getTrace) 346 | return dst 347 | 348 | 349 | def buildCpp (pkg : Package) (path : FilePath) (dep : Job FilePath) : SpawnM (Job FilePath) := do 350 | let optLevel := if pkg.buildType == .release then "-O3" else "-O0" 351 | let flags := #["-fPIC", "-std=c++17", optLevel] 352 | let mut args := flags ++ #[ 353 | "-I", (← getLeanIncludeDir).toString, 354 | "-I", (pkg.buildDir / "include").toString, 355 | ] 356 | if getOS! == .windows then 357 | -- link the headers 358 | args := args ++ #[ 359 | "-I", (pkg.buildDir / "clang64/include/c++/v1").toString, 360 | "-I", (pkg.buildDir / "clang64/include").toString, 361 | "-I", (pkg.buildDir / "clang64/lib/clang/20/include").toString, 362 | ] 363 | let oFile := pkg.buildDir / (path.withExtension "o") 364 | let srcJob ← inputTextFile <| pkg.dir / path 365 | let leanPath ← Lake.getLeanSysroot 366 | 367 | buildFileAfterDep oFile (.collectList [srcJob, dep]) (extraDepTrace := computeHash flags) fun deps => 368 | compileO oFile deps[0]! args (if getOS! == .windows then s!"{leanPath}/bin/clang.exe" else "c++") 369 | 370 | 371 | target ct2.o pkg : FilePath := do 372 | let ct2 ← libctranslate2.fetch 373 | if getOS! == .windows then 374 | -- download headers from https://repo.msys2.org/mingw/clang64/ 375 | proc { 376 | cmd := "curl" 377 | args := #["-L", "-o", "headers.pkg.tar.zst", "https://repo.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst"] 378 | cwd := pkg.buildDir 379 | } 380 | proc { 381 | cmd := "curl" 382 | args := #["-L", "-o", "clang.pkg.tar.zst", "https://repo.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-20.1.3-1-any.pkg.tar.zst"] 383 | cwd := pkg.buildDir 384 | } 385 | proc { 386 | cmd := "curl" 387 | args := #["-L", "-o", "libcxx.pkg.tar.zst", "https://repo.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc%2B%2B-20.1.3-1-any.pkg.tar.zst"] 388 | cwd := pkg.buildDir 389 | } 390 | proc { 391 | cmd := "curl" 392 | args := #["-L", "-o", "pthread.pkg.tar.zst", "https://repo.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r724.g7e3f2dd90-1-any.pkg.tar.zst"] 393 | cwd := pkg.buildDir 394 | } 395 | proc { 396 | cmd := "tar" 397 | args := #["-xvf", "clang.pkg.tar.zst"] 398 | cwd := pkg.buildDir 399 | } 400 | proc { 401 | cmd := "tar" 402 | args := #["-xvf", "headers.pkg.tar.zst"] 403 | cwd := pkg.buildDir 404 | } 405 | proc { 406 | cmd := "tar" 407 | args := #["-xvf", "libcxx.pkg.tar.zst"] 408 | cwd := pkg.buildDir 409 | } 410 | proc { 411 | cmd := "tar" 412 | args := #["-xvf", "pthread.pkg.tar.zst"] 413 | cwd := pkg.buildDir 414 | } 415 | let build := buildCpp pkg "cpp/ct2.cpp" ct2 416 | afterReleaseSync pkg build 417 | 418 | 419 | extern_lib libleanffi pkg := do 420 | let name := nameToStaticLib "leanffi" 421 | let ct2O ← ct2.o.fetch 422 | buildStaticLib (pkg.sharedLibDir / name) #[ct2O] 423 | 424 | 425 | require batteries from git "https://github.com/leanprover-community/batteries.git" @ "main" 426 | require aesop from git "https://github.com/leanprover-community/aesop" @ "master" 427 | 428 | meta if get_config? env = some "dev" then -- dev is so not everyone has to build it 429 | require «doc-gen4» from git "https://github.com/leanprover/doc-gen4" @ "main" 430 | -------------------------------------------------------------------------------- /cpp/npy.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017-2023 Leon Merten Lohse 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | SOFTWARE. 21 | */ 22 | 23 | #ifndef NPY_HPP_ 24 | #define NPY_HPP_ 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | #include 43 | 44 | namespace npy { 45 | 46 | /* Compile-time test for byte order. 47 | If your compiler does not define these per default, you may want to define 48 | one of these constants manually. 49 | Defaults to little endian order. */ 50 | #if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ 51 | defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) 52 | const bool big_endian = true; 53 | #else 54 | const bool big_endian = false; 55 | #endif 56 | 57 | const size_t magic_string_length = 6; 58 | const std::array magic_string = {'\x93', 'N', 'U', 'M', 'P', 'Y'}; 59 | 60 | const char little_endian_char = '<'; 61 | const char big_endian_char = '>'; 62 | const char no_endian_char = '|'; 63 | 64 | constexpr std::array endian_chars = {little_endian_char, big_endian_char, no_endian_char}; 65 | constexpr std::array numtype_chars = {'f', 'i', 'u', 'c'}; 66 | 67 | constexpr char host_endian_char = (big_endian ? big_endian_char : little_endian_char); 68 | 69 | /* npy array length */ 70 | using ndarray_len_t = unsigned long int; 71 | using shape_t = std::vector; 72 | 73 | using version_t = std::pair; 74 | 75 | struct dtype_t { 76 | char byteorder; 77 | char kind; 78 | unsigned int itemsize; 79 | 80 | inline std::string str() const { 81 | std::stringstream ss; 82 | ss << byteorder << kind << itemsize; 83 | return ss.str(); 84 | } 85 | 86 | inline std::tuple tie() const { 87 | return std::tie(byteorder, kind, itemsize); 88 | } 89 | }; 90 | 91 | struct header_t { 92 | dtype_t dtype; 93 | bool fortran_order; 94 | shape_t shape; 95 | }; 96 | 97 | inline void write_magic(std::ostream &ostream, version_t version) { 98 | ostream.write(magic_string.data(), magic_string_length); 99 | ostream.put(version.first); 100 | ostream.put(version.second); 101 | } 102 | 103 | inline version_t read_magic(std::istream &istream) { 104 | std::array buf{}; 105 | istream.read(buf.data(), sizeof(buf)); 106 | 107 | if (!istream) { 108 | throw std::runtime_error("io error: failed reading file"); 109 | } 110 | 111 | if (!std::equal(magic_string.begin(), magic_string.end(), buf.begin())) 112 | throw std::runtime_error("this file does not have a valid npy format."); 113 | 114 | version_t version; 115 | version.first = buf[magic_string_length]; 116 | version.second = buf[magic_string_length + 1]; 117 | 118 | return version; 119 | } 120 | 121 | const std::unordered_map dtype_map = { 122 | {std::type_index(typeid(float)), {host_endian_char, 'f', sizeof(float)}}, 123 | {std::type_index(typeid(double)), {host_endian_char, 'f', sizeof(double)}}, 124 | {std::type_index(typeid(long double)), {host_endian_char, 'f', sizeof(long double)}}, 125 | {std::type_index(typeid(char)), {no_endian_char, 'i', sizeof(char)}}, 126 | {std::type_index(typeid(signed char)), {no_endian_char, 'i', sizeof(signed char)}}, 127 | {std::type_index(typeid(short)), {host_endian_char, 'i', sizeof(short)}}, 128 | {std::type_index(typeid(int)), {host_endian_char, 'i', sizeof(int)}}, 129 | {std::type_index(typeid(long)), {host_endian_char, 'i', sizeof(long)}}, 130 | {std::type_index(typeid(long long)), {host_endian_char, 'i', sizeof(long long)}}, 131 | {std::type_index(typeid(unsigned char)), {no_endian_char, 'u', sizeof(unsigned char)}}, 132 | {std::type_index(typeid(unsigned short)), {host_endian_char, 'u', sizeof(unsigned short)}}, 133 | {std::type_index(typeid(unsigned int)), {host_endian_char, 'u', sizeof(unsigned int)}}, 134 | {std::type_index(typeid(unsigned long)), {host_endian_char, 'u', sizeof(unsigned long)}}, 135 | {std::type_index(typeid(unsigned long long)), {host_endian_char, 'u', sizeof(unsigned long long)}}, 136 | {std::type_index(typeid(std::complex)), {host_endian_char, 'c', sizeof(std::complex)}}, 137 | {std::type_index(typeid(std::complex)), {host_endian_char, 'c', sizeof(std::complex)}}, 138 | {std::type_index(typeid(std::complex)), {host_endian_char, 'c', sizeof(std::complex)}}}; 139 | 140 | // helpers 141 | inline bool is_digits(const std::string &str) { return std::all_of(str.begin(), str.end(), ::isdigit); } 142 | 143 | template 144 | inline bool in_array(T val, const std::array &arr) { 145 | return std::find(std::begin(arr), std::end(arr), val) != std::end(arr); 146 | } 147 | 148 | inline dtype_t parse_descr(std::string typestring) { 149 | if (typestring.length() < 3) { 150 | throw std::runtime_error("invalid typestring (length)"); 151 | } 152 | 153 | char byteorder_c = typestring.at(0); 154 | char kind_c = typestring.at(1); 155 | std::string itemsize_s = typestring.substr(2); 156 | 157 | if (!in_array(byteorder_c, endian_chars)) { 158 | throw std::runtime_error("invalid typestring (byteorder)"); 159 | } 160 | 161 | if (!in_array(kind_c, numtype_chars)) { 162 | throw std::runtime_error("invalid typestring (kind)"); 163 | } 164 | 165 | if (!is_digits(itemsize_s)) { 166 | throw std::runtime_error("invalid typestring (itemsize)"); 167 | } 168 | unsigned int itemsize = std::stoul(itemsize_s); 169 | 170 | return {byteorder_c, kind_c, itemsize}; 171 | } 172 | 173 | namespace pyparse { 174 | 175 | /** 176 | Removes leading and trailing whitespaces 177 | */ 178 | inline std::string trim(const std::string &str) { 179 | const std::string whitespace = " \t"; 180 | auto begin = str.find_first_not_of(whitespace); 181 | 182 | if (begin == std::string::npos) return ""; 183 | 184 | auto end = str.find_last_not_of(whitespace); 185 | 186 | return str.substr(begin, end - begin + 1); 187 | } 188 | 189 | inline std::string get_value_from_map(const std::string &mapstr) { 190 | size_t sep_pos = mapstr.find_first_of(":"); 191 | if (sep_pos == std::string::npos) return ""; 192 | 193 | std::string tmp = mapstr.substr(sep_pos + 1); 194 | return trim(tmp); 195 | } 196 | 197 | /** 198 | Parses the string representation of a Python dict 199 | 200 | The keys need to be known and may not appear anywhere else in the data. 201 | */ 202 | inline std::unordered_map parse_dict(std::string in, const std::vector &keys) { 203 | std::unordered_map map; 204 | 205 | if (keys.size() == 0) return map; 206 | 207 | in = trim(in); 208 | 209 | // unwrap dictionary 210 | if ((in.front() == '{') && (in.back() == '}')) 211 | in = in.substr(1, in.length() - 2); 212 | else 213 | throw std::runtime_error("Not a Python dictionary."); 214 | 215 | std::vector> positions; 216 | 217 | for (auto const &value : keys) { 218 | size_t pos = in.find("'" + value + "'"); 219 | 220 | if (pos == std::string::npos) throw std::runtime_error("Missing '" + value + "' key."); 221 | 222 | std::pair position_pair{pos, value}; 223 | positions.push_back(position_pair); 224 | } 225 | 226 | // sort by position in dict 227 | std::sort(positions.begin(), positions.end()); 228 | 229 | for (size_t i = 0; i < positions.size(); ++i) { 230 | std::string raw_value; 231 | size_t begin{positions[i].first}; 232 | size_t end{std::string::npos}; 233 | 234 | std::string key = positions[i].second; 235 | 236 | if (i + 1 < positions.size()) end = positions[i + 1].first; 237 | 238 | raw_value = in.substr(begin, end - begin); 239 | 240 | raw_value = trim(raw_value); 241 | 242 | if (raw_value.back() == ',') raw_value.pop_back(); 243 | 244 | map[key] = get_value_from_map(raw_value); 245 | } 246 | 247 | return map; 248 | } 249 | 250 | /** 251 | Parses the string representation of a Python boolean 252 | */ 253 | inline bool parse_bool(const std::string &in) { 254 | if (in == "True") return true; 255 | if (in == "False") return false; 256 | 257 | throw std::runtime_error("Invalid python boolan."); 258 | } 259 | 260 | /** 261 | Parses the string representation of a Python str 262 | */ 263 | inline std::string parse_str(const std::string &in) { 264 | if ((in.front() == '\'') && (in.back() == '\'')) return in.substr(1, in.length() - 2); 265 | 266 | throw std::runtime_error("Invalid python string."); 267 | } 268 | 269 | /** 270 | Parses the string represenatation of a Python tuple into a vector of its items 271 | */ 272 | inline std::vector parse_tuple(std::string in) { 273 | std::vector v; 274 | const char seperator = ','; 275 | 276 | in = trim(in); 277 | 278 | if ((in.front() == '(') && (in.back() == ')')) 279 | in = in.substr(1, in.length() - 2); 280 | else 281 | throw std::runtime_error("Invalid Python tuple."); 282 | 283 | std::istringstream iss(in); 284 | 285 | for (std::string token; std::getline(iss, token, seperator);) { 286 | v.push_back(token); 287 | } 288 | 289 | return v; 290 | } 291 | 292 | template 293 | inline std::string write_tuple(const std::vector &v) { 294 | if (v.size() == 0) return "()"; 295 | 296 | std::ostringstream ss; 297 | 298 | if (v.size() == 1) { 299 | ss << "(" << v.front() << ",)"; 300 | } else { 301 | const std::string delimiter = ", "; 302 | // v.size() > 1 303 | ss << "("; 304 | std::copy(v.begin(), v.end() - 1, std::ostream_iterator(ss, delimiter.c_str())); 305 | ss << v.back(); 306 | ss << ")"; 307 | } 308 | 309 | return ss.str(); 310 | } 311 | 312 | inline std::string write_boolean(bool b) { 313 | if (b) 314 | return "True"; 315 | else 316 | return "False"; 317 | } 318 | 319 | } // namespace pyparse 320 | 321 | inline header_t parse_header(std::string header) { 322 | /* 323 | The first 6 bytes are a magic string: exactly "x93NUMPY". 324 | The next 1 byte is an unsigned byte: the major version number of the file 325 | format, e.g. x01. The next 1 byte is an unsigned byte: the minor version 326 | number of the file format, e.g. x00. Note: the version of the file format 327 | is not tied to the version of the numpy package. The next 2 bytes form a 328 | little-endian unsigned short int: the length of the header data HEADER_LEN. 329 | The next HEADER_LEN bytes form the header data describing the array's 330 | format. It is an ASCII string which contains a Python literal expression of 331 | a dictionary. It is terminated by a newline ('n') and padded with spaces 332 | ('x20') to make the total length of the magic string + 4 + HEADER_LEN be 333 | evenly divisible by 16 for alignment purposes. The dictionary contains 334 | three keys: 335 | 336 | "descr" : dtype.descr 337 | An object that can be passed as an argument to the numpy.dtype() 338 | constructor to create the array's dtype. "fortran_order" : bool Whether the 339 | array data is Fortran-contiguous or not. Since Fortran-contiguous arrays 340 | are a common form of non-C-contiguity, we allow them to be written directly 341 | to disk for efficiency. "shape" : tuple of int The shape of the array. For 342 | repeatability and readability, this dictionary is formatted using 343 | pprint.pformat() so the keys are in alphabetic order. 344 | */ 345 | 346 | // remove trailing newline 347 | if (header.back() != '\n') throw std::runtime_error("invalid header"); 348 | header.pop_back(); 349 | 350 | // parse the dictionary 351 | std::vector keys{"descr", "fortran_order", "shape"}; 352 | auto dict_map = npy::pyparse::parse_dict(header, keys); 353 | 354 | if (dict_map.size() == 0) throw std::runtime_error("invalid dictionary in header"); 355 | 356 | std::string descr_s = dict_map["descr"]; 357 | std::string fortran_s = dict_map["fortran_order"]; 358 | std::string shape_s = dict_map["shape"]; 359 | 360 | std::string descr = npy::pyparse::parse_str(descr_s); 361 | dtype_t dtype = parse_descr(descr); 362 | 363 | // convert literal Python bool to C++ bool 364 | bool fortran_order = npy::pyparse::parse_bool(fortran_s); 365 | 366 | // parse the shape tuple 367 | auto shape_v = npy::pyparse::parse_tuple(shape_s); 368 | 369 | shape_t shape; 370 | for (auto item : shape_v) { 371 | auto dim = static_cast(std::stoul(item)); 372 | shape.push_back(dim); 373 | } 374 | 375 | return {dtype, fortran_order, shape}; 376 | } 377 | 378 | inline std::string write_header_dict(const std::string &descr, bool fortran_order, const shape_t &shape) { 379 | std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order); 380 | std::string shape_s = npy::pyparse::write_tuple(shape); 381 | 382 | return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }"; 383 | } 384 | 385 | inline void write_header(std::ostream &out, const header_t &header) { 386 | std::string header_dict = write_header_dict(header.dtype.str(), header.fortran_order, header.shape); 387 | 388 | size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1; 389 | 390 | version_t version{1, 0}; 391 | if (length >= 255 * 255) { 392 | length = magic_string_length + 2 + 4 + header_dict.length() + 1; 393 | version = {2, 0}; 394 | } 395 | size_t padding_len = 16 - length % 16; 396 | std::string padding(padding_len, ' '); 397 | 398 | // write magic 399 | write_magic(out, version); 400 | 401 | // write header length 402 | if (version == version_t{1, 0}) { 403 | auto header_len = static_cast(header_dict.length() + padding.length() + 1); 404 | 405 | std::array header_len_le16{static_cast((header_len >> 0) & 0xff), 406 | static_cast((header_len >> 8) & 0xff)}; 407 | out.write(reinterpret_cast(header_len_le16.data()), 2); 408 | } else { 409 | auto header_len = static_cast(header_dict.length() + padding.length() + 1); 410 | 411 | std::array header_len_le32{ 412 | static_cast((header_len >> 0) & 0xff), static_cast((header_len >> 8) & 0xff), 413 | static_cast((header_len >> 16) & 0xff), static_cast((header_len >> 24) & 0xff)}; 414 | out.write(reinterpret_cast(header_len_le32.data()), 4); 415 | } 416 | 417 | out << header_dict << padding << '\n'; 418 | } 419 | 420 | inline std::string read_header(std::istream &istream) { 421 | // check magic bytes an version number 422 | version_t version = read_magic(istream); 423 | 424 | uint32_t header_length = 0; 425 | if (version == version_t{1, 0}) { 426 | std::array header_len_le16{}; 427 | istream.read(reinterpret_cast(header_len_le16.data()), 2); 428 | header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8); 429 | 430 | if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) { 431 | // TODO(llohse): display warning 432 | } 433 | } else if (version == version_t{2, 0}) { 434 | std::array header_len_le32{}; 435 | istream.read(reinterpret_cast(header_len_le32.data()), 4); 436 | 437 | header_length = 438 | (header_len_le32[0] << 0) | (header_len_le32[1] << 8) | (header_len_le32[2] << 16) | (header_len_le32[3] << 24); 439 | 440 | if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) { 441 | // TODO(llohse): display warning 442 | } 443 | } else { 444 | throw std::runtime_error("unsupported file format version"); 445 | } 446 | 447 | auto buf_v = std::vector(header_length); 448 | istream.read(buf_v.data(), header_length); 449 | std::string header(buf_v.data(), header_length); 450 | 451 | return header; 452 | } 453 | 454 | inline ndarray_len_t comp_size(const shape_t &shape) { 455 | ndarray_len_t size = 1; 456 | for (ndarray_len_t i : shape) size *= i; 457 | 458 | return size; 459 | } 460 | 461 | template 462 | struct npy_data { 463 | std::vector data = {}; 464 | shape_t shape = {}; 465 | bool fortran_order = false; 466 | }; 467 | 468 | template 469 | struct npy_data_ptr { 470 | const Scalar *data_ptr = nullptr; 471 | shape_t shape = {}; 472 | bool fortran_order = false; 473 | }; 474 | 475 | template 476 | inline npy_data read_npy(std::istream &in) { 477 | std::string header_s = read_header(in); 478 | 479 | // parse header 480 | header_t header = parse_header(header_s); 481 | 482 | // check if the typestring matches the given one 483 | const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar))); 484 | 485 | if (header.dtype.tie() != dtype.tie()) { 486 | throw std::runtime_error("formatting error: typestrings not matching"); 487 | } 488 | 489 | // compute the data size based on the shape 490 | auto size = static_cast(comp_size(header.shape)); 491 | 492 | npy_data data; 493 | 494 | data.shape = header.shape; 495 | data.fortran_order = header.fortran_order; 496 | 497 | data.data.resize(size); 498 | 499 | // read the data 500 | in.read(reinterpret_cast(data.data.data()), sizeof(Scalar) * size); 501 | 502 | return data; 503 | } 504 | 505 | template 506 | inline npy_data read_npy(const std::string &filename) { 507 | std::ifstream stream(filename, std::ifstream::binary); 508 | if (!stream) { 509 | throw std::runtime_error("io error: failed to open a file."); 510 | } 511 | 512 | return read_npy(stream); 513 | } 514 | 515 | template 516 | inline void write_npy(std::ostream &out, const npy_data &data) { 517 | // static_assert(has_typestring::value, "scalar type not 518 | // understood"); 519 | const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar))); 520 | 521 | header_t header{dtype, data.fortran_order, data.shape}; 522 | write_header(out, header); 523 | 524 | auto size = static_cast(comp_size(data.shape)); 525 | 526 | out.write(reinterpret_cast(data.data.data()), sizeof(Scalar) * size); 527 | } 528 | 529 | template 530 | inline void write_npy(const std::string &filename, const npy_data &data) { 531 | std::ofstream stream(filename, std::ofstream::binary); 532 | if (!stream) { 533 | throw std::runtime_error("io error: failed to open a file."); 534 | } 535 | 536 | write_npy(stream, data); 537 | } 538 | 539 | template 540 | inline void write_npy(std::ostream &out, const npy_data_ptr &data_ptr) { 541 | const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar))); 542 | 543 | header_t header{dtype, data_ptr.fortran_order, data_ptr.shape}; 544 | write_header(out, header); 545 | 546 | auto size = static_cast(comp_size(data_ptr.shape)); 547 | 548 | out.write(reinterpret_cast(data_ptr.data_ptr), sizeof(Scalar) * size); 549 | } 550 | 551 | template 552 | inline void write_npy(const std::string &filename, const npy_data_ptr &data_ptr) { 553 | std::ofstream stream(filename, std::ofstream::binary); 554 | if (!stream) { 555 | throw std::runtime_error("io error: failed to open a file."); 556 | } 557 | 558 | write_npy(stream, data_ptr); 559 | } 560 | 561 | // old interface 562 | 563 | // NOLINTBEGIN(*-avoid-c-arrays) 564 | template 565 | inline void SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, 566 | const unsigned long shape[], const Scalar *data) { 567 | const npy_data_ptr ptr{data, {shape, shape + n_dims}, fortran_order}; 568 | 569 | write_npy(filename, ptr); 570 | } 571 | 572 | template 573 | inline void SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, 574 | const unsigned long shape[], const std::vector &data) { 575 | SaveArrayAsNumpy(filename, fortran_order, n_dims, shape, data.data()); 576 | } 577 | 578 | template 579 | inline void LoadArrayFromNumpy(const std::string &filename, std::vector &shape, bool &fortran_order, 580 | std::vector &data) { 581 | const npy_data n_data = read_npy(filename); 582 | 583 | shape = n_data.shape; 584 | fortran_order = n_data.fortran_order; 585 | 586 | std::copy(n_data.data.begin(), n_data.data.end(), std::back_inserter(data)); 587 | } 588 | 589 | template 590 | inline void LoadArrayFromNumpy(const std::string &filename, std::vector &shape, 591 | std::vector &data) { 592 | bool fortran_order = false; 593 | LoadArrayFromNumpy(filename, shape, fortran_order, data); 594 | } 595 | // NOLINTEND(*-avoid-c-arrays) 596 | 597 | } // namespace npy 598 | 599 | #endif // NPY_HPP_ 600 | --------------------------------------------------------------------------------