├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ ├── feature_request.md │ └── how-to-question.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── lean_universe ├── dataset │ ├── __init__.py │ ├── args.py │ ├── data.py │ └── run.py └── utils │ ├── __init__.py │ ├── logger.py │ ├── params.py │ └── tools.py └── pyproject.toml /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | about: Submit a bug report to help us improve 4 | labels: 'bug, needs triage' 5 | --- 6 | 7 | ## 🐛 Bug 8 | 9 | 10 | 11 | ### To Reproduce 12 | 13 | Steps to reproduce the behavior (**always include the command you ran**): 14 | 15 | 1. Run cmd '....' 16 | 2. See error 17 | 18 | 19 | 20 | 21 | #### Code sample 22 | 24 | 25 | ### Expected behavior 26 | 27 | 28 | 29 | ### Environment 30 | 31 | - OS (e.g., Linux): 32 | - How you installed LeanUniverse (`pip`, source): 33 | - Build command you used (if compiling from source): 34 | - Python version: 35 | - Any other relevant information: 36 | 37 | ### Additional context 38 | 39 | 40 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 📚 Documentation/Typos 3 | about: Report an issue related to documentation or a typo 4 | labels: 'documentation, needs triage' 5 | --- 6 | 7 | ## 📚 Documentation 8 | 9 | For typos and doc fixes, please go ahead and: 10 | 11 | 1. Create an issue. 12 | 2. Fix the typo. 13 | 3. Submit a PR. 14 | 15 | Thanks! 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature Request 3 | about: Submit a proposal/request for a new feature 4 | labels: 'enhancement, help wanted, needs triage' 5 | --- 6 | 7 | ## 🚀 Feature Request 8 | 9 | 10 | ### Motivation 11 | 12 | 13 | 14 | ### Pitch 15 | 16 | 17 | 18 | ### Alternatives 19 | 20 | 21 | 22 | ### Additional context 23 | 24 | 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/how-to-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: ❓ Questions/Help 3 | about: If you have questions, please first search existing issues and docs 4 | labels: 'question, needs triage' 5 | --- 6 | 7 | ## ❓ Questions and Help 8 | 9 | ### Before asking: 10 | 1. search the issues. 11 | 2. search the docs. 12 | 13 | 14 | 15 | #### What is your question? 16 | 17 | #### Code 18 | 19 | 20 | 21 | #### What have you tried? 22 | 23 | #### What's your environment? 24 | 25 | - OS (e.g., Linux): 26 | - How you installed LeanUniverse (`pip`, source): 27 | - Build command you used (if compiling from source): 28 | - Python version: 29 | - Any other relevant information: 30 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Before submitting 2 | 3 | - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) 4 | - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? 5 | - [ ] Did you make sure to update the docs? 6 | - [ ] Did you write any new necessary tests? 7 | 8 | ## What does this PR do? 9 | Fixes # (issue). 10 | 11 | ## PR review 12 | Anyone in the community is free to review the PR once the tests have passed. 13 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged. 14 | 15 | ## Did you have fun? 16 | Make sure you had fun coding 🙃 17 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Run precommit 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | jobs: 7 | pre_commit_check: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4.1.1 11 | - uses: actions/setup-python@v5.0.0 12 | with: 13 | python-version: '3.10' 14 | - name: Install pre-commit 15 | run: pip install pre-commit 16 | - name: Run pre-commit testing 17 | run: pre-commit run --all-files 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | poetry.lock 163 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Usage 3 | # 4 | # Run the pre-commit hooks on all files using: 5 | # 6 | # $ pre-commit run --all-files 7 | # 8 | # Update this file: 9 | # $ pre-commit autoupdate; pre-commit install 10 | # 11 | # Install the pre-commit hook using: 12 | # 13 | # $ pre-commit install 14 | # 15 | repos: 16 | - repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v4.6.0 18 | hooks: 19 | - id: check-ast 20 | - id: check-byte-order-marker 21 | - id: check-case-conflict 22 | - id: check-docstring-first 23 | - id: check-executables-have-shebangs 24 | - id: check-json 25 | - id: check-yaml 26 | - id: debug-statements 27 | - id: detect-aws-credentials 28 | args: [--allow-missing-credentials] 29 | - id: detect-private-key 30 | - id: end-of-file-fixer 31 | - id: trailing-whitespace 32 | - id: mixed-line-ending 33 | - id: check-added-large-files 34 | args: [--maxkb=10000] 35 | - id: check-merge-conflict 36 | - id: check-symlinks 37 | - id: pretty-format-json 38 | args: [--autofix, --indent=2, --no-sort-keys] 39 | - id: requirements-txt-fixer 40 | - repo: https://github.com/timothycrosley/isort 41 | rev: 5.13.2 42 | hooks: 43 | - id: isort 44 | args: [--profile, black] 45 | - repo: https://github.com/psf/black 46 | rev: 24.8.0 47 | hooks: 48 | - id: black 49 | args: 50 | - --line-length=120 51 | language_version: python3 52 | - repo: https://github.com/jumanjihouse/pre-commit-hook-yamlfmt 53 | rev: 0.2.3 54 | hooks: 55 | - id: yamlfmt 56 | - repo: https://github.com/PyCQA/flake8 57 | rev: 7.1.1 58 | hooks: 59 | - id: flake8 60 | args: [--max-line-length=120, --ignore=E203] 61 | exclude: ^lean_universe/utils/params.py 62 | additional_dependencies: 63 | - flake8-bugbear 64 | - flake8-comprehensions 65 | - flake8-simplify 66 | - repo: https://github.com/asottile/yesqa 67 | rev: v1.5.0 68 | hooks: 69 | - id: yesqa 70 | additional_dependencies: 71 | - flake8-bugbear 72 | - flake8-comprehensions 73 | - flake8-docstrings 74 | - repo: https://github.com/nametake/pre-commit-prototool 75 | rev: v0.1.0 76 | hooks: 77 | - id: prototool-lint 78 | - id: prototool-format-fix 79 | - repo: https://github.com/BlankSpruce/gersemi 80 | rev: 0.15.1 81 | hooks: 82 | - id: gersemi # CMake formatter 83 | - repo: https://github.com/shellcheck-py/shellcheck-py 84 | rev: v0.10.0.1 85 | hooks: 86 | - id: shellcheck 87 | - repo: https://github.com/commitizen-tools/commitizen 88 | rev: v3.29.0 89 | hooks: 90 | - id: commitizen 91 | - id: commitizen-branch 92 | stages: [push] 93 | - repo: https://github.com/hadialqattan/pycln 94 | rev: v2.4.0 95 | hooks: 96 | - id: pycln 97 | - repo: https://github.com/asottile/pyupgrade 98 | rev: v3.17.0 99 | hooks: 100 | - id: pyupgrade 101 | args: [--py39-plus, --keep-runtime-typing] 102 | - repo: https://github.com/asottile/blacken-docs 103 | rev: 1.18.0 104 | hooks: 105 | - id: blacken-docs 106 | - repo: https://github.com/pre-commit/mirrors-mypy 107 | rev: v1.11.2 108 | hooks: 109 | - id: mypy 110 | #verbose: true 111 | args: [--ignore-missing-imports, --explicit-package-bases, --show-error-codes] 112 | additional_dependencies: [types-requests] 113 | - repo: https://github.com/PyCQA/bandit 114 | rev: 1.7.9 115 | hooks: 116 | - id: bandit 117 | exclude: ^lean_universe/utils/params.py 118 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/LeanUniverse/1c9f6ca76738c82eb42175b1d9b5bbe6d478ae92/CHANGELOG.md -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to LeanUniverse 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | 26 | ## Pre-commit hooks 27 | In order to ensure your code lints, there are pre-commit hooks configured in the repository which you can install. 28 | After installation, they will automatically run each time you commit. 29 | An abbreviated guide is given below; for more information, refer to [the offical pre-commit documentation](https://pre-commit.com/). 30 | 31 | ### Installation 32 | ``` 33 | pip install pre-commit 34 | pre-commit install 35 | ``` 36 | 37 | ### Usage 38 | Just commit your changes: 39 | ``` 40 | git commit -m "My informative commit message" 41 | ``` 42 | 43 | If there was a failure, you will get feedback 44 | ``` 45 | [INFO] Initializing environment for https://github.com/PyCQA/flake8. 46 | [INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks. 47 | [INFO] Once installed this environment will be reused. 48 | [INFO] This may take a few minutes... 49 | [INFO] Installing environment for https://github.com/PyCQA/flake8. 50 | [INFO] Once installed this environment will be reused. 51 | [INFO] This may take a few minutes... 52 | Trim Trailing Whitespace.................................................Failed 53 | - hook id: trailing-whitespace 54 | - exit code: 1 55 | - files were modified by this hook 56 | Fixing examples/nllb/modeling/wmt15_benchmark/eval_langs2.sh 57 | Fix End of Files.........................................................Failed 58 | - hook id: end-of-file-fixer 59 | - exit code: 1 60 | - files were modified by this hook 61 | Fixing examples/few_shot/scripts/schedule_jobs_few_shot.py 62 | flake8...................................................................Passed 63 | ``` 64 | 65 | Certain hooks modify your files to comply. 66 | To include these modifications, you will need to add them (i.e. `git add ...`) and commit again. 67 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LeanUniverse: A Library for Consistent and Scalable Lean4 Dataset Management 2 | LeanUniverse is a package designed to create comprehensive datasets from Lean4 repositories on Github. Its goal is to simplify and standardize the generation of training datasets for AI models. 3 | The key features include: 4 | - _Consistency_: LeanUniverse ensures that all collected repositories are consistent and can be linked to the same version of dependencies (mathlib). This guarantees reliability and compatibility across datasets created with the library. 5 | - _License Filtering_: Users are empowered to define filters for acceptable licenses and users are responsible for ensuring that their usage of third-party content (github repositories) complies with the associated license and GitHub’s terms of service. 6 | - _Caching_: The library incorporates a caching mechanism, enhancing efficiency by reducing redundant computations. This feature enables recurrent updates and incremental growth of datasets over time. 7 | 8 | 9 | ## Getting Started 10 | LeanUniverse uses [Poetry](https://python-poetry.org/) to manage project dependencies and virtual environments. Follow these steps to get started: 11 | 12 | 1. Clone the LeanUniverse repository to your local machine: 13 | ``` 14 | git clone https://github.com/your-repo/LeanUniverse.git 15 | cd LeanUniverse 16 | ``` 17 | 2. Ensure you have Poetry installed for managing dependencies and virtual environments. You can install Poetry using one of the following method: 18 | ``` 19 | pip install poetry 20 | ``` 21 | For other installation methods, refer to the [Poetry installation guide](https://python-poetry.org/docs/). 22 | 3. Install all required dependencies by running: 23 | ``` 24 | poetry install 25 | ``` 26 | 4. Activate the environment created by Poetry: 27 | ``` 28 | poetry shell 29 | ``` 30 | This sets up a proper shell environment with all dependencies installed. 31 | 5. Now, you’re ready to use LeanUniverse! Execute the main script or any specific functionality: 32 | ``` 33 | python lean_universe/dataset/run.py 34 | ``` 35 | 6. You can add or remove dependencies using poetry. 36 | To add a new dependency: 37 | ``` 38 | poetry add 39 | ``` 40 | To remove a dependency: 41 | ``` 42 | poetry remove 43 | ``` 44 | For more information on Poetry and its features, refer to the [official Poetry documentation](https://python-poetry.org/docs/). 45 | 46 | 47 | ## Development 48 | We will be using the using `poetry` to manage the project dependencies and the virtual environments. Once you clone the repo, you should run `poetry install`. To run the code you need to run `poetry shell` to get the proper shell environment with everything installed. To add a new dependency you can run `poetry add numpy` and the same way you can remove the dependency. 49 | 50 | ## License 51 | The model is licensed under the [CC-BY-NC 4.0](LICENSE). Use of this package for commercial purposes is prohibited. 52 | 53 | __Important__: Users are responsible for ensuring that their usage of third-party content (github repositories) complies with the associated license and GitHub’s terms of service. LeanUniverse allows filtering of licenses. 54 | 55 | ## Citation 56 | Please cite as: 57 | 58 | ``` bibtex 59 | @inproceedings{ahm2025leanuniverse, 60 | title = {LeanUniverse: A Library for Consistent and Scalable Lean4 Dataset Management}, 61 | author = {Aram H. Markosyan, Gabriel Synnaeve, Hugh Leather}, 62 | -------------------------------------------------------------------------------- /lean_universe/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /lean_universe/dataset/args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | import getpass 9 | import os 10 | from dataclasses import dataclass, field 11 | from pathlib import Path 12 | 13 | from psutil import cpu_count 14 | 15 | from lean_universe.utils.params import Params 16 | 17 | # from logging import getLogger 18 | 19 | 20 | # logger = getLogger() 21 | 22 | # @cache 23 | # def get_git_repo_root() -> str: 24 | # """ 25 | # Returns the top-level directory of the current Git repository. 26 | # Returns: 27 | # str: The absolute path of the top-level directory of the Git repository. 28 | # Returns None if the current directory is not part of a Git repository. 29 | # """ 30 | 31 | # try: 32 | # # Run the git command to get the top-level directory 33 | # repo_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip() # nosec 34 | # return repo_root 35 | # except subprocess.CalledProcessError: 36 | # # Handle the case where the current directory is not part of a Git repo 37 | # logger.error("Current directory is not a Git repository.") 38 | # return "" 39 | 40 | USER = getpass.getuser() 41 | 42 | # ROOT_WORKING_DIR = Path(get_git_repo_root()) 43 | # if not Path(""): 44 | ROOT_WORKING_DIR = Path(__file__).parent.parent.parent 45 | # else: 46 | # ROOT_WORKING_DIR = Path(ROOT_WORKING_DIR) 47 | 48 | cache_path = os.getenv("LEANUNIVERSE_CACHE", None) 49 | if cache_path: 50 | CACHE_DIR = Path(f"{cache_path}") 51 | else: 52 | CACHE_DIR = ROOT_WORKING_DIR / "cache" 53 | 54 | TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 55 | 56 | LEAN3_REPO = "leanpkg.toml" 57 | LEAN4_LAKEFILE = "lakefile.lean" 58 | LEAN4_TOOLCHAIN = "lean-toolchain" 59 | 60 | LARGE_DATASET = 100000 61 | LARGE_DATASET_TEST_VAL_PERCENT = 2 62 | MEDIUM_DATASET = 1000 63 | MEDIUM_DATASET_TEST_VAL_PERCENT = 7 64 | SMALL_DATASET = 100 65 | SMALL_DATASET_TEST_VAL_PERCENT = 10 66 | TEST_VAL_MIN_SIZE = 5 67 | 68 | 69 | @dataclass 70 | class EvalArgs(Params): 71 | cache_dir: str = "" 72 | working_dir: str = str(ROOT_WORKING_DIR) 73 | dataset_export_dir: str = "" 74 | raw_dataset_dir: str = "" # raw files to compose the dataset 75 | repos_dir: str = "" # cloned repos 76 | repos_included: list = field(default_factory=list) # repos to include in the dataset 77 | max_num_repos: int = 1 78 | ld_max_num_procs: int = 32 79 | large_dataset: int = LARGE_DATASET 80 | large_dataset_test_val_percent: int = LARGE_DATASET_TEST_VAL_PERCENT 81 | medium_dataset: int = MEDIUM_DATASET 82 | medium_dataset_test_val_percent: int = MEDIUM_DATASET_TEST_VAL_PERCENT 83 | small_dataset: int = SMALL_DATASET 84 | small_dataset_test_val_percent: int = SMALL_DATASET_TEST_VAL_PERCENT 85 | test_val_min_size: int = TEST_VAL_MIN_SIZE 86 | # dependencies_build_dir: str = str(DEPENDENCIES_BUILD_DIR) 87 | # dependencies_bashrc_extension: str = str(BASHRC_EXTENSION_FILE) 88 | log_file: str = "" 89 | timestamp: str = TIMESTAMP 90 | # log_file_dependencies: str = str(LOG_FILE_DEPENDENCIES) 91 | # log_failed_repos: str = str(LOG_FAILED_REPOS) 92 | # lean_extractor_file: str = str(LEAN4_DATA_EXTRACTOR_PATH) 93 | # lean_connector_file: str = str(LEAN4_DATA_CONNECTOR_PATH) 94 | num_threads: int = cpu_count(logical=False) 95 | 96 | def __post_init__(self): 97 | if self.cache_dir == "": 98 | self.cache_dir = CACHE_DIR 99 | else: 100 | self.cache_dir = Path(self.cache_dir) 101 | self.cache_dir = Path(self.cache_dir) 102 | if self.log_file == "": 103 | self.log_file = self.cache_dir / f"logs/{TIMESTAMP}_{USER}_lean_universe.log" 104 | else: 105 | self.log_file = Path(self.log_file) 106 | if self.dataset_export_dir == "": 107 | self.dataset_export_dir = self.cache_dir / "dataset" 108 | else: 109 | self.dataset_export_dir = Path(self.dataset_export_dir) 110 | if self.raw_dataset_dir == "": 111 | self.raw_dataset_dir = self.cache_dir / "raw" 112 | else: 113 | self.raw_dataset_dir = Path(self.raw_dataset_dir) 114 | if self.repos_dir == "": 115 | self.repos_dir = self.cache_dir / "repos" 116 | else: 117 | self.repos_dir = Path(self.repos_dir) 118 | self.log_file = Path(self.log_file) 119 | self.working_dir = Path(self.working_dir) 120 | self.dataset_export_dir = Path(self.dataset_export_dir) 121 | self.raw_dataset_dir = Path(self.raw_dataset_dir) 122 | self.repos_dir = Path(self.repos_dir) 123 | self.timestamp = TIMESTAMP 124 | # self.log_failed_repos = Path(self.log_failed_repos) 125 | 126 | self.working_dir.mkdir(parents=True, exist_ok=True) 127 | self.raw_dataset_dir.mkdir(parents=True, exist_ok=True) 128 | self.repos_dir.mkdir(parents=True, exist_ok=True) 129 | self.dataset_export_dir.mkdir(parents=True, exist_ok=True) 130 | self.cache_dir.mkdir(parents=True, exist_ok=True) 131 | (self.cache_dir / "logs").mkdir(parents=True, exist_ok=True) 132 | 133 | if not self.log_file.is_file(): 134 | self.log_file.touch() 135 | 136 | # if not self.log_failed_repos.is_file(): 137 | # self.log_failed_repos.touch() 138 | -------------------------------------------------------------------------------- /lean_universe/dataset/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import random 9 | import shutil 10 | import time 11 | from collections import defaultdict 12 | from copy import copy 13 | from datetime import datetime 14 | from logging import getLogger 15 | from pathlib import Path 16 | from subprocess import CalledProcessError # nosec 17 | from typing import Optional, Union 18 | 19 | import lean_dojo 20 | import networkx as nx 21 | from github import Github, RateLimitExceededException, Repository 22 | from lean_dojo import LeanGitRepo, TracedRepo, TracedTheorem, constants, trace 23 | from lean_dojo.constants import LEAN4_PACKAGES_DIR 24 | from tqdm import tqdm 25 | 26 | from lean_universe.dataset.args import LEAN3_REPO, LEAN4_LAKEFILE, LEAN4_TOOLCHAIN, USER 27 | from lean_universe.utils.tools import ( 28 | clone_and_checkout, 29 | execute_and_capture, 30 | get_github, 31 | reset_and_pull, 32 | url_to_repo, 33 | ) 34 | 35 | logger = getLogger() 36 | 37 | SPLIT_NAME = str # train/val/test 38 | SPLIT = dict[SPLIT_NAME, list[TracedTheorem]] 39 | SPLIT_STRATEGY = str 40 | 41 | random.seed(3407) 42 | 43 | 44 | def get_lean_repos( 45 | GITHUB: Github, 46 | repos_included: Optional[list[str]] = None, 47 | max_num: int = -1, 48 | query: str = "lean", 49 | language: str = "lean", 50 | ): 51 | """ 52 | Retrieves Lean repositories from GitHub based on the specified query and language. 53 | Args: 54 | repos_included (Optional[list[str]], optional): A list of repository names to include. Defaults to None. 55 | max_num (int, optional): The maximum number of repositories to retrieve. Defaults to None. 56 | query (str, optional): The search query. Defaults to "lean". 57 | language (str, optional): The programming language. Defaults to "lean". 58 | Returns: 59 | List[Tuple[str, str]]: A list of tuples containing the full name and HTML URL of the repositories. 60 | """ 61 | 62 | logger.info("Searching for Lean repositories on GitHub...") 63 | query = f"{query} language:{language}" 64 | results = [] 65 | if repos_included: 66 | for repo_name in repos_included: 67 | repo = url_to_repo(GITHUB=GITHUB, url=repo_name) 68 | if repo: 69 | results.append((repo.full_name, repo.html_url)) 70 | if max_num > 0 and len(results) >= max_num: 71 | break 72 | 73 | logger.info(f"Found {len(results)} repositories in the included list.") 74 | if max_num > 0 and len(results) >= max_num: 75 | return results 76 | else: 77 | max_num = max_num - len(results) if max_num > 0 else max_num 78 | 79 | repositories = GITHUB.search_repositories(query=query, sort="stars", order="desc") 80 | try: 81 | for repo in tqdm(repositories): 82 | results.append((repo.full_name, repo.html_url)) 83 | if max_num > 0 and len(results) >= max_num: 84 | break 85 | if len(results) % 1000 == 0: 86 | logger.info(f"Fetched {len(results)} results so far...") 87 | time.sleep(10) # Sleep for 10 seconds every 1000 results to manage rate limits 88 | except RateLimitExceededException: 89 | print("GitHub API rate limit exceeded. Please wait and try again later.") 90 | return results 91 | 92 | 93 | class Dataset: 94 | def __init__( 95 | self, 96 | cache_dir: Path, 97 | log_file: Path, 98 | dataset_export_dir: Path, 99 | raw_dataset_dir: Path, 100 | repos_dir: Path, 101 | timestamp: str, 102 | github: Github, 103 | ld_cache_dir: Optional[Path], 104 | repos_excluded: Optional[list[str]] = None, 105 | repos_included: Optional[list[str]] = None, 106 | licesnes_excluded: Optional[list[str]] = None, 107 | ld_max_num_procs: int = -1, 108 | large_dataset: int = -1, 109 | large_dataset_test_val_percent: int = -1, 110 | medium_dataset: int = -1, 111 | medium_dataset_test_val_percent: int = -1, 112 | small_dataset: int = -1, 113 | small_dataset_test_val_percent: int = -1, 114 | test_val_min_size: int = -1, 115 | ): 116 | self.cache_dir = cache_dir 117 | self.database_file = cache_dir / "database.json" 118 | 119 | self.dataset_export_dir = dataset_export_dir 120 | self.raw_dataset_dir = raw_dataset_dir 121 | self.repos_dir = repos_dir 122 | self.timestamp = timestamp 123 | self.repos_excluded = repos_excluded 124 | self.repos_included = repos_included 125 | self.licenses_excluded = licesnes_excluded 126 | self.log_file = log_file 127 | self.ld_max_num_procs = ld_max_num_procs 128 | self.ld_cache_dir = ld_cache_dir 129 | 130 | self.large_dataset = large_dataset 131 | self.large_dataset_test_val_percent = large_dataset_test_val_percent 132 | self.medium_dataset = medium_dataset 133 | self.medium_dataset_test_val_percent = medium_dataset_test_val_percent 134 | self.small_dataset = small_dataset 135 | self.small_dataset_test_val_percent = small_dataset_test_val_percent 136 | 137 | self.github = github 138 | 139 | self.database: dict[str, dict] = {} 140 | self.database["report"] = {} 141 | self.database["repos"] = {} 142 | self.database["cache_repos"] = {} 143 | self.database["cache_report"] = {} 144 | if self.database_file.is_file(): 145 | logger.info(f"Loading database file from {self.database_file}") 146 | with self.database_file.open("r", encoding="utf-8") as file: 147 | self.database = json.load(file) 148 | for repo in self.database["repos"]: 149 | if self.database["repos"][repo]["is_correct"]: 150 | self.database["cache_repos"][repo] = self.database["repos"][repo] 151 | self.database["repos"] = {} 152 | self.database["cache_report"] = self.database["report"] 153 | self.database["report"] = {} 154 | 155 | self.database["report"]["cache_dir"] = self.cache_dir.as_posix() 156 | self.database["report"]["log_file"] = self.log_file.as_posix() 157 | self.database["report"]["raw_dataset_dir"] = self.raw_dataset_dir.as_posix() 158 | self.database["report"]["repos_dir"] = self.repos_dir.as_posix() 159 | self.database["report"]["dataset_export_dir"] = self.dataset_export_dir.as_posix() 160 | self.database["report"]["timestamp"] = self.timestamp 161 | self.database["report"]["user"] = USER 162 | self.database["report"]["repos_excluded"] = self.repos_excluded 163 | self.database["report"]["repos_included"] = self.repos_included 164 | self.database["report"]["licenses_excluded"] = self.licenses_excluded 165 | 166 | self.new_repos: set[str] = set() 167 | self.updated_repos: set[str] = set() 168 | self.lean_ready_repos: set[str] = set() 169 | 170 | def report(self) -> None: 171 | logger.info(f"Saving database file to {self.database_file}.") 172 | with open(self.database_file, "w", encoding="utf-8") as file: 173 | json.dump(self.database, file, indent=4) 174 | 175 | def __get_repo_info(self, repo: Repository) -> tuple[dict, bool]: 176 | info = {} 177 | print_prefix = f"[{repo.full_name}]" 178 | 179 | contents = repo.get_contents("") 180 | contents_names = [content_file.name for content_file in contents] 181 | 182 | info["date_updated"] = repo.updated_at.strftime("%Y-%m-%d %H:%M:%S") 183 | info["date_created"] = repo.created_at.strftime("%Y-%m-%d %H:%M:%S") 184 | info["full_name"] = repo.full_name 185 | info["name"] = repo.name 186 | info["description"] = repo.description 187 | info["branch"] = repo.default_branch 188 | info["sha"] = repo.get_branch(repo.default_branch).commit.sha 189 | info["star_count"] = repo.stargazers_count 190 | info["fork_count"] = repo.forks_count 191 | info["license"] = repo.license.spdx_id if repo.license else "None" 192 | info["path"] = (self.repos_dir / repo.full_name.replace("/", "_")).as_posix() 193 | info["raw_dataset_path"] = (self.raw_dataset_dir / repo.full_name.replace("/", "_") / info["sha"]).as_posix() 194 | info["dataset_export_dir"] = ( 195 | self.dataset_export_dir / repo.full_name.replace("/", "_") / info["sha"] 196 | ).as_posix() 197 | info["is_correct"] = False 198 | 199 | if repo.fork: 200 | info["is_fork"] = True 201 | logger.warning(f"{print_prefix} Repo is a fork.") 202 | return (info, False) 203 | 204 | if LEAN3_REPO in contents_names: 205 | info["lean3_repo"] = True 206 | logger.warning(f"{print_prefix} Repo is a Lean3 repo.") 207 | return (info, False) 208 | 209 | if LEAN4_LAKEFILE not in contents_names: 210 | info["lean4_lakefile_missing"] = True 211 | logger.warning(f"{print_prefix} Repo does not have a lakefile.") 212 | return (info, False) 213 | 214 | if LEAN4_TOOLCHAIN not in contents_names: 215 | info["lean4_toolchain_missing"] = True 216 | logger.warning(f"{print_prefix} Repodoes not have a lean-toolchain.") 217 | return (info, False) 218 | 219 | # Get the Lean version 220 | toolchain = repo.get_contents(LEAN4_TOOLCHAIN) 221 | info["toolchain"] = toolchain.decoded_content.decode("utf-8") 222 | info["is_correct"] = True 223 | return (info, True) 224 | 225 | def __license_excluded(self, license: str) -> bool: 226 | if self.licenses_excluded: 227 | for license_exclude in self.licenses_excluded: 228 | if license_exclude in license: 229 | return True 230 | return False 231 | 232 | def filter_new_repos(self, repos: list) -> None: 233 | self.database["report"]["repos_found"] = len(repos) 234 | self.database["report"]["repos_incorrect_lean"] = 0 235 | self.database["report"]["repos_incorrect_exclusion"] = 0 236 | self.database["report"]["repos_incorrect_license"] = 0 237 | self.database["report"]["repos_correct"] = 0 238 | 239 | for repo in tqdm(repos): 240 | url = repo[1] 241 | rep: Repository = url_to_repo(self.github, repo[1]) 242 | print_prefix = f"[{rep.full_name}]" 243 | logger.info(f"{print_prefix} Processing ...") 244 | 245 | info, correct = self.__get_repo_info(rep) 246 | if not correct: 247 | self.database["report"]["repos_incorrect_lean"] += 1 248 | self.database["repos"][url] = info 249 | logger.info(f"{print_prefix} is not a Lean4 repo.") 250 | continue 251 | 252 | # Filter out repositories with excluded licenses 253 | if self.__license_excluded(info["license"]): 254 | info["Excluded_license"] = True 255 | info["is_correct"] = False 256 | self.database["report"]["repos_incorrect_license"] += 1 257 | self.database["repos"][url] = info 258 | logger.info(f"{print_prefix} has an excluded license. Skipping...") 259 | continue 260 | 261 | # Filter out repositories that are in the exclude list 262 | if self.repos_excluded and url in self.repos_excluded: 263 | info["Excluded_repo"] = True 264 | info["is_correct"] = False 265 | self.database["report"]["repos_incorrect_exclusion"] += 1 266 | self.database["repos"][url] = info 267 | logger.info(f"{print_prefix} is in the exclude list. Skipping...") 268 | continue 269 | 270 | self.database["repos"][url] = info 271 | self.database["report"]["repos_correct"] += 1 272 | self.new_repos.add(url) 273 | 274 | correct = self.database["report"]["repos_correct"] 275 | found = self.database["report"]["repos_found"] 276 | logger.info(f"Using {correct}/{found} repositories.") 277 | 278 | def clone_or_pull_repos(self) -> None: 279 | for repo in tqdm(self.new_repos): 280 | repo_info = self.database["repos"][repo] 281 | print_prefix = f"[{repo_info['full_name']}]" 282 | if repo in self.database["cache_repos"]: 283 | if self.database["cache_repos"][repo]["sha"] != repo_info["sha"]: 284 | if Path(repo_info["path"]).is_dir(): 285 | reset_and_pull( 286 | repo, f"""[{repo_info["full_name"]}]""", repo_info["branch"], Path(repo_info["path"]) 287 | ) 288 | logger.info(f"{print_prefix} Repository has been updated. Resetting and pulling...") 289 | else: 290 | clone_and_checkout( 291 | repo, f"""[{repo_info["full_name"]}]""", repo_info["branch"], Path(repo_info["path"]) 292 | ) 293 | logger.info(f"{print_prefix} Repository has been updated. Cloning...") 294 | self.updated_repos.add(repo) 295 | else: 296 | if not Path(repo_info["path"]).is_dir(): 297 | clone_and_checkout( 298 | repo, f"""[{repo_info["full_name"]}]""", repo_info["branch"], Path(repo_info["path"]) 299 | ) 300 | logger.info(f"{print_prefix} Repository has been updated. Cloning...") 301 | self.updated_repos.add(repo) 302 | 303 | else: 304 | if Path(repo_info["path"]).is_dir(): 305 | reset_and_pull( 306 | repo, f"""[{repo_info["full_name"]}]""", repo_info["branch"], Path(repo_info["path"]) 307 | ) 308 | else: 309 | clone_and_checkout( 310 | repo, f"""[{repo_info["full_name"]}]""", repo_info["branch"], Path(repo_info["path"]) 311 | ) 312 | self.updated_repos.add(repo) 313 | self.database["report"]["repos_correct_updated"] = len(self.updated_repos) 314 | logger.info(f"""[Clone or Fetch] Updated {self.database["report"]["repos_correct_updated"]} repositories.""") 315 | 316 | def __check_for_incorrect_datasets(self) -> None: 317 | for repo in self.new_repos: 318 | if not Path(self.database["repos"][repo]["raw_dataset_path"]).is_dir(): 319 | self.updated_repos.add(repo) 320 | 321 | def __build_package(self, dir: Path, name: str) -> bool: 322 | try: 323 | logger.info(f"[{name}] Get caches.") 324 | command = "lake exe cache get" 325 | execute_and_capture(command, dir) 326 | except CalledProcessError as ex: 327 | logger.warning(f"[{name}] Failed to get caches. {ex.stderr}") 328 | 329 | # return False 330 | # raise ex 331 | 332 | try: 333 | logger.info(f"[{name}] Building package.") 334 | command = "lake build" 335 | execute_and_capture(command, dir) 336 | except CalledProcessError as ex: 337 | logger.error(f"[{name}] Failed to build package. {ex.stderr}") 338 | return False 339 | # raise ex 340 | 341 | try: 342 | logger.info(f"[{name}] Upgrading package.") 343 | command = "lake update" 344 | execute_and_capture(command, dir) 345 | except CalledProcessError as ex: 346 | logger.error(f"[{name}] Failed to upgrade package. {ex.stderr}") 347 | return False 348 | # raise ex 349 | 350 | try: 351 | logger.info(f"[{name}] Building package.") 352 | command = "lake build" 353 | execute_and_capture(command, dir) 354 | except CalledProcessError as ex: 355 | logger.error(f"[{name}] Failed to build package. {ex.stderr}") 356 | return False 357 | # raise ex 358 | 359 | logger.info(f"[{name}] Package built and updated successfully.") 360 | return True 361 | 362 | def build_lake(self) -> None: 363 | self.__check_for_incorrect_datasets() 364 | for repo in tqdm(self.updated_repos): 365 | repo_info = self.database["repos"][repo] 366 | if self.__build_package(Path(repo_info["path"]), repo_info["full_name"]): 367 | self.database["repos"][repo]["builds"] = True 368 | self.lean_ready_repos.add(repo) 369 | else: 370 | self.database["repos"][repo]["builds"] = False 371 | 372 | def configure_leandojo(self) -> None: 373 | if self.ld_max_num_procs > 0: 374 | constants.MAX_NUM_PROCS = self.ld_max_num_procs 375 | logger.info(f"[LeanDojo] Using {constants.MAX_NUM_PROCS} processes.") 376 | if self.ld_cache_dir: 377 | constants.CACHE_DIR = self.ld_cache_dir 378 | logger.info(f"[LeanDojo] Using cache directory {constants.CACHE_DIR}") 379 | 380 | self.database["report"]["LeanDojo"] = {} 381 | self.database["report"]["LeanDojo"]["version"] = constants.__version__ 382 | self.database["report"]["LeanDojo"]["MAX_NUM_PROCS"] = constants.MAX_NUM_PROCS 383 | self.database["report"]["LeanDojo"]["CACHE_DIR"] = constants.CACHE_DIR.as_posix() 384 | 385 | def run_leandojo(self) -> None: 386 | logger.info("Running LeanDojo for the data extraction.") 387 | self.configure_leandojo() 388 | for repo in tqdm(self.lean_ready_repos if self.lean_ready_repos else self.new_repos): 389 | repo_info = self.database["repos"][repo] 390 | print_prefix = f"[{repo_info['full_name']}]" 391 | logger.info(f"{print_prefix} Extracting dataset.") 392 | 393 | try: 394 | lean_repo = LeanGitRepo(repo_info["path"], repo_info["branch"]) 395 | traced_repo = trace( 396 | repo=lean_repo, 397 | dst_dir=repo_info["raw_dataset_path"], 398 | # build_deps=True 399 | ) 400 | self.database["repos"][repo]["Traced"] = True 401 | splits = self.split_data(traced_repo, repo) 402 | self.export_data( 403 | traced_repo, 404 | splits, 405 | repo_info["dataset_export_dir"], 406 | dataset_name=f"{print_prefix} dataset", 407 | repo=repo, 408 | ) 409 | except Exception as ex: 410 | logger.info(f"{print_prefix} Error fetching results: {repr(ex)}") 411 | 412 | def split_data(self, traced_repo: TracedRepo, repo: str) -> dict[SPLIT_STRATEGY, SPLIT]: 413 | # Skip theorems in the Lean 4 repo itself. 414 | traced_theorems = [thm for thm in traced_repo.get_traced_theorems() if not thm.repo.is_lean4] 415 | repo_info = self.database["repos"][repo] 416 | print_prefix = f"[{repo_info['full_name']}]" 417 | 418 | total_theorems_num = len(traced_theorems) 419 | logger.info(f"{print_prefix} Theorems in total: {total_theorems_num}") 420 | 421 | NUM_VAL = NUM_TEST = 0 422 | if total_theorems_num > self.large_dataset: 423 | NUM_VAL = NUM_TEST = int(total_theorems_num * self.large_dataset_test_val_percent / 100) 424 | else: 425 | if total_theorems_num > self.medium_dataset: 426 | NUM_VAL = NUM_TEST = int(total_theorems_num * self.medium_dataset_test_val_percent / 100) 427 | else: 428 | if total_theorems_num > self.small_dataset: 429 | NUM_VAL = NUM_TEST = int(total_theorems_num * self.small_dataset_test_val_percent / 100) 430 | 431 | self.database["repos"][repo][ 432 | "Split train/val/test" 433 | ] = f"{total_theorems_num - NUM_TEST - NUM_VAL}/{NUM_VAL}/{NUM_TEST}" 434 | return { 435 | "random": self.split_randomly(traced_theorems, NUM_TEST, NUM_VAL, repo), 436 | "novel_premises": self.split_by_premise(traced_theorems, NUM_TEST, NUM_VAL, repo), 437 | } 438 | 439 | def split_randomly(self, traced_theorems: list[TracedTheorem], NUM_TEST: int, NUM_VAL: int, repo: str) -> SPLIT: 440 | """Split ``traced_theorems`` randomly into train/val/test.""" 441 | repo_info = self.database["repos"][repo] 442 | print_prefix = f"[{repo_info['full_name']}]" 443 | logger.info(f"{print_prefix} Splitting the theorems randomly") 444 | traced_theorems = copy(traced_theorems) 445 | random.shuffle(traced_theorems) 446 | return self._split_sequentially(traced_theorems, NUM_TEST, NUM_VAL) 447 | 448 | def _split_sequentially(self, traced_theorems: list[TracedTheorem], NUM_TEST: int, NUM_VAL: int) -> SPLIT: 449 | """Split ``traced_theorems`` sequentially into train/val/test.""" 450 | num_theorems = len(traced_theorems) 451 | num_train = num_theorems - NUM_VAL - NUM_TEST 452 | return { 453 | "train": traced_theorems[:num_train], 454 | "val": traced_theorems[num_train : num_train + NUM_VAL], 455 | "test": traced_theorems[num_train + NUM_VAL :], 456 | } 457 | 458 | def split_by_premise(self, traced_theorems: list[TracedTheorem], NUM_TEST: int, NUM_VAL: int, repo: str) -> SPLIT: 459 | """ 460 | Split theorems into train/val/test so that proofs in val/test rely on at 461 | least one novel premise that does not appear in train. 462 | """ 463 | repo_info = self.database["repos"][repo] 464 | print_prefix = f"[{repo_info['full_name']}]" 465 | logger.info(f"{print_prefix} Splitting the theorems by premises") 466 | 467 | # Figure out the number of theorems in train/val/test. 468 | # num_theorems = len(traced_theorems) 469 | num_val_test = NUM_VAL + NUM_TEST 470 | # num_train = num_theorems - num_val_test 471 | theorems_val_test_: set = set() 472 | 473 | # Map each premise to a list of theorems using it. 474 | theorems_by_premises_ = defaultdict(list) 475 | for t in traced_theorems: 476 | for p in t.get_premise_full_names(): 477 | theorems_by_premises_[p].append(t) 478 | 479 | # Sort the premises by the number of theorems using them (in ascending order). 480 | theorems_by_premises = sorted(theorems_by_premises_.items(), key=lambda x: len(x[1])) 481 | 482 | # For each premise, put all theorems using it into val_test so that it does not appear in train. 483 | for _, thms in theorems_by_premises: 484 | if len(theorems_val_test_) < num_val_test: 485 | theorems_val_test_.update(thms) 486 | 487 | # All other theorems go to train. 488 | theorems_train = [t for t in traced_theorems if t not in theorems_val_test_] 489 | theorems_val_test = list(theorems_val_test_) 490 | random.shuffle(theorems_val_test) 491 | 492 | return { 493 | "train": theorems_train, 494 | "val": theorems_val_test[:NUM_VAL], 495 | "test": theorems_val_test[NUM_VAL:], 496 | } 497 | 498 | def export_data( 499 | self, 500 | traced_repo: TracedRepo, 501 | splits: dict[SPLIT_STRATEGY, SPLIT], 502 | dst_path: Union[str, Path], 503 | repo: str, 504 | **kwargs, 505 | ) -> None: 506 | """Export a traced repo whose theorems have been splitted to ``dst_path``.""" 507 | repo_info = self.database["repos"][repo] 508 | print_prefix = f"[{repo_info['full_name']}]" 509 | if isinstance(dst_path, str): 510 | dst_path = Path(dst_path) 511 | if dst_path.exists(): 512 | logger.warning(f"{dst_path} already exists. Removing it now.") 513 | shutil.rmtree(dst_path) 514 | 515 | # Export the proofs. 516 | logger.info(f"{print_prefix} Exporting the proofs.") 517 | self.export_proofs(splits, dst_path, traced_repo, repo) 518 | 519 | # Export the premises (theorems, definitions, etc.). 520 | logger.info(f"{print_prefix} Exporting the premises.") 521 | self.export_premises(traced_repo, dst_path, repo) 522 | 523 | # Export the licenses. 524 | logger.info(f"{print_prefix} Exporting the licenses.") 525 | self.export_licenses(traced_repo, dst_path) 526 | 527 | # Export metadata. 528 | logger.info(f"{print_prefix} Exporting the metadata.") 529 | self.export_metadata(traced_repo, dst_path, **kwargs) 530 | 531 | logger.info(f"{print_prefix} Exporting the data is completed.") 532 | 533 | def export_proofs( 534 | self, splits: dict[SPLIT_STRATEGY, SPLIT], dst_path: Path, traced_repo: TracedRepo, repo: str 535 | ) -> None: 536 | """Export all proofs in a traced repo to ``dst_path''.""" 537 | repo_info = self.database["repos"][repo] 538 | print_prefix = f"[{repo_info['full_name']}]" 539 | for strategy, split in splits.items(): 540 | split_dir = dst_path / strategy 541 | split_dir.mkdir(parents=True) 542 | 543 | for name, theorems in split.items(): 544 | data = [] 545 | num_tactics = 0 546 | 547 | for thm in theorems: 548 | tactics = [ 549 | { 550 | "tactic": t.tactic, 551 | "annotated_tactic": t.get_annotated_tactic(), 552 | "state_before": t.state_before, 553 | "state_after": t.state_after, 554 | } 555 | for t in thm.get_traced_tactics() 556 | if t.state_before != "no goals" and "·" not in t.tactic # Ignore "·". 557 | ] 558 | num_tactics += len(tactics) 559 | data.append( 560 | { 561 | "url": traced_repo.repo.url, 562 | "commit": traced_repo.repo.commit, 563 | "file_path": self._get_file_path(traced_repo, thm), 564 | "full_name": thm.theorem.full_name, 565 | "start": list(thm.start), 566 | "end": list(thm.end), 567 | "traced_tactics": tactics, 568 | } 569 | ) 570 | oup_path = split_dir / f"{name}.json" 571 | json.dump(data, oup_path.open("wt")) 572 | logger.info(f"{print_prefix} {len(theorems)} theorems and {num_tactics} tactics saved to {oup_path}") 573 | 574 | def _get_file_path(self, traced_repo: TracedRepo, thm: TracedTheorem) -> str: 575 | if thm.repo == traced_repo.repo: 576 | # The theorem belongs to the traced repo itself. 577 | return str(thm.theorem.file_path) 578 | else: 579 | # The theorem belongs to one of the dependencies. 580 | for name, dep in traced_repo.dependencies.items(): 581 | if dep == thm.repo: 582 | return f"{LEAN4_PACKAGES_DIR}/{name}/{thm.theorem.file_path}" 583 | raise ValueError(f"Unable to find the dependency {thm.repo}") 584 | 585 | def export_premises(self, traced_repo: TracedRepo, dst_path: Path, repo: str) -> None: 586 | """Export all premise definitions in a traced repo to ``dst_path``.""" 587 | repo_info = self.database["repos"][repo] 588 | print_prefix = f"[{repo_info['full_name']}]" 589 | oup_path = dst_path / "corpus.jsonl" 590 | num_premises = 0 591 | 592 | with oup_path.open("wt") as oup: 593 | G = traced_repo.traced_files_graph 594 | logger.info(f"Printing1 {traced_repo}") 595 | logger.info(f"Printing {G}") 596 | for tf_node in reversed(list(nx.topological_sort(G))): 597 | tf = G.nodes[tf_node]["traced_file"] 598 | imports = [str(_) for _ in G.successors(tf_node)] 599 | premises = tf.get_premise_definitions() 600 | num_premises += len(premises) 601 | oup.write(json.dumps({"path": str(tf.path), "imports": imports, "premises": premises}) + "\n") 602 | length = len(traced_repo.traced_files) 603 | logger.info(f"{print_prefix} {num_premises} theorems/definitions from {length} files saved to {oup_path}") 604 | 605 | def export_licenses(self, traced_repo: TracedRepo, dst_path: Path) -> None: 606 | """Export the licenses of a traced repo and all its dependencies to ``dst_path``.""" 607 | license_dir = dst_path / "licenses" 608 | license_dir.mkdir() 609 | all_repos = [traced_repo.repo] + list(traced_repo.dependencies.values()) 610 | 611 | for repo in all_repos: 612 | lic = repo.get_license() 613 | if lic is None: 614 | continue 615 | with (license_dir / repo.name).open("wt") as oup: 616 | oup.write(lic) 617 | 618 | with (license_dir / "README.md").open("wt") as oup: 619 | oup.write( 620 | """This directory contains licenses of Lean repos used to generate this dataset. \\ 621 | The dataset itself is released under [CC BY 2.0](https://creativecommons.org/licenses/by/2.0/).""" 622 | ) 623 | 624 | def export_metadata(self, traced_repo: TracedRepo, dst_path: Path, **kwargs) -> None: 625 | """Export the metadata of a traced repo to ``dst_path''.""" 626 | metadata = dict(kwargs) 627 | metadata["creation_time"] = str(datetime.now()) 628 | metadata["from_repo"] = { 629 | "url": traced_repo.repo.url, 630 | "commit": traced_repo.repo.commit, 631 | } 632 | metadata["leandojo_version"] = lean_dojo.__version__ 633 | json.dump(metadata, (dst_path / "metadata.json").open("wt")) 634 | 635 | 636 | class LeanUniverseData: 637 | def __init__( 638 | self, 639 | cache_dir: Path, 640 | log_file: Path, 641 | dataset_export_dir: Path, 642 | raw_dataset_dir: Path, 643 | repos_dir: Path, 644 | timestamp: str, 645 | ld_cache_dir: Optional[Path], 646 | repos_excluded: Optional[list[str]] = None, 647 | repos_included: Optional[list[str]] = None, 648 | licesnes_excluded: Optional[list[str]] = None, 649 | max_num_repos: int = -1, 650 | language: str = "lean", 651 | query: str = "lean", 652 | ld_max_num_procs: int = -1, 653 | large_dataset: int = -1, 654 | large_dataset_test_val_percent: int = -1, 655 | medium_dataset: int = -1, 656 | medium_dataset_test_val_percent: int = -1, 657 | small_dataset: int = -1, 658 | small_dataset_test_val_percent: int = -1, 659 | test_val_min_size: int = -1, 660 | ): 661 | self.cache_dir = cache_dir 662 | logger.info(f"Loading dataset from {cache_dir}") 663 | self.timestamp = timestamp 664 | self.github = get_github() 665 | self.dataset = Dataset( 666 | cache_dir=cache_dir, 667 | log_file=log_file, 668 | dataset_export_dir=dataset_export_dir, 669 | raw_dataset_dir=raw_dataset_dir, 670 | repos_dir=repos_dir, 671 | timestamp=timestamp, 672 | github=self.github, 673 | repos_excluded=repos_excluded, 674 | repos_included=repos_included, 675 | licesnes_excluded=licesnes_excluded, 676 | ld_cache_dir=ld_cache_dir, 677 | ld_max_num_procs=ld_max_num_procs, 678 | large_dataset=large_dataset, 679 | large_dataset_test_val_percent=large_dataset_test_val_percent, 680 | medium_dataset=medium_dataset, 681 | medium_dataset_test_val_percent=medium_dataset_test_val_percent, 682 | small_dataset=small_dataset, 683 | small_dataset_test_val_percent=small_dataset_test_val_percent, 684 | test_val_min_size=test_val_min_size, 685 | ) 686 | 687 | self.max_num = max_num_repos 688 | self.query = query 689 | self.language = language 690 | 691 | self.repos_excluded = repos_excluded 692 | self.repos_included = repos_included 693 | self.licenses_excluded = licesnes_excluded 694 | 695 | def fetch_repos_and_update_database(self): 696 | self.__pull_repos() 697 | 698 | def __pull_repos(self) -> None: 699 | """ 700 | Pulls repositories from the Lean Universe. 701 | """ 702 | repos = get_lean_repos( 703 | GITHUB=self.github, 704 | repos_included=self.repos_included, 705 | max_num=self.max_num, 706 | query=self.query, 707 | language=self.language, 708 | ) 709 | self.dataset.filter_new_repos(repos) 710 | self.dataset.clone_or_pull_repos() 711 | 712 | def build_lake(self) -> None: 713 | self.dataset.build_lake() 714 | 715 | def run_leandojo(self) -> None: 716 | self.dataset.configure_leandojo() 717 | self.dataset.run_leandojo() 718 | 719 | def report(self) -> None: 720 | self.dataset.report() 721 | logger.info("Report saved.") 722 | -------------------------------------------------------------------------------- /lean_universe/dataset/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from logging import getLogger 8 | from pathlib import Path 9 | 10 | from lean_universe.dataset.args import EvalArgs 11 | from lean_universe.dataset.data import LeanUniverseData 12 | from lean_universe.utils.logger import add_logger_file_handler, init_logger, log_host 13 | from lean_universe.utils.params import cfg_from_cli 14 | 15 | 16 | def main(args: EvalArgs): 17 | init_logger() 18 | log_host() 19 | add_logger_file_handler(args.log_file) 20 | logger = getLogger() 21 | logger.info(f"Logger file is located at {args.log_file}") 22 | 23 | 24 | if __name__ == "__main__": 25 | args: EvalArgs = cfg_from_cli(schema=EvalArgs) 26 | main(args) 27 | logger = getLogger() 28 | 29 | data = LeanUniverseData( 30 | cache_dir=Path(args.cache_dir), 31 | log_file=Path(args.log_file), 32 | dataset_export_dir=Path(args.dataset_export_dir), 33 | raw_dataset_dir=Path(args.raw_dataset_dir), 34 | repos_dir=Path(args.repos_dir), 35 | timestamp=args.timestamp, 36 | repos_excluded=[ 37 | "https://github.com/leanprover-community/mathlib4_with_LeanInfer", 38 | ], 39 | repos_included=args.repos_included, 40 | # repos_included=[ 41 | # # "https://github.com/dwrensha/compfiles", 42 | # # "https://github.com/goens/lost-pop-lean", 43 | # # "https://github.com/RustyYato/lean-algebra", 44 | # # "https://github.com/Junology/dijkstra", 45 | # # "https://github.com/arthurpaulino/viper", 46 | # # "https://github.com/isubasinghe/leftpad-lean", 47 | # # "https://github.com/FizzyElt/lean-pratice", 48 | # ], 49 | max_num_repos=args.max_num_repos, 50 | # max_num_repos=1, 51 | ld_max_num_procs=args.ld_max_num_procs, 52 | # ld_max_num_procs=100, 53 | ld_cache_dir=Path(args.cache_dir) / "ld_cache", 54 | large_dataset=args.large_dataset, 55 | large_dataset_test_val_percent=args.large_dataset_test_val_percent, 56 | medium_dataset=args.medium_dataset, 57 | medium_dataset_test_val_percent=args.medium_dataset_test_val_percent, 58 | small_dataset=args.small_dataset, 59 | small_dataset_test_val_percent=args.small_dataset_test_val_percent, 60 | test_val_min_size=args.test_val_min_size, 61 | ) 62 | 63 | data.fetch_repos_and_update_database() 64 | # data.build_lake() 65 | data.run_leandojo() 66 | 67 | data.report() 68 | -------------------------------------------------------------------------------- /lean_universe/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /lean_universe/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | import os 10 | import socket 11 | import sys 12 | import time 13 | from datetime import timedelta 14 | from functools import cache 15 | 16 | 17 | @cache 18 | def get_is_dora_run() -> bool: 19 | # Set in apps.codegen_dora.compat.py 20 | return os.environ.get("DORA_FORCE_DISTRIB") is not None 21 | 22 | 23 | @cache 24 | def get_is_local_run() -> bool: 25 | return os.environ.get("LOCAL_RUN") is not None 26 | 27 | 28 | @cache 29 | def get_is_slurm_job() -> bool: 30 | return "SLURM_JOB_ID" in os.environ and not get_is_local_run() 31 | 32 | 33 | @cache 34 | def get_global_rank() -> int: 35 | if get_is_dora_run(): 36 | return int(os.environ["RANK"]) 37 | elif get_is_local_run(): 38 | return 0 39 | elif get_is_slurm_job(): 40 | return int(os.environ["SLURM_PROCID"]) 41 | else: 42 | return 0 43 | 44 | 45 | def log_host(): 46 | logger = logging.getLogger() 47 | logger.info(f"Host: {socket.gethostname()}") 48 | logger.info(f"Job hosts: {os.environ.get('SLURM_JOB_NODELIST', '')}") 49 | logger.info(f"Slurm job id: {int(os.environ.get('SLURM_JOB_ID', -1))}") 50 | 51 | 52 | class LogFormatter(logging.Formatter): 53 | def __init__(self): 54 | self.start_time = time.time() 55 | self.rank = get_global_rank() 56 | self.show_rank = not get_is_slurm_job() # srun has --label 57 | 58 | def format(self, record): 59 | # define prefix 60 | # record.pathname / record.filename / record.lineno 61 | subsecond, seconds = math.modf(record.created) 62 | curr_date = time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) + f".{int(subsecond * 1_000_000):06d}" 63 | delta = timedelta(seconds=round(record.created - self.start_time)) 64 | if self.show_rank: 65 | prefix = f"{self.rank}: {record.levelname:<7} {curr_date} - {delta} - " 66 | else: 67 | prefix = f"{record.levelname:<7} {curr_date} - {delta} - " 68 | 69 | # logged content 70 | content = record.getMessage() 71 | indent = " " * len(prefix) 72 | content = content.replace("\n", "\n" + indent) 73 | 74 | # Exception handling as in the default formatter, albeit with indenting 75 | # according to our custom prefix 76 | if record.exc_info and not record.exc_text: 77 | # Cache the traceback text to avoid converting it multiple times 78 | # (it's constant anyway) 79 | record.exc_text = self.formatException(record.exc_info) 80 | if record.exc_text: 81 | if content[-1:] != "\n": 82 | content = content + "\n" + indent 83 | content = content + indent.join([line + "\n" for line in record.exc_text.splitlines()]) 84 | if content[-1:] == "\n": 85 | content = content[:-1] 86 | if record.stack_info: 87 | if content[-1:] != "\n": 88 | content = content + "\n" + indent 89 | stack_text = self.formatStack(record.stack_info) 90 | content = content + indent.join([line + "\n" for line in stack_text.splitlines()]) 91 | if content[-1:] == "\n": 92 | content = content[:-1] 93 | 94 | return prefix + content 95 | 96 | 97 | def init_logger() -> logging.Logger: 98 | # log everything 99 | logger = logging.getLogger() 100 | logger.setLevel(logging.NOTSET) 101 | 102 | # stdout: everything 103 | stdout_handler = logging.StreamHandler(sys.stdout) 104 | stdout_handler.setLevel(logging.NOTSET) 105 | stdout_handler.setFormatter(LogFormatter()) 106 | 107 | # stderr: warnings / errors and above 108 | stderr_handler = logging.StreamHandler(sys.stderr) 109 | stderr_handler.setLevel(logging.WARNING) 110 | stderr_handler.setFormatter(LogFormatter()) 111 | 112 | # set stream handlers 113 | logger.handlers.append(stdout_handler) 114 | logger.handlers.append(stderr_handler) 115 | 116 | # turn package loggers silent 117 | logging.getLogger("filelock").setLevel(logging.WARNING) 118 | 119 | return logger 120 | 121 | 122 | def add_logger_file_handler(filepath: str): 123 | # build file handler 124 | file_handler = logging.FileHandler(filepath, "a") 125 | file_handler.setLevel(logging.NOTSET) 126 | file_handler.setFormatter(LogFormatter()) 127 | 128 | # update logger 129 | logger = logging.getLogger() 130 | logger.addHandler(file_handler) 131 | -------------------------------------------------------------------------------- /lean_universe/utils/params.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # mypy: ignore-errors 8 | import abc 9 | import argparse 10 | import copy 11 | import dataclasses 12 | import json 13 | from argparse import ArgumentParser 14 | from dataclasses import asdict, dataclass, field, fields, is_dataclass 15 | from inspect import signature 16 | from logging import getLogger 17 | from typing import Any, Optional, TypeVar 18 | 19 | import typeguard 20 | 21 | FALSY_STRINGS = {"off", "false", "0"} 22 | TRUTHY_STRINGS = {"on", "true", "1"} 23 | 24 | SEPARATOR = "." 25 | 26 | 27 | logger = getLogger() 28 | 29 | 30 | class DeepCopyDict(dict): 31 | def __getitem__(self, item): 32 | res = super().__getitem__(item) 33 | if callable(res): 34 | return res() 35 | return copy.deepcopy(res) 36 | 37 | def __setitem__(self, key, value): 38 | if key in self: 39 | raise ValueError(f"{key} already in ConfStore") 40 | 41 | super().__setitem__(key, value) 42 | 43 | 44 | ConfStore = DeepCopyDict() 45 | 46 | 47 | MISSING: Any = "???" 48 | 49 | 50 | class NOTSET: 51 | pass 52 | 53 | 54 | class MissingArg(Exception): 55 | pass 56 | 57 | 58 | class WrongConfName(Exception): 59 | pass 60 | 61 | 62 | class WrongConfType(Exception): 63 | pass 64 | 65 | 66 | class WrongArgType(Exception): 67 | pass 68 | 69 | 70 | class WrongFieldType(Exception): 71 | pass 72 | 73 | 74 | class OptionalDataClass(Exception): 75 | pass 76 | 77 | 78 | class DefaultDataClassValue(Exception): 79 | pass 80 | 81 | 82 | def bool_flag(s): 83 | """ 84 | Parse boolean arguments from the command line. 85 | """ 86 | if s.lower() in FALSY_STRINGS: 87 | return False 88 | elif s.lower() in TRUTHY_STRINGS: 89 | return True 90 | elif s == NOTSET: 91 | return MISSING 92 | else: 93 | raise argparse.ArgumentTypeError("Invalid value for a boolean flag!") 94 | 95 | 96 | def flatten_dict(to_flatten, prefix=""): 97 | flattened = {} 98 | for x, y in to_flatten.items(): 99 | if isinstance(y, dict): 100 | flattened.update(flatten_dict(y, prefix=f"{prefix}{x}{SEPARATOR}")) 101 | else: 102 | flattened[f"{prefix}{x}"] = y 103 | return flattened 104 | 105 | 106 | def is_optional(some_type): 107 | return some_type == Optional[some_type] 108 | 109 | 110 | def get_opt_type(some_type): 111 | if is_optional(some_type): 112 | # allowing Optional dataclasses now: we only recurse in the cli if the name appears in the args 113 | # if is_dataclass(some_type.__args__[0]): 114 | # raise OptionalDataClass("Optional dataclasses not supported") 115 | return some_type.__args__[0] 116 | return some_type 117 | 118 | 119 | def NOCLI(default: Any = None): 120 | return field(default=default, metadata={"NOCLI": True}) 121 | 122 | 123 | def is_nocli(some_field): 124 | return isinstance(some_field, dataclasses.Field) and some_field.metadata.get("NOCLI", False) is True 125 | 126 | 127 | def is_dict_type(ty): 128 | """ 129 | Return whether the argument type is Dict or dict, or an applied 130 | variant of these two (e.g., dict[str, int] or Dict[int, Any]). 131 | """ 132 | return ty is dict or getattr(ty, "__origin__", None) is dict 133 | 134 | 135 | @dataclass 136 | class Params(abc.ABC): 137 | def __new__(cls, *args, **kwargs): 138 | """ 139 | Verify that: 140 | 1) a Params dataclass does not have a field of type Params with a 141 | default value (as it would result in silent configs bugs) 142 | 2) default_factory arguments build dataclasses / objects of expected types 143 | """ 144 | for field in fields(cls): 145 | ftype = get_opt_type(field.type) 146 | if is_dict_type(ftype): 147 | kt = getattr(ftype, "__args__", (str,))[0] 148 | if kt is not str: 149 | raise RuntimeError( 150 | f"Dictionary field {field.name} of class {cls.__name__} " 151 | f"must have 'str' as its key type, got {kt}." 152 | ) 153 | if not is_nocli(field): 154 | raise RuntimeError( 155 | f"Dictionary field {field.name} of class {cls.__name__} " f"must be marked as NOCLI." 156 | ) 157 | if is_dataclass(ftype) and isinstance(field.default, Params): 158 | raise DefaultDataClassValue( 159 | f"Field {field.name} of class {cls.__name__} is set to a shared " 160 | f"default config. Setting a shared default value is dangerous " 161 | f"and can lead to unexpected changes across configs. " 162 | f"Use `default_factory` instead." 163 | ) 164 | if field.default_factory != dataclasses.MISSING: 165 | try: 166 | value = field.default_factory() 167 | if "argname" in signature(typeguard.check_type).parameters: 168 | # typeguard < 3 169 | typeguard.check_type(field.name, value, field.type) # type: ignore 170 | else: 171 | # typeguard >= 3 172 | typeguard.check_type(value, field.type) # type: ignore 173 | except TypeError: 174 | raise DefaultDataClassValue( 175 | f"`default_factory` for field {field.name} of class {cls.__name__} " 176 | f"should build objects of type {field.type.__name__}." 177 | ) 178 | return super().__new__(cls) 179 | 180 | @classmethod 181 | def to_cli(cls, prefix: str = "", parser: Optional[ArgumentParser] = None) -> ArgumentParser: 182 | # initialize parser 183 | if parser is None: 184 | parser = ArgumentParser(allow_abbrev=False) 185 | parser.add_argument("--cfg", type=str) 186 | parser.add_argument("--json", type=str) 187 | 188 | # for each field in the dataclass 189 | for field in fields(cls): 190 | # sanity check / get field name / field type 191 | if prefix == "": 192 | assert field.name != "cfg", "'cfg' field is reserved for cli parser" 193 | assert field.name != "json", "'json' field is reserved for cli parser" 194 | fullname = f"{prefix}{field.name}" 195 | field_type = get_opt_type(field.type) 196 | 197 | # this field is not allowed in the CLI -> nothing to do 198 | if is_nocli(field): 199 | pass 200 | 201 | # dataclass -> recursively add arguments to the CLI 202 | elif is_dataclass(field_type): 203 | field_type.to_cli(prefix=f"{fullname}.", parser=parser) # type: ignore[attr-defined] 204 | parser.add_argument(f"--{fullname}", type=str, default=NOTSET) 205 | 206 | # standard parameter 207 | else: 208 | parser.add_argument( 209 | f"--{fullname}", 210 | type=bool_flag if field_type == bool else field_type, 211 | help="" if field.metadata is None else field.metadata.get("help"), 212 | default=NOTSET, 213 | ) 214 | 215 | return parser 216 | 217 | @classmethod 218 | def from_cli( 219 | cls, 220 | param_dict: dict[str, Any], 221 | prefix: str = "", 222 | default_instance: Optional["Params"] = None, 223 | allow_incomplete: Optional[bool] = False, 224 | ): 225 | """ 226 | Converts a flat dot separated dictionary into a class object. 227 | The dictionary must come from a CLI (`parse_args`) output directly. 228 | """ 229 | assert default_instance is None or isinstance(default_instance, cls) 230 | 231 | kwargs = {} 232 | 233 | # to crash/or do some warning if some unused args 234 | used_args: set[str] = {"cfg", "json"} 235 | 236 | for field in fields(cls): 237 | fullname = f"{prefix}{field.name}" 238 | field_type = get_opt_type(field.type) 239 | assert allow_incomplete or (fullname not in param_dict) == is_nocli(field), f"Error with field {fullname}" 240 | cli_value = NOTSET 241 | if fullname in param_dict: 242 | cli_value = param_dict[fullname] 243 | used_args |= {k for k in param_dict if k.startswith(fullname)} 244 | 245 | # default value is field is not set in the CLI 246 | default_value = _get_default_value(field, default_instance) 247 | 248 | # this field should not be in the CLI, so we use the default value 249 | if is_nocli(field): 250 | kwargs[field.name] = default_value 251 | 252 | # this is a dataclass 253 | elif is_dataclass(field_type): 254 | # If optional, default value must be None 255 | # In this case, we only recurse if some params start with our full name 256 | must_recurse = any([x.startswith(fullname) and y != NOTSET for x, y in param_dict.items()]) 257 | if not must_recurse and is_optional(field.type): 258 | # Either default_value is None and nothing else was set in the CLI 259 | # Or it might be set in the default instance which is captured by default_value 260 | kwargs[field.name] = default_value 261 | else: 262 | # dataclass no specified, use a default value 263 | if cli_value == NOTSET: 264 | sub_conf = None if default_value == MISSING else default_value 265 | 266 | # dataclass specified using a named config 267 | elif isinstance(cli_value, str) and cli_value != MISSING: 268 | if cli_value not in ConfStore: 269 | raise WrongConfName(f"Unknown conf key {cli_value} for field {fullname}") 270 | sub_conf = ConfStore[cli_value] 271 | 272 | # unexpected type (should not be reachable if `param_dict` is the 273 | # output of a CLI, as the argument should be of type `str`). 274 | else: 275 | raise WrongArgType(f"Value for {fullname} should be a string!") 276 | 277 | # check that the current config has a correct type 278 | if sub_conf is not None and not isinstance(sub_conf, field_type): 279 | raise WrongConfType( 280 | f"Invalid configuration. Provided a configuration of type " 281 | f'"{type(sub_conf).__name__}", expected "{field_type.__name__}".' 282 | ) 283 | 284 | # if we specified a.b = some_name and a.b.c = 5, we want to overwrite some_name.c 285 | kwargs[field.name] = field_type.from_cli( # type: ignore[attr-defined] 286 | param_dict=param_dict, 287 | prefix=f"{fullname}.", 288 | default_instance=sub_conf, 289 | ) 290 | 291 | # this is not a dataclass, with a CLI provided value. try to parse it 292 | elif cli_value != NOTSET: 293 | try: 294 | kwargs[field.name] = field_type(cli_value) 295 | except ValueError as e: 296 | raise WrongArgType(e) 297 | 298 | # this is not a dataclass, and it does not appear in the CLI 299 | else: 300 | kwargs[field.name] = default_value 301 | 302 | # all arguments should be used and not MISSING 303 | unused_args = {k: v for k, v in param_dict.items() if (k not in used_args) and k.startswith(prefix)} 304 | if unused_args: 305 | raise RuntimeError(f"Some fields in from_cli are unused to instanciate {cls}: {unused_args}") 306 | for x, y in kwargs.items(): 307 | if y == MISSING: 308 | raise MissingArg(f"Parameter {prefix}{x} is MISSING") 309 | 310 | return cls(**kwargs) 311 | 312 | @classmethod 313 | def from_flat( 314 | cls, 315 | flat: dict[str, Any], 316 | prefix: str = "", 317 | default_instance: Optional["Params"] = None, 318 | unused_warning: bool = True, 319 | ): 320 | """ 321 | Converts a flat dot separated dictionary into a class object. 322 | The dictionary must come from a flattened object from the same class, 323 | and not a CLI. 324 | 325 | NOTE: some issues will happen if a class field is a Union of dataclasses (e.g. DatasetConf), 326 | as `field_type.from_flat` will not know which class to reload (e.g. MetamathDatasetConf, or 327 | HolLightDatasetConf in ZMQProverParams.dataset) 328 | """ 329 | kwargs: dict[str, Any] = {} 330 | 331 | # to crash/or do some warning if some unused args 332 | used_args: set[str] = set() 333 | 334 | for field in fields(cls): 335 | fullname = f"{prefix}{field.name}" 336 | field_type = get_opt_type(field.type) 337 | 338 | default_value = _get_default_value(field, default_instance) 339 | 340 | if is_dataclass(field_type): 341 | sub_conf = None if default_value == MISSING else default_value 342 | 343 | if is_optional(field.type) and fullname in flat: 344 | assert flat[fullname] is None 345 | assert len({k for k in flat if k.startswith(f"{fullname}{SEPARATOR}")}) == 0 346 | kwargs[field.name] = None 347 | else: 348 | kwargs[field.name] = field_type.from_flat( # type: ignore[attr-defined] 349 | flat, prefix=f"{fullname}{SEPARATOR}", default_instance=sub_conf 350 | ) 351 | used_args |= {k for k in flat if k.startswith(f"{fullname}{SEPARATOR}")} 352 | elif is_dict_type(field_type): 353 | prefix = f"{fullname}{SEPARATOR}" 354 | 355 | dictparam = {} 356 | for key, value in flat.items(): 357 | if key.startswith(prefix): 358 | used_args.add(key) 359 | dictparam[key[len(prefix) :]] = value 360 | 361 | if not dictparam: 362 | if default_value == MISSING: 363 | if not is_optional(field_type): 364 | raise MissingArg(f"Parameter {fullname} unspecified and has no default") 365 | else: 366 | dictparam = None # type: ignore 367 | else: 368 | dictparam = default_value 369 | 370 | kwargs[field.name] = dictparam 371 | else: 372 | try: 373 | kwargs[field.name] = flat[fullname] 374 | used_args.add(fullname) 375 | except KeyError: 376 | # We only crash if no default is provided 377 | if default_value == MISSING: 378 | raise MissingArg(f"Parameter {fullname} unspecified and has no default") 379 | else: 380 | kwargs[field.name] = default_value 381 | 382 | unused_args = {k: v for k, v in flat.items() if (k not in used_args) and k.startswith(prefix)} 383 | if unused_warning and unused_args: 384 | logger.warning(f"Some fields in flat_dict are unused to instantiate {cls}: {unused_args}") 385 | for x, y in kwargs.items(): 386 | if y == MISSING: 387 | raise MissingArg(f"Parameter {prefix}{x} is MISSING") 388 | return cls(**kwargs) 389 | 390 | @classmethod 391 | def from_dict(cls: type["Params"], src: dict): 392 | """ 393 | @param src: a nested (or flat) dictionary as exported by dataclass.asdict() 394 | @return: a cls object with defaults filled in 395 | """ 396 | flat_dict = flatten_dict(src) 397 | return cls.from_flat(flat_dict) 398 | 399 | @classmethod 400 | def from_json(cls: type["Params"], s: str): 401 | return cls.from_dict(json.loads(s)) 402 | 403 | def to_dict(self) -> dict[str, Any]: 404 | return asdict(self) 405 | 406 | def to_flat(self): 407 | return flatten_dict(asdict(self)) 408 | 409 | def to_json(self: "Params"): 410 | return json.dumps(asdict(self), sort_keys=True, indent=4) 411 | 412 | def check_and_mutate_args(self: "Params", avoid: Optional[set[type]] = None): 413 | self._check_and_mutate_args() 414 | for field in fields(self): 415 | attr = getattr(self, field.name) 416 | if isinstance(attr, Params) and (avoid is None or not type(attr) in avoid): 417 | attr.check_and_mutate_args(avoid=avoid) 418 | elif avoid is not None and type(attr) in avoid: 419 | print(f"Not checking {field.name} of type {type(attr)}") 420 | 421 | def get_missing(self) -> set[str]: 422 | def traverse(x: "Params", prefix: str, missing: set[str]): 423 | for field in fields(x): 424 | value = getattr(x, field.name) 425 | if value == MISSING: 426 | missing.add(f"{prefix}{field.name}") 427 | elif is_dataclass(value): 428 | traverse(value, f"{prefix}{field.name}.", missing) # type: ignore[attr-defined] 429 | return missing 430 | 431 | return traverse(self, "", set()) 432 | 433 | def has_missing(self) -> bool: 434 | for field in fields(self): 435 | ftype = get_opt_type(field.type) 436 | value = getattr(self, field.name) 437 | if value == MISSING: 438 | return True 439 | elif is_dataclass(ftype) and value.has_missing(): 440 | return True 441 | return False 442 | 443 | def check_type(self, prefix: str = "") -> None: 444 | for field in fields(self): 445 | ftype = get_opt_type(field.type) 446 | value = getattr(self, field.name) 447 | fullname = f"{prefix}{field.name}" 448 | try: 449 | if "argname" in signature(typeguard.check_type).parameters: 450 | # typeguard < 3 451 | typeguard.check_type(field.name, value, field.type) # type: ignore 452 | else: 453 | # typeguard >= 3 454 | typeguard.check_type(value, field.type) # type: ignore 455 | except TypeError: 456 | raise WrongFieldType(f"Wrong field type for {fullname}: {value}") 457 | if is_dataclass(ftype) and value is not None: 458 | value.check_type(f"{fullname}.") 459 | 460 | def _check_and_mutate_args(self): 461 | pass 462 | 463 | 464 | def _get_default_value(field_: dataclasses.Field, default_instance: Optional["Params"]): 465 | """We check first for default_instance. 466 | If not there we check for field.default_factory() then for field.default 467 | If the default_value is dataclasses.MISSING we normalize it to our MISSING 468 | """ 469 | default_value = ( 470 | (field_.default_factory() if field_.default_factory != dataclasses.MISSING else field_.default) 471 | if default_instance is None 472 | else getattr(default_instance, field_.name) 473 | ) 474 | if default_value == dataclasses.MISSING: 475 | # we normalize to MISSING 476 | return MISSING 477 | return default_value 478 | 479 | 480 | T = TypeVar("T", bound=Params) 481 | 482 | 483 | def cfg_from_cli( 484 | base_config: Optional[T] = None, 485 | schema: Optional[type[T]] = None, 486 | args: Optional[list[str]] = None, 487 | ) -> T: 488 | assert (base_config is None) == (schema is not None) 489 | if base_config is not None: 490 | schema = base_config.__class__ 491 | 492 | assert schema is not None 493 | param_dict = vars(schema.to_cli().parse_args(args)) 494 | if param_dict["cfg"] is not None: 495 | assert param_dict["json"] is None, "Can't specify both --cfg and --json" 496 | base_config = ConfStore[param_dict.pop("cfg")] 497 | assert isinstance(base_config, schema) 498 | elif param_dict["json"] is not None: 499 | with open(param_dict.pop("json")) as f: 500 | json_conf = f.read() 501 | base_config = schema.from_json(json_conf) 502 | ret = schema.from_cli(param_dict=param_dict, default_instance=base_config) 503 | assert isinstance(ret, schema) 504 | return ret 505 | -------------------------------------------------------------------------------- /lean_universe/utils/tools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import time 10 | from collections.abc import Generator 11 | from contextlib import contextmanager 12 | from functools import cache 13 | from logging import getLogger 14 | from pathlib import Path 15 | from subprocess import PIPE, CalledProcessError, Popen # nosec 16 | 17 | from github import Auth, Github 18 | from github.Repository import Repository 19 | 20 | logger = getLogger() 21 | 22 | 23 | # @cache 24 | def get_github() -> Github: 25 | """ 26 | Returns a Github object with authentication if the GITHUB_ACCESS_TOKEN 27 | environment variable is set. 28 | """ 29 | 30 | # This provides increased rate limit for GitHub API calls. 31 | access_token = os.getenv("GITHUB_ACCESS_TOKEN", None) 32 | if access_token: 33 | logger.info("Using GitHub personal access token for authentication") 34 | github = Github(auth=Auth.Token(access_token)) 35 | github.get_user().login 36 | return github 37 | else: 38 | logger.info("Using GitHub without authentication. Don't be surprised if you hit the API rate limit.") 39 | return Github() 40 | 41 | 42 | def normalize_url(url: str) -> str: 43 | """ 44 | Normalize the given URL by removing any trailing slashes. 45 | Args: 46 | url (str): The URL to be normalized. 47 | Returns: 48 | str: The normalized URL. 49 | """ 50 | 51 | return url.rstrip("/") 52 | 53 | 54 | @cache 55 | def url_to_repo(GITHUB: Github, url: str, num_retries: int = 2) -> Repository: 56 | """ 57 | Retrieves a GitHub repository object based on the given URL. 58 | Args: 59 | url (str): The URL of the repository. 60 | num_retries (int, optional): The number of retries in case of failure. Defaults to 2. 61 | Returns: 62 | Repository: The GitHub repository object. 63 | Raises: 64 | Exception: If the retrieval fails after the specified number of retries. 65 | """ 66 | logger.debug(f"Fetching repository for {url}") 67 | url = normalize_url(url) 68 | backoff = 1 69 | 70 | while True: 71 | try: 72 | return GITHUB.get_repo("/".join(url.split("/")[-2:])) 73 | except Exception as ex: 74 | if num_retries <= 0: 75 | logger.error(f"url_to_repo({url}) failed after {num_retries} retries.") 76 | raise ex 77 | num_retries -= 1 78 | logger.debug(f'url_to_repo("{url}") failed. Retrying...') 79 | time.sleep(backoff) 80 | backoff 81 | 82 | 83 | @contextmanager 84 | def change_directory(new_dir: str) -> Generator[None, None, None]: 85 | """ 86 | Change the current working directory to the specified directory temporarily. 87 | Args: 88 | new_dir (str): The path of the directory to change to. 89 | Yields: 90 | None 91 | Raises: 92 | OSError: If the specified directory does not exist or cannot be accessed. 93 | Example: 94 | >>> with change_directory('/path/to/new_directory'): 95 | ... # Code to be executed in the new directory 96 | ... 97 | """ 98 | 99 | current_dir = Path.cwd() 100 | try: 101 | # Change to the new directory 102 | os.chdir(new_dir) 103 | yield 104 | finally: 105 | # Restore the original directory 106 | os.chdir(current_dir) 107 | 108 | 109 | def execute_and_capture(command: str, directory: Path) -> tuple[str, str]: 110 | """ 111 | Executes a shell command in the specified directory and captures the stdout and stderr outputs. 112 | Args: 113 | command (str): The shell command to execute. 114 | directory (Path): The directory in which to execute the command. 115 | Returns: 116 | tuple[str, str]: A tuple containing the stdout and stderr outputs of the command. 117 | """ 118 | 119 | # Use the custom context manager to handle directory changes 120 | with change_directory(directory.as_posix()), Popen( 121 | command, shell=True, stdout=PIPE, stderr=PIPE, text=True # nosec 122 | ) as process: 123 | if process.stdout: 124 | output = process.stdout.read().strip() 125 | for line in output.splitlines(): 126 | logger.info(line.strip()) 127 | process.wait() 128 | if process.returncode != 0 and process.stderr: 129 | error_message = process.stderr.read().strip() 130 | logger.error(f"Command '{command}' failed with exit code {error_message}.") 131 | raise CalledProcessError(process.returncode, command, error_message) 132 | return output, process.stderr.read().strip() if process.stderr else "" 133 | 134 | 135 | def clone_and_checkout(repo_url: str, repo_prefix: str, repo_commit: str, dir: Path, exept: bool = True) -> None: 136 | logger.info(f"{repo_prefix} Cloning repo to {dir.as_posix()}.") 137 | clone_command = f"git clone -n --recursive {repo_url} {dir.name}" 138 | try: 139 | execute_and_capture(clone_command, dir.parent) 140 | except CalledProcessError as ex: 141 | logger.error(f"{repo_prefix} Failed to clone.") 142 | if exept: 143 | raise ex 144 | checkout_command = f"git checkout {repo_commit} && git submodule update --recursive" 145 | try: 146 | execute_and_capture(checkout_command, dir) 147 | except CalledProcessError as ex: 148 | logger.error(f"{repo_prefix} Failed to checkout at {repo_commit}.") 149 | if exept: 150 | raise ex 151 | logger.info(f"{repo_prefix} Cloned to {dir}") 152 | 153 | 154 | def reset_and_pull(repo_url: str, repo_prefix: str, repo_commit: str, dir: Path, exept: bool = True) -> None: 155 | logger.info(f"{repo_prefix} Resetting and pulling repo at {dir.as_posix()}.") 156 | pull_command = f"git checkout {repo_commit} && git pull && git submodule update --recursive" 157 | try: 158 | execute_and_capture(pull_command, dir) 159 | except CalledProcessError as ex: 160 | logger.error(f"{repo_prefix} Failed to clone.") 161 | if exept: 162 | raise ex 163 | logger.info(f"{repo_prefix} Resetted and pulled. {dir}") 164 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | [tool.poetry] 8 | name = "lean_universe" 9 | version = "0.1.0" 10 | description = "LeanUniverse: A Library for Consistent and Scalable Lean4 Dataset Management." 11 | authors = ["Aram H. Markosyan "] 12 | readme = "README.md" 13 | 14 | [tool.poetry.dependencies] 15 | python = "~3.11" 16 | pre-commit = "^3.8.0" 17 | argparse = "^1.4.0" 18 | typeguard = "^4.3.0" 19 | psutil = "^6.0.0" 20 | tqdm = "^4.66.5" 21 | pygithub = "^2.4.0" 22 | lean-dojo = {git = "https://github.com/lean-dojo/LeanDojo.git"} 23 | # lean-dojo = {path = "../LeanDojo", develop = true} 24 | 25 | 26 | [build-system] 27 | requires = ["poetry-core"] 28 | build-backend = "poetry.core.masonry.api" 29 | --------------------------------------------------------------------------------