├── .gitattributes ├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── LICENSE ├── RAI_Transparency_Information.md ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── azure-pipelines.yml ├── datasets ├── DotPrompts │ └── dataset.csv └── PragmaticCode │ ├── fileContentsByRepo.json │ └── repos.csv ├── evaluation_scripts ├── eval_results.py └── eval_utils.py ├── figures ├── figure_2b.png └── motivating_example.png ├── inference_results ├── README.md ├── dotprompts_results.csv └── dotprompts_results_sample.csv ├── pyproject.toml ├── requirements.txt ├── results ├── Report.md ├── all_metrics_table.csv ├── all_metrics_table.md └── figures │ ├── fig_method_dist_by_max_identifier_complexity_dist.png │ ├── rel_changes │ ├── fig_CR_fim_heatmap.png │ ├── fig_CR_models_heatmap.png │ ├── fig_CR_prompt_heatmap.png │ ├── fig_ISM_fim_heatmap.png │ ├── fig_ISM_models_heatmap.png │ ├── fig_ISM_prompt_heatmap.png │ ├── fig_NIM_fim_heatmap.png │ ├── fig_NIM_models_heatmap.png │ ├── fig_NIM_prompt_heatmap.png │ ├── fig_PM_fim_heatmap.png │ ├── fig_PM_models_heatmap.png │ ├── fig_PM_prompt_heatmap.png │ ├── fig_next_identifier_match_fim_id_complexity_heatmap.png │ ├── fig_next_identifier_match_models_id_complexity_heatmap.png │ └── fig_next_identifier_match_prompt_id_complexity_heatmap.png │ └── score_at_k │ ├── fig_CR_fim.png │ ├── fig_CR_models.png │ ├── fig_CR_prompt.png │ ├── fig_ISM_fim.png │ ├── fig_ISM_models.png │ ├── fig_ISM_prompt.png │ ├── fig_NIM_fim.png │ ├── fig_NIM_models.png │ ├── fig_NIM_prompt.png │ ├── fig_PM_fim.png │ ├── fig_PM_models.png │ ├── fig_PM_prompt.png │ ├── fig_next_identifier_match_fim_id_complexity.png │ ├── fig_next_identifier_match_models_id_complexity.png │ └── fig_next_identifier_match_prompt_id_complexity.png ├── results_sample ├── Report.md ├── all_metrics_table.csv ├── all_metrics_table.md └── figures │ ├── fig_method_dist_by_max_identifier_complexity_dist.png │ ├── rel_changes │ ├── fig_CR_fim_heatmap.png │ ├── fig_CR_models_heatmap.png │ ├── fig_CR_prompt_heatmap.png │ ├── fig_ISM_fim_heatmap.png │ ├── fig_ISM_models_heatmap.png │ ├── fig_ISM_prompt_heatmap.png │ ├── fig_NIM_fim_heatmap.png │ ├── fig_NIM_models_heatmap.png │ ├── fig_NIM_prompt_heatmap.png │ ├── fig_PM_fim_heatmap.png │ ├── fig_PM_models_heatmap.png │ ├── fig_PM_prompt_heatmap.png │ ├── fig_next_identifier_match_fim_id_complexity_heatmap.png │ ├── fig_next_identifier_match_models_id_complexity_heatmap.png │ └── fig_next_identifier_match_prompt_id_complexity_heatmap.png │ └── score_at_k │ ├── fig_CR_fim.png │ ├── fig_CR_models.png │ ├── fig_CR_prompt.png │ ├── fig_ISM_fim.png │ ├── fig_ISM_models.png │ ├── fig_ISM_prompt.png │ ├── fig_NIM_fim.png │ ├── fig_NIM_models.png │ ├── fig_NIM_prompt.png │ ├── fig_PM_fim.png │ ├── fig_PM_models.png │ ├── fig_PM_prompt.png │ ├── fig_next_identifier_match_fim_id_complexity.png │ ├── fig_next_identifier_match_models_id_complexity.png │ └── fig_next_identifier_match_prompt_id_complexity.png ├── src └── monitors4codegen │ ├── __init__.py │ ├── monitor_guided_decoding │ ├── hf_gen.py │ ├── mgd_utils.py │ ├── monitor.py │ ├── monitors │ │ ├── class_instantiation_monitor.py │ │ ├── dereferences_monitor.py │ │ ├── numargs_monitor.py │ │ └── switch_enum_monitor.py │ ├── openai_gen.py │ └── tokenizer_wrapper.py │ └── multilspy │ ├── __init__.py │ ├── language_server.py │ ├── language_servers │ ├── eclipse_jdtls │ │ ├── eclipse_jdtls.py │ │ ├── initialize_params.json │ │ └── runtime_dependencies.json │ ├── jedi_language_server │ │ ├── initialize_params.json │ │ └── jedi_server.py │ ├── omnisharp │ │ ├── initialize_params.json │ │ ├── omnisharp.py │ │ ├── runtime_dependencies.json │ │ └── workspace_did_change_configuration.json │ └── rust_analyzer │ │ ├── initialize_params.json │ │ ├── runtime_dependencies.json │ │ └── rust_analyzer.py │ ├── lsp_protocol_handler │ ├── lsp_constants.py │ ├── lsp_requests.py │ ├── lsp_types.py │ └── server.py │ ├── multilspy_config.py │ ├── multilspy_exceptions.py │ ├── multilspy_logger.py │ ├── multilspy_settings.py │ ├── multilspy_types.py │ ├── multilspy_utils.py │ └── type_helpers.py └── tests ├── monitor_guided_decoding ├── test_classinstantiation_monitor_java.py ├── test_dereferences_monitor_java.py ├── test_dereferences_monitor_java_openai.py ├── test_joint_monitors.py ├── test_numargs_monitor_java.py ├── test_switchenum_monitor_csharp.py └── test_typestate_monitor_rust.py ├── multilspy ├── multilspy_context.py ├── test_multilspy_csharp.py ├── test_multilspy_java.py ├── test_multilspy_python.py ├── test_multilspy_rust.py ├── test_sync_multilspy_csharp.py ├── test_sync_multilspy_java.py ├── test_sync_multilspy_python.py └── test_sync_multilspy_rust.py ├── pytest.ini └── test_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | inference_results/dotprompts_results.csv filter=lfs diff=lfs merge=lfs -text 2 | inference_results/dotprompts_results_sample.csv filter=lfs diff=lfs merge=lfs -text 3 | inference_results/ filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '22 13 * * 2' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | # Runner size impacts CodeQL analysis time. To learn more, please see: 27 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 28 | # - https://gh.io/supported-runners-and-hardware-resources 29 | # - https://gh.io/using-larger-runners 30 | # Consider using larger runners for possible analysis time improvements. 31 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 32 | timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} 33 | permissions: 34 | actions: read 35 | contents: read 36 | security-events: write 37 | 38 | strategy: 39 | fail-fast: false 40 | matrix: 41 | language: [ 'python' ] 42 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby', 'swift' ] 43 | # Use only 'java' to analyze code written in Java, Kotlin or both 44 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both 45 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 46 | 47 | steps: 48 | - name: Checkout repository 49 | uses: actions/checkout@v3 50 | 51 | # Initializes the CodeQL tools for scanning. 52 | - name: Initialize CodeQL 53 | uses: github/codeql-action/init@v2 54 | with: 55 | languages: ${{ matrix.language }} 56 | # If you wish to specify custom queries, you can do so here or in a config file. 57 | # By default, queries listed here will override any specified in a config file. 58 | # Prefix the list here with "+" to use these queries and those in the config file. 59 | 60 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 61 | # queries: security-extended,security-and-quality 62 | 63 | 64 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). 65 | # If this step fails, then you should remove it and run the build manually (see below) 66 | - name: Autobuild 67 | uses: github/codeql-action/autobuild@v2 68 | 69 | # ℹ️ Command-line programs to run using the OS shell. 70 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 71 | 72 | # If the Autobuild fails above, remove it and uncomment the following three lines. 73 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 74 | 75 | # - run: | 76 | # echo "Run, Build Application using script" 77 | # ./location_of_script_within_repo/buildscript.sh 78 | 79 | - name: Perform CodeQL Analysis 80 | uses: github/codeql-action/analyze@v2 81 | with: 82 | category: "/language:${{matrix.language}}" 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | .vscode/ 162 | 163 | src/monitors4codegen/multilspy/language_servers/*/static 164 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: >- 6 | Monitor-Guided Decoding of Code LMs with Static Analysis 7 | of Repository Context 8 | message: >- 9 | If you use this repository, please cite it using the metadata 10 | from this file. 11 | type: dataset 12 | authors: 13 | - given-names: Lakshya A 14 | family-names: Agrawal 15 | email: t-lakagrawal@microsoft.com 16 | affiliation: Microsoft Research 17 | orcid: 'https://orcid.org/0000-0003-0409-8212' 18 | - given-names: Aditya 19 | family-names: Kanade 20 | email: kanadeaditya@microsoft.com 21 | affiliation: Microsoft Research 22 | - given-names: Navin 23 | family-names: Goyal 24 | email: navingo@microsoft.com 25 | affiliation: Microsoft Research 26 | - given-names: Shuvendu K. 27 | family-names: Lahiri 28 | email: shuvendu.lahiri@microsoft.com 29 | affiliation: Microsoft Research 30 | - given-names: Sriram K. 31 | family-names: Rajamani 32 | email: sriram@microsoft.com 33 | affiliation: Microsoft Research 34 | identifiers: 35 | - type: doi 36 | value: 10.48550/arXiv.2306.10763 37 | - type: url 38 | value: >- 39 | https://openreview.net/forum?id=qPUbKxKvXq¬eId=98Ukj82fSP 40 | abstract: >- 41 | Language models of code (LMs) work well when the 42 | surrounding code provides sufficient context. This is not 43 | true when it becomes necessary to use types, functionality 44 | or APIs defined elsewhere in the repository or a linked 45 | library, especially those not seen during training. LMs 46 | suffer from limited awareness of such global context and 47 | end up hallucinating. 48 | 49 | 50 | Integrated development environments (IDEs) assist 51 | developers in understanding repository context using 52 | static analysis. We extend this assistance, enjoyed by 53 | developers, to LMs. We propose monitor-guided decoding 54 | (MGD) where a monitor uses static analysis to guide the 55 | decoding. We construct a repository-level dataset 56 | PragmaticCode for method-completion in Java and evaluate 57 | MGD on it. On models of varying parameter scale, by 58 | monitoring for type-consistent object dereferences, MGD 59 | consistently improves compilation rates and agreement with 60 | ground truth. Further, LMs with fewer parameters, when 61 | augmented with MGD, can outperform larger LMs. With MGD, 62 | SantaCoder-1.1B achieves better compilation rate and 63 | next-identifier match than the much larger 64 | text-davinci-003 model. 65 | 66 | 67 | We also conduct a generalizability study to evaluate the 68 | ability of MGD to generalize to multiple programming 69 | languages (Java, C# and Rust), coding scenarios (e.g., 70 | correct number of arguments to method calls), and to 71 | enforce richer semantic constraints (e.g., stateful API 72 | protocols). Our data and implementation are available at 73 | https://github.com/microsoft/monitors4codegen. 74 | keywords: 75 | - program analysis 76 | - correctness 77 | - code generation 78 | - Language models 79 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /RAI_Transparency_Information.md: -------------------------------------------------------------------------------- 1 | # Responsible AI Transparency Information 2 | ## What is Monitor-Guided Decoding (MGD)? 3 | Monitor-Guided Decoding (MGD) is a tool for Language Models (LMs) to generate more reliable code. It combines the token-by-token LM decoding with Program Analysis techniques (a method that can check the syntax, semantics, and logic of code, such as the ones used in Integrated Development Environments). Under MGD, a software called monitor runs concurrently to the decoder, and iteratively uses results from continuous program analysis to prevent generation of potentially problematic tokens, such as identifiers that are inconsistent with the type definitions. For example, a type analysis is performed at identifier dereferences, to find the list of type-correct symbols, and prevent generation of type-invalid symbols, thus generating code free from a large class of compilation errors. 4 | 5 | The static analysis in MGD is powered by Language Servers served over the Language Server Protocol. MGD takes as input a code repository, a partially completed code file within the repository, a prompt for the LM to generate the remaining code, and then uses a Language Model (from HuggingFace or OpenAI), to provide a code completion for it, while adhering to the monitored property. 6 | 7 | ## What can Monitor-Guided Decoding do? 8 | 9 | MGD can improve the quality and reliability of code generation by LMs, especially when the code involves using types or functionality defined in another module, library, or when the LM has not seen such types or functionality during training (for example, the library version has upgraded with new APIs defined or private codebases). MGD can also prevent the LM from hallucinating non-existent dereferenced identifiers. Since MGD is prompt-agnostic, it can be used for various code generation tasks, such as code writing, code repair, code refactoring, code completion, etc., simply by changing the prompt. MGD can also be applied to any programming language for which a Language Server (The Language Server must declare “textDocument/completion” capability) is available. 10 | 11 | ## What is/are Monitor-Guided Decoding’s intended use(s)? 12 | 13 | MGD is intended to be used as a research tool to advance the state of the art in and explore the potential of combining LM decoding with Program Analysis for code generation. It is also intended to be used as a baseline for evaluating and improving the performance of LMs on code generation tasks. It can be integrated in IDEs with LM based code-completion assistants; however, this use case has not been evaluated with users. MGD is not intended to be used as a substitute for human verification or testing of the generated code and does not provide guarantees for the generated code to be bug-free. 14 | 15 | ## How was Monitor-Guided Decoding evaluated? What metrics are used to measure performance? 16 | 17 | MGD was evaluated on a dataset of open-source Java repositories from GitHub, called PragmaticCode, which contains code snippets with different levels of complexity and context. The dataset was used to curate a code benchmark, called DotPrompts (consisting of >10,000 testcases), which consists of prompts that require the LM to generate the remaining code for a partially completed nontrivial method. The benchmark is set up such that the LM must generate non-local identifier dereferences to complete the method. 18 | 19 | MGD was applied to several off-the-shelf LMs of different sizes and domains, such as CodeGen-{350M, 2B, 6B}-Multi, SantaCoder-1.1B, and OpenAI text-davinci-003. The performance of LMs with and without MGD was measured using the following metrics: 20 | 21 | 1. Compilation Rate: Fraction of test cases, for which generated code compiled successfully 22 | 2. Next Identifier Match: Fraction of test cases, for which generated next identifier is accurate 23 | 3. Identifier Sequence Match: Percent prefix of ordered identifiers in the ground truth matched by the generated code 24 | 4. Prefix Match: Percent prefix of ground truth matched by generated code 25 | 26 | The metrics were aggregated over 6 indepedent trials for each testcase using the following aggregation: 27 | * score@k - estimate of best score achievable by the evaluated model, given k independent trials. 28 | 29 | The results show that MGD consistently improved the ability of the LMs to generate code that compiles and matches the ground truth, across different metrics and models. MGD also outperformed the prompting technique on most metrics. MGD also demonstrated that LMs with fewer parameters, when guided with MGD, can outperform larger LMs without MGD. 30 | 31 | ## What are the limitations of Monitor-Guided Decoding? How can users minimize the impact of Monitor-Guided Decoding’s limitations when using the system? 32 | 33 | MGD has some limitations that users should be aware of when using the system. Some of these limitations are: 34 | * The current instantiation of MGD monitors for type-consistent use of identifiers, which is one of the major sources of compilation errors in LM based code generation. However, there are other types of errors or bugs that MGD does not monitor or prevent, such as logical, syntactic, semantic, or runtime errors. Users should not rely on MGD to generate error-free code and should always verify and test the generated code for correctness and functionality. 35 | * MGD relies on the availability and accuracy of a Language Server for the programming language of interest. If the Language Server is not available, not compatible, or not reliable, MGD cannot be applied or may produce incorrect results. Users should ensure that the Language Server used is suitable and trustworthy. 36 | * MGD introduces some latency overhead to the code generation process, as it requires invoking the language server and masking the LM output iteratively. In our experiments, we find the latency overhead to not be significant, however, it may vary depending on the complexity of the code repository, size of the LM, speed of the static analysis, and the hardware and software configuration of the system. 37 | * MGD is a research tool that has not been extensively tested or validated with human users. It may not generalize well to domains and tasks that are beyond the scope of evaluation. 38 | 39 | ## What operational factors and settings allow for effective and responsible use of Monitor-Guided Decoding? 40 | 41 | MGD has been shown to enhance the output of the LM by preventing a class of errors appearing in the generated code. However, the underlying generated code is still limited by the capability of the base LM. 42 | Some of the operational factors and settings that can enable effective and responsible use of MGD are: 43 | * Choosing an appropriate LM for the code generation task and the programming language of interest. Users should select an LM that has been trained on a relevant and diverse corpus of code. Users should also be aware of the limitations and assumptions of the LM, and how they may affect the code generation quality and reliability. 44 | * Reviewing and testing the generated code for correctness and functionality. Users should not blindly trust or use the generated code without verifying and testing it for errors, bugs, or vulnerabilities. Users should also document and acknowledge the use of MGD and the LM for their code generation task and cite the relevant sources and references. 45 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please create an issue with the label "question". 10 | 11 | ## Microsoft Support Policy 12 | Support for `monitors4codegen` is limited to the resources listed above. 13 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # Starter pipeline 2 | # Start with a minimal pipeline that you can customize to build and deploy your code. 3 | # Add steps that build, run tests, deploy, and more: 4 | # https://aka.ms/yaml 5 | 6 | trigger: 7 | - main 8 | 9 | pool: 10 | vmImage: ubuntu-latest 11 | 12 | steps: 13 | - script: echo Hello, world! 14 | displayName: 'Run a one-line script' 15 | 16 | - script: | 17 | echo Add other tasks to build, test, and deploy your project. 18 | echo See https://aka.ms/yaml 19 | displayName: 'Run a multi-line script' 20 | 21 | - task: ComponentGovernanceComponentDetection@0 22 | inputs: 23 | scanType: 'Register' 24 | verbosity: 'Verbose' 25 | alertWarningLevel: 'High' 26 | 27 | - task: CodeQL3000Init@0 28 | - task: CodeQL3000Finalize@0 29 | 30 | # - task: CredScan@2 31 | # inputs: 32 | # toolMajorVersion: 'V2' 33 | 34 | # - task: ESLint@1 35 | # inputs: 36 | # Configuration: 'recommended' 37 | # TargetType: 'eslint' 38 | # ErrorLevel: 'warn' 39 | 40 | # - task: Semmle@0 41 | # env: 42 | # SYSTEM_ACCESSTOKEN: $(System.AccessToken) 43 | # inputs: 44 | # sourceCodeDirectory: '$(Build.SourcesDirectory)' 45 | # language: 'tsandjs' 46 | # includeNodeModules: true 47 | # querySuite: 'Recommended' 48 | # timeout: '1800' 49 | # ram: '16384' 50 | # addProjectDirToScanningExclusionList: true 51 | 52 | # - task: Semmle@1 53 | # inputs: 54 | # sourceCodeDirectory: '$(Build.SourcesDirectory)' 55 | # language: 'python' 56 | # querySuite: 'Recommended' 57 | # timeout: '1800' 58 | # ram: '16384' 59 | # addProjectDirToScanningExclusionList: true -------------------------------------------------------------------------------- /evaluation_scripts/eval_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from pygments.lexers import JavaLexer 3 | from pygments.token import Token 4 | 5 | def find_method_stop_actual(gen_text: str) -> Tuple[str, int]: 6 | assert gen_text[0] == '{', gen_text 7 | jlex = JavaLexer() 8 | toks = list(jlex.get_tokens(gen_text)) 9 | balance = -1 10 | int_toks = [] 11 | for tok in toks: 12 | int_toks.append(tok[1]) 13 | if tok[1] == '{' and Token.Punctuation in tok[0].split(): 14 | balance += 1 15 | elif tok[1] == '}' and Token.Punctuation in tok[0].split(): 16 | balance -= 1 17 | 18 | if balance == -1: 19 | break 20 | return (''.join(int_toks), balance+1) 21 | 22 | def find_method_stop(gen_text: str) -> str: 23 | """ 24 | Given the completed text for a method, returns the text for the method 25 | removing all of the extra text beyond the matching bracket. 26 | 27 | For example, if the input is: 28 | ``` 29 | int a = 0; 30 | int b = 1; 31 | if (a == b) { 32 | return a; 33 | } 34 | return b; 35 | } 36 | 37 | public int foo() { 38 | return 0; 39 | } 40 | ``` 41 | 42 | then the output is: 43 | ``` 44 | int a = 0; 45 | int b = 1; 46 | if (a == b) { 47 | return a; 48 | } 49 | return b; 50 | } 51 | ``` 52 | """ 53 | text, balance = find_method_stop_actual('{' + gen_text) 54 | text = text.rstrip() 55 | assert gen_text.startswith(text[1:]), (gen_text, text, balance) 56 | return text[1:] + ('}'*balance) 57 | 58 | def get_identifiers(text: str) -> List[str]: 59 | """ 60 | Returns the list of identifiers in the input text as per the PL tokenizer 61 | """ 62 | if len(text.strip()) == 0: 63 | return [] 64 | else: 65 | j = JavaLexer() 66 | ctok = list(j.get_tokens(text)) 67 | l = [] 68 | for tok in ctok: 69 | if Token.Name in tok[0].split() or tok[1] == 'class': 70 | l.append(tok[1]) 71 | return l 72 | 73 | def tokenizer_pl(text: str) -> List[str]: 74 | """ 75 | Tokenizes the input text as per the PL tokenizer removing all whitespaces 76 | """ 77 | if len(text.strip()) == 0: 78 | return [] 79 | else: 80 | j = JavaLexer() 81 | ctok = list(j.get_tokens(text)) 82 | l = [] 83 | for tok in ctok: 84 | if Token.Text in tok[0].split() and tok[1].strip() != '': 85 | raise Exception(text, tok) 86 | elif Token.Text.Whitespace in tok[0].split(): 87 | assert tok[1].strip() == '', (text, ctok, tok) 88 | else: 89 | l.append(tok[1]) 90 | return l 91 | 92 | def get_first_token(inp_text: str) -> str: 93 | """ 94 | Returns the first token as per the PL tokenizer for the input text 95 | """ 96 | inp_text = inp_text.lstrip() 97 | tokens = tokenizer_pl(inp_text) 98 | if len(tokens) == 0: 99 | return None 100 | return tokens[0] -------------------------------------------------------------------------------- /figures/figure_2b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/figures/figure_2b.png -------------------------------------------------------------------------------- /figures/motivating_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/figures/motivating_example.png -------------------------------------------------------------------------------- /inference_results/README.md: -------------------------------------------------------------------------------- 1 | This directory contains the inference results for various models reported in the paper on DotPrompts. 2 | 3 | The files in this directory are stored in [git lfs](https://git-lfs.com/). If after cloning, this directory doesn't contain files `dotprompts_results.csv` and `dotprompts_results_sample.csv`, then setup git-lfs from the above reference, and clone the repository again. 4 | -------------------------------------------------------------------------------- /inference_results/dotprompts_results.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:14d4b0613689eb51bf1f27009c8c262c22f7e1db4e57666b1ef4dba8e9843ddc 3 | size 1719759013 4 | -------------------------------------------------------------------------------- /inference_results/dotprompts_results_sample.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c68645588cb670b74f2469ba3ebe8f626b4235a6ed5581caad5d00b73876b834 3 | size 1656386 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Read https://setuptools.pypa.io/en/latest/userguide/datafiles.html 2 | [build-system] 3 | requires = ["flit_core>=3.4"] 4 | build-backend = "flit_core.buildapi" 5 | 6 | [project] 7 | name = "monitors4codegen" 8 | version = "0.0.1" 9 | authors = [ 10 | { name="Lakshya A Agrawal", email="t-lakagrawal@microsoft.com" }, 11 | ] 12 | description = "Code for running Monitor-Guided Decoding (https://github.com/microsoft/monitors4codegen) including multilspy: A language-agnostic LSP client in Python, with a library interface. Intended to be used to build applications around language servers. Currently multilspy supports language servers for Python, Rust, Java and C#." 13 | readme = "README.md" 14 | requires-python = ">=3.7" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "Operating System :: OS Independent", 18 | "Development Status :: 2 - Pre-Alpha", 19 | "Topic :: Software Development", 20 | "Topic :: Text Editors :: Integrated Development Environments (IDE)", 21 | "Programming Language :: C#", 22 | "Programming Language :: Java", 23 | "Programming Language :: Python", 24 | "Programming Language :: Rust" 25 | ] 26 | 27 | dependencies = [ 28 | "jedi-language-server==0.41.1", 29 | "pydantic==1.10.5", 30 | "code-tokenize==0.2.0", 31 | "code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0", 32 | "openai==1.3.3", 33 | "torch==1.12.0", 34 | "transformers==4.30.0", 35 | "tiktoken==0.3.3", 36 | "pygtrie==2.5.0" 37 | ] 38 | 39 | [project.urls] 40 | "Homepage" = "https://github.com/microsoft/monitors4codegen" 41 | "Bug Tracker" = "https://github.com/microsoft/monitors4codegen/issues" 42 | 43 | [tool.setuptools] 44 | include-package-data = true 45 | 46 | [tool.setuptools.packages.find] 47 | where = ["src"] 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.64.1 2 | pandas==1.5.3 3 | tiktoken==0.3.3 4 | transformers==4.30.0 5 | Pygments==2.15.1 6 | matplotlib==3.7.1 7 | seaborn==0.12.2 8 | tabulate==0.9.0 9 | jedi-language-server==0.41.1 10 | pytest==7.3.1 11 | pydantic==1.10.5 12 | pytest-asyncio==0.21.1 13 | pygtrie==2.5.0 14 | openai==1.3.3 15 | code-tokenize==0.2.0 16 | code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0 17 | --extra-index-url https://download.pytorch.org/whl/cu113 18 | torch==1.12.0+cu113 -------------------------------------------------------------------------------- /results/all_metrics_table.csv: -------------------------------------------------------------------------------- 1 | ,Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM) 2 | ,score@1,score@2,score@3,score@4,score@5,score@6,score@1,score@2,score@3,score@4,score@5,score@6,score@1,score@2,score@3,score@4,score@5,score@6,score@1,score@2,score@3,score@4,score@5,score@6 3 | configuration,,,,,,,,,,,,,,,,,,,,,,,, 4 | CG-350M,35.72467893971026,43.04169039033339,46.82102865818941,49.268045802492566,51.04384133611691,52.429303473144806,68.8887834503701,72.68488644271525,74.40738280508634,75.49756437021574,76.29689378123616,76.94059593850827,26.09951681861391,28.739771250264123,29.9908815173615,30.792002831132674,31.38411036293041,31.860754128118383,21.814904018784468,24.12332222985021,25.22550146653636,25.930035177715478,26.447455373871094,26.859342538292914 5 | CG-350M-MGD,44.58309609666603,53.70405516543304,58.413361169102295,61.47276523059404,63.68697412538749,65.3729360409945,76.22255962548239,79.89055481748592,81.52827861074206,82.54570759789968,83.26374391092553,83.80148035680395,28.1501137132769,30.928529319805843,32.27009998457561,33.144187276699974,33.79315538995351,34.312888330859316,23.37906883761845,25.779591155583336,26.93647317276529,27.688796913163145,28.24729887028087,28.69181933350254 6 | CG-2B,42.05099006769152,48.746125134434116,52.07060163218827,54.21332321123553,55.782248370974884,57.012715885367236,73.81065350793952,77.25501360156892,78.85177453027141,79.85702536850764,80.56873537040552,81.10647181628391,30.196497341320345,32.94233721091501,34.3133867876073,35.20560110086433,35.861304810572555,36.378986411057845,25.523165660650648,27.929734092841272,29.124944804533904,29.906469777225887,30.485036869620174,30.94616454048572 7 | CG-2B-MGD,52.87530840766749,61.22603909660277,65.23154298728412,67.7370785095211,69.5245777187322,70.91478458910609,80.51970645916367,83.75403302334409,85.20971721389257,86.13272600746504,86.80173340924908,87.32207249952553,32.3636839450562,35.312349892255156,36.779485540769514,37.74605568907297,38.46333838318198,39.030567294028316,27.282711532732545,29.843024671606127,31.113881982699514,31.948967783330517,32.5686563353729,33.056896378757116 8 | CG-6B,44.19244638451319,50.58834693490226,53.773486430062626,55.86702094008983,57.41601821977604,58.64490415638641,74.52868982096538,77.83007528310243,79.3252989182008,80.27772505851839,80.98469032707027,81.5524767508066,31.00500918057365,33.7709921600244,35.115850558373104,36.00086952133321,36.657218223572485,37.170356661133894,26.35430859924353,28.791356344605862,29.953969822765238,30.705922660370344,31.25947776948833,31.692726080724754 9 | CG-6B-MGD,55.17808565825266,62.92465363446574,66.66350351110268,69.09597013981147,70.88473461124819,72.28126779275004,81.09223761624597,84.13614221547415,85.47969254127918,86.30796482571013,86.89662807616878,87.35054089960144,33.170021664130545,35.99230150446229,37.3855414566833,38.303635312240345,38.99309128574137,39.551422563688384,28.11611358772491,30.575511890896927,31.7612018613006,32.529518440292165,33.10097376296201,33.55998186239897 10 | SC,44.69696969696969,51.31966850129689,54.686847599164935,56.93363699626747,58.625925223002476,59.97342949326247,75.7433415575378,78.90554817485925,80.32074397418866,81.21401910545961,81.87195546276965,82.3970392863921,31.844954534698743,34.58602420469151,35.97138829283025,36.89269316962523,37.583893904329116,38.14246748596227,26.751506031054216,29.10102192360234,30.27223391181647,31.047986517232296,31.62991478153248,32.09694265189863 11 | SC-MGD,55.34415132536219,63.55728474726388,67.48718921996583,69.94369583096096,71.68817612450181,73.03093566141582,82.3005630416904,85.23945087619408,86.53159992408426,87.34484721958627,87.94363256784969,88.42285063579428,34.13716084363118,37.02231451203238,38.459479729814475,39.411010795139795,40.12358091638569,40.69126792613497,28.59519180273788,31.103386534083494,32.348529156030075,33.16616158863896,33.77306875940889,34.25429395470208 12 | SC-classExprTypes,48.06098563927373,55.401404441070426,59.041563864110834,61.41266527487821,63.171379768457015,64.56633137217689,78.56803947618144,81.61131144429685,82.95929018789144,83.80337825014233,84.41829569178212,84.91174795976465,33.208366403106176,36.056266816013874,37.478802411702965,38.41431131263999,39.11160734421112,39.66597179832546,27.93045609078515,30.410749046321904,31.639297765435444,32.448161012012086,33.05745204371009,33.55059359030316 13 | SC-classExprTypes-MGD,56.79920288479787,65.20149300942622,69.23040425128107,71.76820396027077,73.59081419624218,75.01423420003796,83.55475422281268,86.3851458214715,87.62288859366105,88.38995381792876,88.94477130385272,89.37179730499146,34.9193283445047,37.85979581483904,39.31270279654634,40.27142199506433,40.98872703192434,41.561747825635884,29.29705980745126,31.817161566487716,33.057335328558,33.87660551629196,34.489919407919515,34.97838228951299 14 | SC-RLPG,50.59309166824825,57.76048586069463,61.24406908331751,63.47947112038971,65.11355728474726,66.38830897703549,79.65300183463023,82.50585183779339,83.74881381666351,84.49990510533308,85.02562156006832,85.42417916113114,35.57030742117765,38.523676239292755,40.0314989931763,41.02941207514353,41.76928258840009,42.35452507467172,30.3421651443446,32.92850048610419,34.23106580892534,35.084644928030265,35.714143438145676,36.20780978633676 15 | SC-RLPG-MGD,60.50483962801291,68.7828177389764,72.65562725374835,75.07623204909217,76.80141709369266,78.13626874169671,84.71246915923325,87.27778832162967,88.37113304232301,89.04535965078763,89.525210349845,89.89371797304992,37.47518405500656,40.56383208857486,42.11175045927945,43.12632822045407,43.87652277126343,44.4730310865453,31.814080885653013,34.52967851218502,35.8801891067139,36.77091803723661,37.437290631702744,37.97050621903203 16 | SC-FIM,51.71917504902891,59.30853419371164,62.942209147845894,65.24704244954768,66.91813753400393,68.22926551527803,79.58499399000442,82.42360979312961,83.68238754981971,84.48535458973872,85.07939520465617,85.55703169481875,35.52535564186261,38.48782964808012,39.94757857525001,40.91082987857955,41.636248362536975,42.22065822909456,30.33911570788072,32.92880416204875,34.190557565158635,35.01584529282747,35.63181795122734,36.12326222678888 17 | SC-FIM-MGD,62.157588410198,70.72626051749225,74.71484152590624,77.16581261466439,78.88593661036249,80.18599354716265,84.67767444802934,87.24362624153856,88.33412412222432,89.00866704624534,89.50148668311509,89.89371797304992,37.552752511782,40.611271224309895,42.12766257464915,43.13394258135106,43.890162105311745,44.49711268122593,31.937943774854844,34.602162814486135,35.90357997798328,36.75709360749582,37.397429605899994,37.90990648077383 18 | SC-FIM-classExprTypes,53.30075283102423,61.42911368381099,65.34541658758778,67.81868792307206,69.597330296704,70.97172138925792,81.36110583918516,84.13108116657178,85.32453976086543,86.05997342949327,86.58347567533372,86.9899411653065,36.097171230579335,39.04408615331272,40.47921804789472,41.42170671893477,42.119985635428044,42.67132096027691,30.62213625758116,33.18586987314861,34.43831532349217,35.26025010271139,35.87383101905703,36.364915283636485 19 | SC-FIM-classExprTypes-MGD,61.711583475675326,70.61175428607581,74.69586259252232,77.20060732586828,78.96817865502626,80.32833554754222,85.32137660530145,87.81362687416967,88.88878345037008,89.5603213766053,90.04396786233946,90.41563864110836,37.6806357413135,40.6137479174813,42.02259470282186,42.944795018524076,43.63307263953589,44.18101757232171,31.87785127651206,34.512492477899784,35.79316383290621,36.62882297603417,37.25316122606348,37.75472435306206 20 | TD-3,51.71284873790092,56.731827671284876,59.108464604289246,60.64022268615171,61.77010185360917,62.65894856709053,80.75536154868097,83.42190168912506,84.56301005883469,85.2704498007212,85.78161573986208,86.18333649648889,38.57358094332489,41.48641650964019,42.871031368079855,43.7698993238355,44.438416245428684,44.97281253211191,33.1271966402104,35.68767550441467,36.90421431038277,37.69651924868723,38.28913981947876,38.76563801243524 21 | TD-3-MGD,61.233946985512745,67.26576833048648,70.11007781362689,71.92636173847029,73.23970392863922,74.26456633137218,86.32093376352249,88.70816726766621,89.73619282596317,90.37325235655089,90.83317517555514,91.19377490985006,40.69493824648326,43.67129848330911,45.12256080060183,46.07262572794244,46.77485377131491,47.32860672373113,34.29410802966261,36.84409443119493,38.06843318567872,38.86682739708561,39.46362889865108,39.940870471594124 22 | -------------------------------------------------------------------------------- /results/figures/fig_method_dist_by_max_identifier_complexity_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/fig_method_dist_by_max_identifier_complexity_dist.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_CR_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_CR_fim_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_CR_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_CR_models_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_CR_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_CR_prompt_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_ISM_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_ISM_fim_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_ISM_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_ISM_models_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_ISM_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_ISM_prompt_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_NIM_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_NIM_fim_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_NIM_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_NIM_models_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_NIM_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_NIM_prompt_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_PM_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_PM_fim_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_PM_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_PM_models_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_PM_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_PM_prompt_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_next_identifier_match_fim_id_complexity_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_next_identifier_match_fim_id_complexity_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_next_identifier_match_models_id_complexity_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_next_identifier_match_models_id_complexity_heatmap.png -------------------------------------------------------------------------------- /results/figures/rel_changes/fig_next_identifier_match_prompt_id_complexity_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/rel_changes/fig_next_identifier_match_prompt_id_complexity_heatmap.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_CR_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_CR_fim.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_CR_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_CR_models.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_CR_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_CR_prompt.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_ISM_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_ISM_fim.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_ISM_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_ISM_models.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_ISM_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_ISM_prompt.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_NIM_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_NIM_fim.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_NIM_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_NIM_models.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_NIM_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_NIM_prompt.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_PM_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_PM_fim.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_PM_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_PM_models.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_PM_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_PM_prompt.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_next_identifier_match_fim_id_complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_next_identifier_match_fim_id_complexity.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_next_identifier_match_models_id_complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_next_identifier_match_models_id_complexity.png -------------------------------------------------------------------------------- /results/figures/score_at_k/fig_next_identifier_match_prompt_id_complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results/figures/score_at_k/fig_next_identifier_match_prompt_id_complexity.png -------------------------------------------------------------------------------- /results_sample/all_metrics_table.csv: -------------------------------------------------------------------------------- 1 | ,Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Compilation Rate (CR),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Next Identifier Match (NIM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Identifier Sequence Match (ISM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM),Prefix Match (PM) 2 | ,score@1,score@2,score@3,score@4,score@5,score@6,score@1,score@2,score@3,score@4,score@5,score@6,score@1,score@2,score@3,score@4,score@5,score@6,score@1,score@2,score@3,score@4,score@5,score@6 3 | configuration,,,,,,,,,,,,,,,,,,,,,,,, 4 | CG-350M,48.148148148148145,56.29629629629629,60.555555555555564,62.962962962962955,64.81481481481482,66.66666666666666,51.85185185185185,54.81481481481482,55.55555555555556,55.55555555555556,55.55555555555556,55.55555555555556,15.544733044733045,16.854256854256853,17.56854256854257,18.00505050505051,18.256373256373255,18.322510822510825,13.674781561928581,14.105552834993762,14.218389686181865,14.259541949556354,14.30069421293084,14.341846476305326 5 | CG-350M-MGD,46.29629629629629,57.03703703703704,62.22222222222222,65.18518518518519,66.66666666666666,66.66666666666666,64.81481481481482,70.37037037037037,72.22222222222221,74.07407407407408,75.92592592592592,77.77777777777779,16.34867190422746,17.47668483779595,18.135121051787717,18.767102239324462,19.399083426861203,20.031064614397952,14.64595185271856,15.287404055742684,15.57911288661363,15.854360812134777,16.12960873765593,16.40485666317708 6 | CG-2B,31.481481481481477,37.77777777777778,41.66666666666667,43.7037037037037,44.44444444444444,44.44444444444444,51.85185185185185,61.48148148148148,65.55555555555556,66.66666666666666,66.66666666666666,66.66666666666666,15.391414141414142,16.726791726791728,17.58627946127946,18.203463203463205,18.71933621933622,19.182299182299182,13.577779798260151,15.531772705658076,16.079059880185397,16.366537834330014,16.543492566840303,16.687525488651005 7 | CG-2B-MGD,46.29629629629629,55.55555555555556,61.66666666666667,65.18518518518519,66.66666666666666,66.66666666666666,83.33333333333334,94.81481481481482,98.88888888888889,100.0,100.0,100.0,19.798848271070494,21.881446742557852,22.80222863556197,23.404080487413825,23.867043450376784,24.330006413339746,17.726441414493742,20.178469161626523,20.696772234662543,20.84930110762152,20.84930110762152,20.84930110762152 8 | CG-6B,40.74074074074074,43.7037037037037,44.44444444444444,44.44444444444444,44.44444444444444,44.44444444444444,55.55555555555556,62.22222222222222,64.44444444444444,65.92592592592594,66.66666666666666,66.66666666666666,16.573873657206992,18.018278018278018,18.673039923039923,18.99270482603816,19.116161616161616,19.182299182299182,14.237754217692956,14.810067784032807,15.115509027745652,15.386199471275605,15.622139114622657,15.823327957786807 9 | CG-6B-MGD,51.85185185185186,64.44444444444444,70.0,73.33333333333333,75.92592592592592,77.77777777777779,75.92592592592592,82.22222222222223,86.11111111111111,88.14814814814815,88.88888888888889,88.88888888888889,18.035914702581373,20.02665544332211,21.271143979477316,21.99755491422158,22.434062850729518,22.808842392175727,15.950538722226549,17.284453105497537,18.227636432193407,18.906914276345802,19.388395757827283,19.73818999651041 10 | SC,29.629629629629633,36.29629629629629,42.22222222222222,47.40740740740741,51.85185185185186,55.55555555555556,55.55555555555556,59.25925925925925,61.111111111111114,62.962962962962955,64.81481481481482,66.66666666666666,17.07691999358666,19.045614878948214,20.45875420875421,21.620570787237455,22.577360910694246,23.348965848965854,15.212017722820656,16.44061190840409,17.32950079729298,18.127854706757997,18.88917157918598,19.650488451613967 11 | SC-MGD,38.88888888888889,51.85185185185185,61.66666666666667,68.88888888888889,74.07407407407408,77.77777777777779,74.07407407407408,85.92592592592592,93.33333333333333,97.77777777777779,100.0,100.0,17.532534338089892,19.842939981828874,21.682299182299182,23.25563839452728,24.675391480947034,25.961399711399714,15.944789882331559,17.781012143725686,19.167252939234555,20.25922712349775,21.14096496334124,21.82481213777736 12 | SC-classExprTypes,31.48148148148148,40.0,47.77777777777777,54.81481481481482,61.111111111111114,66.66666666666666,57.4074074074074,65.18518518518519,70.0,73.33333333333333,75.92592592592592,77.77777777777779,15.522687189353856,16.382475549142217,16.774891774891778,16.990941157607825,17.176126342793008,17.330447330447328,14.984795279349708,16.168652972167973,16.909747710873223,17.441983650516573,17.894658547635913,18.292463760255938 13 | SC-classExprTypes-MGD,42.59259259259259,57.03703703703703,66.66666666666666,72.59259259259258,75.92592592592592,77.77777777777779,79.62962962962963,91.85185185185186,97.22222222222221,99.25925925925925,100.0,100.0,18.693616054727162,21.49050024050024,23.280623697290366,24.377037571482013,24.969336219336224,25.12365720699054,15.64731673624477,18.666632838291942,20.043009083964595,20.944048759543836,21.521454742738122,21.83695542860918 14 | SC-RLPG,50.0,54.81481481481482,55.55555555555556,55.55555555555556,55.55555555555556,55.55555555555556,85.18518518518518,88.14814814814815,88.88888888888889,88.88888888888889,88.88888888888889,88.88888888888889,22.362113195446533,23.618726952060282,23.803912137245472,23.803912137245472,23.803912137245472,23.803912137245472,17.89875830188084,19.51675162692389,19.708795522671494,19.763665207170806,19.791100049420464,19.791100049420464 15 | SC-RLPG-MGD,55.55555555555556,65.18518518518519,70.0,73.33333333333333,75.92592592592592,77.77777777777779,94.44444444444444,99.25925925925925,100.0,100.0,100.0,100.0,27.73528940195607,29.884159050825716,30.285393618726953,30.285393618726953,30.285393618726953,30.285393618726953,22.183492771800495,24.23920247283153,24.92827426066875,25.305503341601533,25.634721448597418,25.963939555593306 16 | SC-FIM,38.88888888888889,43.7037037037037,47.77777777777777,51.11111111111111,53.70370370370371,55.55555555555556,94.44444444444444,100.0,100.0,100.0,100.0,100.0,27.68949414782748,30.534511784511785,30.999679333012665,31.05258938592272,31.079044412377744,31.079044412377744,22.701279905464745,24.50166246368658,24.91044161320647,25.06407672980455,25.147752998666007,25.182046551478077 17 | SC-FIM-MGD,57.4074074074074,65.18518518518519,70.0,73.33333333333333,75.92592592592592,77.77777777777779,96.29629629629632,100.0,100.0,100.0,100.0,100.0,34.4608786275453,36.40576398909732,36.682219015552356,36.82904441237775,36.89253647586981,36.89253647586981,27.636857705005536,29.034855631470723,29.521533494287972,29.77358790245176,29.881096588015133,29.881096588015133 18 | SC-FIM-classExprTypes,40.74074074074075,54.074074074074076,63.33333333333333,69.62962962962963,74.07407407407408,77.77777777777779,87.03703703703705,92.59259259259258,94.44444444444444,96.29629629629629,98.14814814814815,100.0,23.814935064935064,26.109908609908604,28.106160814494153,30.049502966169634,31.96639009139009,33.85682219015552,19.747661549191843,21.760007228204188,23.44953626341223,25.036641887554893,26.574364795648176,28.062704987692065 19 | SC-FIM-classExprTypes-MGD,46.29629629629629,58.51851851851851,65.55555555555556,70.37037037037037,74.07407407407408,77.77777777777779,92.59259259259261,99.25925925925925,100.0,100.0,100.0,100.0,27.084335417668747,30.41546416546417,32.28274811608145,33.625340708674045,34.78274811608145,35.94015552348886,22.386000245566766,24.97628523460565,26.301388115264086,27.267094562452016,28.15872693556587,29.050359308679724 20 | TD-3,22.22222222222222,22.22222222222222,22.22222222222222,22.22222222222222,22.22222222222222,22.22222222222222,75.92592592592592,83.7037037037037,86.66666666666667,88.14814814814815,88.88888888888889,88.88888888888889,19.263468013468017,20.185385602052268,20.35513868847202,20.447731281064616,20.49402757736091,20.49402757736091,16.6445849275967,17.342213773373693,17.601015785262124,17.708926164777445,17.762881354535104,17.762881354535104 21 | TD-3-MGD,25.92592592592593,29.629629629629626,33.33333333333333,37.03703703703704,40.74074074074075,44.44444444444444,75.92592592592592,91.85185185185186,98.33333333333333,100.0,100.0,100.0,18.626643418310085,21.64081289081289,23.643378226711555,25.092793009459673,26.24248436748437,27.206990540323872,17.092550387509593,19.22532701690112,20.57896992988294,21.539744637571225,22.353644957644388,23.167545277717544 22 | -------------------------------------------------------------------------------- /results_sample/figures/fig_method_dist_by_max_identifier_complexity_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/fig_method_dist_by_max_identifier_complexity_dist.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_CR_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_CR_fim_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_CR_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_CR_models_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_CR_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_CR_prompt_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_ISM_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_ISM_fim_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_ISM_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_ISM_models_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_ISM_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_ISM_prompt_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_NIM_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_NIM_fim_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_NIM_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_NIM_models_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_NIM_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_NIM_prompt_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_PM_fim_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_PM_fim_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_PM_models_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_PM_models_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_PM_prompt_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_PM_prompt_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_next_identifier_match_fim_id_complexity_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_next_identifier_match_fim_id_complexity_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_next_identifier_match_models_id_complexity_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_next_identifier_match_models_id_complexity_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/rel_changes/fig_next_identifier_match_prompt_id_complexity_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/rel_changes/fig_next_identifier_match_prompt_id_complexity_heatmap.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_CR_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_CR_fim.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_CR_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_CR_models.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_CR_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_CR_prompt.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_ISM_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_ISM_fim.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_ISM_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_ISM_models.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_ISM_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_ISM_prompt.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_NIM_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_NIM_fim.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_NIM_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_NIM_models.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_NIM_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_NIM_prompt.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_PM_fim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_PM_fim.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_PM_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_PM_models.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_PM_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_PM_prompt.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_next_identifier_match_fim_id_complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_next_identifier_match_fim_id_complexity.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_next_identifier_match_models_id_complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_next_identifier_match_models_id_complexity.png -------------------------------------------------------------------------------- /results_sample/figures/score_at_k/fig_next_identifier_match_prompt_id_complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/results_sample/figures/score_at_k/fig_next_identifier_match_prompt_id_complexity.png -------------------------------------------------------------------------------- /src/monitors4codegen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/monitors4codegen/5003004198eaeb3d06461427337e615b9df4b18c/src/monitors4codegen/__init__.py -------------------------------------------------------------------------------- /src/monitors4codegen/monitor_guided_decoding/hf_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides the definition of a monitor as per the Monitor-Guided Decoding framework 3 | """ 4 | 5 | import asyncio 6 | import torch 7 | 8 | from asyncio.events import AbstractEventLoop 9 | from typing import List, Union 10 | from transformers import LogitsProcessor 11 | from monitors4codegen.monitor_guided_decoding.monitor import Monitor 12 | 13 | class MGDLogitsProcessor(LogitsProcessor): 14 | """ 15 | Provides the logits processor for monitor guided decoding 16 | """ 17 | 18 | loop: AbstractEventLoop 19 | 20 | def __init__(self, monitors: List[Monitor], loop: Union[None, AbstractEventLoop] = None) -> None: 21 | super().__init__() 22 | 23 | if loop is None: 24 | self.loop = asyncio.get_event_loop() 25 | else: 26 | self.loop = loop 27 | 28 | self.monitors: List[Monitor] = monitors 29 | 30 | async def process_scores_for_single_input_id( 31 | self, segment_idx: int, input_ids: torch.LongTensor, scores: torch.FloatTensor 32 | ) -> torch.FloatTensor: 33 | """ 34 | Asynchronously processes the scores for a single input id using the MGD framework 35 | """ 36 | blacklisted_ids: List[int] = await self.monitors[segment_idx].maskgen(input_ids.tolist()) 37 | output_scores: torch.FloatTensor = torch.where( 38 | torch.tensor([True if i in blacklisted_ids else False for i in range(scores.shape[0])]).to(scores.device), 39 | float("-inf") * torch.ones(scores.shape[0]).to(scores.device), 40 | scores, 41 | ).to(scores) 42 | return output_scores 43 | 44 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 45 | """ 46 | This method is called by the HuggingFace decoder, for every token generation with 47 | the input_ids (seen so far including prompt) and scores (for the next token). 48 | This method processes the scores using the MGD framework. 49 | """ 50 | assert len(input_ids.shape) == 2 51 | assert input_ids.shape[0] == len(self.monitors) 52 | assert len(scores.shape) == 2 53 | 54 | async def f(input_ids_arg: torch.LongTensor, scores_arg: torch.FloatTensor): 55 | new_score_coroutines = [ 56 | self.process_scores_for_single_input_id(i, input_ids_arg[i], scores_arg[i]) 57 | for i in range(input_ids_arg.shape[0]) 58 | ] 59 | new_scores = await asyncio.gather(*new_score_coroutines) 60 | return tuple(new_scores) 61 | 62 | future = asyncio.run_coroutine_threadsafe(f(input_ids, scores), self.loop) 63 | results = future.result() 64 | new_scores = torch.stack(results, dim=0).to(scores) 65 | return new_scores 66 | -------------------------------------------------------------------------------- /src/monitors4codegen/monitor_guided_decoding/mgd_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the utility functions for handling programming language text 3 | """ 4 | 5 | import code_tokenize as ctok 6 | 7 | from typing import List 8 | from monitors4codegen.multilspy.multilspy_config import Language 9 | 10 | 11 | class PLUtils: 12 | """ 13 | This class provides various utility functions for handling programming language text 14 | """ 15 | 16 | @staticmethod 17 | def tokenizer_pl(inp_text: str, lang: Language) -> List[ctok.tokens.ASTToken]: 18 | """ 19 | Tokenizes the given text using code_tokenize 20 | """ 21 | lang_s = str(lang) if lang != Language.CSHARP else "c-sharp" 22 | if inp_text.strip() == "": 23 | return [] 24 | lsp_text_lang_tokenized: List[ctok.tokens.ASTToken] = ctok.tokenize( 25 | inp_text, lang=lang_s, syntax_error="ignore" 26 | ) 27 | lsp_text_lang_tokenized: List[ctok.tokens.ASTToken] = [tok for tok in lsp_text_lang_tokenized if tok.text != ""] 28 | return lsp_text_lang_tokenized 29 | 30 | @staticmethod 31 | def get_opening_bracket_stream(inp_text: str, lang: Language) -> List[str]: 32 | """ 33 | Returns the list of opened brackets in the given text 34 | """ 35 | bracket_stream: List[str] = [] 36 | err = False 37 | lsp_text_lang_tokenized = PLUtils.tokenizer_pl(inp_text, lang) 38 | for tok in lsp_text_lang_tokenized: 39 | if tok.type in ["{", "(", "["]: 40 | bracket_stream.append(tok.type) 41 | elif tok.type in ["}", ")", "]"]: 42 | if len(bracket_stream) == 0: 43 | err = True 44 | break 45 | if ( 46 | (tok.type == "}" and bracket_stream[-1] == "{") 47 | or (tok.type == ")" and bracket_stream[-1] == "(") 48 | or (tok.type == "]" and bracket_stream[-1] == "[") 49 | ): 50 | bracket_stream.pop() 51 | else: 52 | err = True 53 | break 54 | if err: 55 | raise Exception("Invalid bracket stream") 56 | return bracket_stream 57 | 58 | @staticmethod 59 | def get_closing_bracket_stream(inp_text: str, lang: Language) -> List[str]: 60 | """ 61 | Returns the list of closing brackets in the given text 62 | """ 63 | bracket_stream: List[str] = [] 64 | err = False 65 | lsp_text_lang_tokenized = PLUtils.tokenizer_pl(inp_text, lang) 66 | for tok in lsp_text_lang_tokenized[::-1]: 67 | if tok.type in ["}", ")", "]"]: 68 | bracket_stream.append(tok.type) 69 | elif tok.type in ["{", "(", "["]: 70 | if len(bracket_stream) == 0: 71 | err = True 72 | break 73 | if ( 74 | (tok.type == "{" and bracket_stream[-1] == "}") 75 | or (tok.type == "(" and bracket_stream[-1] == ")") 76 | or (tok.type == "[" and bracket_stream[-1] == "]") 77 | ): 78 | bracket_stream.pop() 79 | else: 80 | err = True 81 | break 82 | if err: 83 | raise Exception("Invalid bracket stream") 84 | return bracket_stream[::-1] 85 | -------------------------------------------------------------------------------- /src/monitors4codegen/monitor_guided_decoding/monitor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides the definition of a monitor as per the Monitor-Guided Decoding framework 3 | """ 4 | 5 | from typing import List, Tuple 6 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper 7 | from monitors4codegen.multilspy import LanguageServer 8 | from monitors4codegen.multilspy.multilspy_config import Language 9 | from dataclasses import dataclass 10 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 11 | 12 | @dataclass 13 | class MonitorFileBuffer: 14 | """ 15 | Dataclass for storing the state of the monitor for the prompt file in which the generation is happening 16 | """ 17 | 18 | lsp: LanguageServer 19 | file_path: str 20 | prompt_lc: Tuple[int, int] 21 | current_lc: Tuple[int, int] 22 | language: Language 23 | gen_text: str = "" 24 | 25 | def append_text(self, text: str): 26 | """ 27 | Appends the given text to the prompt file and returns the new line and character 28 | """ 29 | current_lc_index = TextUtils.get_index_from_line_col( 30 | self.lsp.get_open_file_text(self.file_path), self.current_lc[0], self.current_lc[1] 31 | ) 32 | new_lc = self.lsp.insert_text_at_position(self.file_path, self.current_lc[0], self.current_lc[1], text) 33 | self.current_lc = (new_lc["line"], new_lc["character"]) 34 | self.gen_text += text 35 | assert current_lc_index + len(text) == TextUtils.get_index_from_line_col( 36 | self.lsp.get_open_file_text(self.file_path), self.current_lc[0], self.current_lc[1] 37 | ) 38 | 39 | 40 | class Monitor: 41 | """ 42 | Provides the definition of a monitor as per the Monitor-Guided Decoding framework 43 | """ 44 | 45 | def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: 46 | self.tokenizer = tokenizer 47 | self.monitor_file_buffer = monitor_file_buffer 48 | self.responsible_for_file_buffer_state = responsible_for_file_buffer_state 49 | 50 | async def pre(self) -> None: 51 | """ 52 | If the current state is uninitialized, or s0, this function checks 53 | if the static analysis should be performed at this point. 54 | If yes, it invokes the static analysis, and updates the state. 55 | """ 56 | raise NotImplementedError() 57 | 58 | async def maskgen(self, input_ids: List[int]) -> List[int]: 59 | """ 60 | Given input_ids, which is the list of token ids generated so far (or input for the first time), 61 | this function returns the list of token ids that should be masked for the next token generation. 62 | 63 | This is the function that is invoked by the end user at every token decodes. 64 | """ 65 | raise NotImplementedError() 66 | 67 | def a_phi(self): 68 | """ 69 | This function defines the implementation of the static analysis, 70 | and returns the result of the static analysis. 71 | It is invoked primarily by pre() 72 | """ 73 | raise NotImplementedError() 74 | 75 | def update(self, generated_token: str): 76 | """ 77 | This function updates the state of the monitor, given the generated token. 78 | """ 79 | raise NotImplementedError() 80 | -------------------------------------------------------------------------------- /src/monitors4codegen/monitor_guided_decoding/monitors/class_instantiation_monitor.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the class-instantiation monitor, that is invoked when "new " is typed to instantiate new classes 3 | """ 4 | 5 | import os 6 | 7 | from pathlib import PurePath 8 | from typing import List 9 | from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates 10 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 11 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper 12 | from monitors4codegen.multilspy.multilspy_utils import TextUtils, FileUtils 13 | from monitors4codegen.multilspy import multilspy_types 14 | 15 | class ClassInstantiationMonitor(DereferencesMonitor): 16 | """ 17 | Class Instantiation Monitor that is invoked when "new " is typed to instantiate new classes 18 | """ 19 | 20 | def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: 21 | super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state) 22 | 23 | async def pre(self) -> None: 24 | cursor_idx = TextUtils.get_index_from_line_col( 25 | self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path), 26 | self.monitor_file_buffer.current_lc[0], 27 | self.monitor_file_buffer.current_lc[1], 28 | ) 29 | text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[ 30 | :cursor_idx 31 | ] 32 | 33 | # TODO: pre can be improved by checking for "new", and obtaining completions, and then prefixing a whitespace 34 | if not text_upto_cursor.endswith("new "): 35 | self.decoder_state = DecoderStates.S0 36 | return 37 | 38 | completions = await self.a_phi() 39 | if len(completions) == 0: 40 | self.decoder_state = DecoderStates.S0 41 | else: 42 | self.decoder_state = DecoderStates.Constrained 43 | self.legal_completions = completions 44 | 45 | async def a_phi(self) -> List[str]: 46 | """ 47 | Find the set of classes in the repository 48 | Filter out the set of abstract classes in the repository 49 | Remaining classes are instantiable. Return their names as legal completions 50 | """ 51 | 52 | legal_completions: List[str] = [] 53 | repository_root_path = self.monitor_file_buffer.lsp.repository_root_path 54 | for path, _, files in os.walk(repository_root_path): 55 | for file in files: 56 | if file.endswith(".java"): 57 | filecontents = FileUtils.read_file(self.monitor_file_buffer.lsp.logger, str(PurePath(path, file))) 58 | relative_file_path = str(PurePath(os.path.relpath(str(PurePath(path, file)), repository_root_path))) 59 | document_symbols, _ = await self.monitor_file_buffer.lsp.request_document_symbols(relative_file_path) 60 | for symbol in document_symbols: 61 | if symbol["kind"] != multilspy_types.SymbolKind.Class: 62 | continue 63 | decl_start_idx = TextUtils.get_index_from_line_col(filecontents, symbol["range"]["start"]["line"], symbol["range"]["start"]["character"]) 64 | decl_end_idx = TextUtils.get_index_from_line_col(filecontents, symbol["selectionRange"]["end"]["line"], symbol["selectionRange"]["end"]["character"]) 65 | decl_text = filecontents[decl_start_idx:decl_end_idx] 66 | if "abstract" not in decl_text: 67 | legal_completions.append(symbol["name"]) 68 | 69 | return legal_completions 70 | 71 | async def update(self, generated_token: str): 72 | """ 73 | Updates the monitor state based on the generated token 74 | """ 75 | if self.responsible_for_file_buffer_state: 76 | self.monitor_file_buffer.append_text(generated_token) 77 | if self.decoder_state == DecoderStates.Constrained: 78 | for break_char in self.all_break_chars: 79 | if break_char in generated_token: 80 | self.decoder_state = DecoderStates.S0 81 | self.legal_completions = None 82 | return 83 | 84 | # No breaking characters found. Continue in constrained state 85 | self.legal_completions = [ 86 | legal_completion[len(generated_token) :] 87 | for legal_completion in self.legal_completions 88 | if legal_completion.startswith(generated_token) 89 | ] 90 | else: 91 | # Nothing to be done in other states 92 | return -------------------------------------------------------------------------------- /src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement 3 | """ 4 | 5 | from typing import List 6 | from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates 7 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 8 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper 9 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 10 | from monitors4codegen.multilspy import multilspy_types 11 | 12 | class SwitchEnumMonitor(DereferencesMonitor): 13 | """ 14 | Provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement to provide 15 | enum values as completions 16 | """ 17 | def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None: 18 | super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state) 19 | self.all_break_chars.remove('.') 20 | 21 | async def pre(self) -> None: 22 | cursor_idx = TextUtils.get_index_from_line_col( 23 | self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path), 24 | self.monitor_file_buffer.current_lc[0], 25 | self.monitor_file_buffer.current_lc[1], 26 | ) 27 | text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[ 28 | :cursor_idx 29 | ] 30 | 31 | # TODO: pre can be improved by checking for r"switch.*case", and obtaining completions, and then prefixing a whitespace 32 | if not text_upto_cursor.endswith("case "): 33 | self.decoder_state = DecoderStates.S0 34 | return 35 | 36 | completions = await self.a_phi() 37 | if len(completions) == 0: 38 | self.decoder_state = DecoderStates.S0 39 | else: 40 | self.decoder_state = DecoderStates.Constrained 41 | self.legal_completions = completions 42 | 43 | async def a_phi(self) -> List[str]: 44 | relative_file_path = self.monitor_file_buffer.file_path 45 | line, column = self.monitor_file_buffer.current_lc 46 | 47 | with self.monitor_file_buffer.lsp.open_file(relative_file_path): 48 | legal_completions = await self.monitor_file_buffer.lsp.request_completions( 49 | relative_file_path, line, column 50 | ) 51 | legal_completions = [ 52 | completion["completionText"] 53 | for completion in legal_completions 54 | if completion["kind"] == multilspy_types.CompletionItemKind.EnumMember 55 | ] 56 | 57 | return legal_completions 58 | 59 | async def update(self, generated_token: str): 60 | """ 61 | Updates the monitor state based on the generated token 62 | """ 63 | if self.responsible_for_file_buffer_state: 64 | self.monitor_file_buffer.append_text(generated_token) 65 | if self.decoder_state == DecoderStates.Constrained: 66 | for break_char in self.all_break_chars: 67 | if break_char in generated_token: 68 | self.decoder_state = DecoderStates.S0 69 | self.legal_completions = None 70 | return 71 | 72 | # No breaking characters found. Continue in constrained state 73 | self.legal_completions = [ 74 | legal_completion[len(generated_token) :] 75 | for legal_completion in self.legal_completions 76 | if legal_completion.startswith(generated_token) 77 | ] 78 | else: 79 | # Nothing to be done in other states 80 | return -------------------------------------------------------------------------------- /src/monitors4codegen/monitor_guided_decoding/openai_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the functions and classes for running Monitor-Guided Decoding over OpenAI models 3 | """ 4 | 5 | from enum import Enum 6 | import time 7 | from typing import List, Set 8 | import torch 9 | import asyncio 10 | 11 | from openai import OpenAI 12 | from monitors4codegen.monitor_guided_decoding.monitor import Monitor 13 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TikTokenWrapper 14 | 15 | class OpenAI_Models(Enum): 16 | TD3 = 'text-davinci-003' 17 | 18 | def openai_mgd( 19 | client: OpenAI, 20 | model: OpenAI_Models, 21 | tokenizer: TikTokenWrapper, 22 | prompt_tokenized: torch.Tensor, 23 | temp: float, 24 | top_p: float, 25 | monitor: Monitor, 26 | num_new_tokens: int 27 | ): 28 | """ 29 | This function generates completions with OpenAI models using the Monitor-Guided Decoding scheme. 30 | """ 31 | prompt_tokenized: torch.Tensor = torch.tensor(prompt_tokenized, dtype=torch.int64) 32 | assert len(prompt_tokenized.shape) == 1 33 | 34 | all_tokens: torch.Tensor = prompt_tokenized 35 | gen_text: bytes = b'' 36 | 37 | gen_tokens: List[int] = [] 38 | 39 | tokens_sort_key = {k:[0, 0] for k in tokenizer.all_token_ids} 40 | 41 | # # TODO: Find a way to prioritize tokens to be blacklisted 42 | 43 | # # Why prioritize? OpenAI allows applying logit_bias to upto 300 tokens, whereas the typical number of tokens in vocabulary is 50,000. 44 | # # Because of this, it is necessary to identify the top 300 tokens, that we think need to be either blacklisted, or whitelisted. 45 | # # This prioritization should be done taking into account what violating token is the model likely to predict in the next step. 46 | 47 | # # Options for prioritization of tokens: 48 | # # 1. The following code uses info about whether the token has a break char in it 49 | # for token, token_id in tokenizer.vocab_trie.iteritems(): 50 | # if token[0] in monitor.all_break_chars: 51 | # tokens_sort_key[token_id][0] = 0 # ".", ", a" 52 | # elif any([c in monitor.all_break_chars for c in token]): 53 | # tokens_sort_key[token_id][0] = 1 # "abc, " 54 | # else: 55 | # tokens_sort_key[token_id][0] = 2 56 | 57 | # # 2. The following code uses frequency of the token in repo as a heuristic 58 | # for freq_token, freq in metadata_batch[seq_idx]['token_freq']: 59 | # tokens_sort_key[freq_token][1] = freq 60 | 61 | # # 3. Use a local-small and very fast language model to score the tokens 62 | 63 | # # 4. Use the prompt to score the tokens 64 | 65 | all_text_bytes: bytes = tokenizer.tokenizer.decode_bytes(all_tokens.tolist()) 66 | prompt_num_tokens: int = all_tokens.shape[0] 67 | 68 | priority_blacklist: List[int] = [] 69 | 70 | while all_tokens.shape[0] < prompt_num_tokens + num_new_tokens: 71 | num_toks_to_gen = (prompt_num_tokens + num_new_tokens) - all_tokens.shape[0] 72 | 73 | blacklisted_ids: List[int] = asyncio.run_coroutine_threadsafe(monitor.maskgen(all_tokens.tolist()), monitor.monitor_file_buffer.lsp.server.loop).result() 74 | white_listed_ids: Set[int] = set(tokenizer.all_token_ids) - set(blacklisted_ids+[50256]) 75 | 76 | logit_bias = {50256:-100} 77 | 78 | for token_id in priority_blacklist: 79 | logit_bias[token_id] = -100 80 | 81 | if len(white_listed_ids) <= (300 - len(logit_bias)): 82 | for white_token_id in white_listed_ids: 83 | logit_bias[white_token_id] = 100 84 | else: 85 | for candidate_token in sorted(blacklisted_ids, key=lambda x: tokens_sort_key[x], reverse=True): 86 | if len(logit_bias) >= 300: 87 | break 88 | if candidate_token in blacklisted_ids: 89 | logit_bias[candidate_token] = -100 90 | 91 | exponential_backoff_wait = 1 92 | while True: 93 | try: 94 | prompt_arg: str = all_text_bytes.decode('utf-8', errors='strict') 95 | except UnicodeDecodeError: 96 | prompt_arg: List[int] = all_tokens.tolist() 97 | 98 | try: 99 | response = client.completions.create( 100 | model=model.value, 101 | prompt=[prompt_arg], 102 | temperature=temp, 103 | max_tokens=num_toks_to_gen if len(logit_bias) <= 1 else 1, 104 | top_p=top_p, 105 | stop=['.'], 106 | logit_bias=logit_bias, 107 | logprobs=5 108 | ) 109 | break 110 | except Exception: 111 | time.sleep(exponential_backoff_wait) 112 | if exponential_backoff_wait < 64: 113 | exponential_backoff_wait = exponential_backoff_wait*2 114 | else: 115 | exponential_backoff_wait = 1 116 | 117 | assert len(response.choices) == 1 118 | 119 | def convert_bytesrep_to_bytes(x: str) -> bytes: 120 | if x.startswith('bytes:'): 121 | return bytes.fromhex(x.replace('bytes:', '').replace('\\x', '')) 122 | else: 123 | return x.encode() 124 | 125 | tokens_gen_bytes_ = list(map(convert_bytesrep_to_bytes, response.choices[0].logprobs.tokens)) 126 | tokens_gen_bytes = [] 127 | dot_found = False 128 | for token_bytes in tokens_gen_bytes_: 129 | gen_text += token_bytes 130 | all_text_bytes += token_bytes 131 | tokens_gen_bytes.append(token_bytes) 132 | if b'.' in token_bytes: 133 | dot_found = True 134 | break 135 | 136 | # When "stop" sequence is sent to openai model, it will not generate text beyond the text sequence within the "stop" parameter. 137 | # However, when it stops because of the "stop" sequence, the returned text does not contain the stop sequence, and only includes 138 | # text upto the stop sequence. So, the following code determines if the stop sequence "." needs to be added manually. 139 | should_manually_add_dot = None 140 | if response.choices[0].finish_reason == 'stop': 141 | if dot_found: 142 | should_manually_add_dot = False 143 | else: 144 | should_manually_add_dot = True 145 | elif response.choices[0].finish_reason == 'length': 146 | should_manually_add_dot = False 147 | else: 148 | raise Exception("Unknown finish reason", response.choices[0].finish_reason) 149 | 150 | tokens_gen = list(map(lambda x: tokenizer.tokenizer.encode_single_token(x), tokens_gen_bytes)) 151 | 152 | assert should_manually_add_dot is not None 153 | if should_manually_add_dot: 154 | gen_text += b'.' 155 | all_text_bytes += b'.' 156 | tokens_gen.append(tokenizer.tokenizer.encode_single_token('.')) 157 | 158 | if len(logit_bias) > 1: 159 | assert len(tokens_gen) == 1, (print(response), response, launch_debug(locals())) 160 | if tokens_gen[0] in blacklisted_ids: 161 | priority_blacklist.append(tokens_gen[0]) 162 | continue 163 | priority_blacklist = [] 164 | 165 | new_all_tokens = torch.cat([ 166 | all_tokens, 167 | torch.tensor(tokens_gen) 168 | ]).to(all_tokens) 169 | 170 | assert len(new_all_tokens.shape) == 1 171 | assert new_all_tokens.shape[0] > all_tokens.shape[0], (new_all_tokens.shape, all_tokens.shape, launch_debug(locals())) 172 | assert torch.equal(new_all_tokens[:all_tokens.shape[0]], all_tokens) 173 | gen_tokens += new_all_tokens[all_tokens.shape[0]:].tolist() 174 | all_tokens = new_all_tokens 175 | 176 | return gen_tokens, gen_text.decode() 177 | -------------------------------------------------------------------------------- /src/monitors4codegen/monitor_guided_decoding/tokenizer_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file provides the tokenizer wrapper that is used to provide a common interface over 3 | HF tokenizers and TikToken tokenizers 4 | """ 5 | 6 | import torch 7 | import tiktoken 8 | 9 | from typing import List, Set, Union 10 | from pygtrie import CharTrie 11 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 12 | 13 | 14 | class TokenizerWrapper: 15 | """ 16 | This class provides a common interface over HF tokenizers and TikToken tokenizers 17 | """ 18 | 19 | def __init__(self, tokenizer: Union[PreTrainedTokenizerBase, tiktoken.core.Encoding]): 20 | """ 21 | Initializes the tokenizer wrapper 22 | """ 23 | self.tokenizer = tokenizer 24 | self.vocab_trie = CharTrie() 25 | self.tokenizer_char_set: Set[str] = set() 26 | self.all_token_ids: Set[int] = set() 27 | 28 | def decode(self, *args, **kwargs) -> str: 29 | """ 30 | Decodes the given token ids to a string 31 | 32 | Params: 33 | token_ids, clean_up_tokenization_spaces, skip_special_tokens 34 | """ 35 | raise NotImplementedError() 36 | 37 | def convert_ids_to_tokens(self, x) -> List[str]: 38 | """ 39 | Converts the given token ids to a list of tokens 40 | """ 41 | raise NotImplementedError() 42 | 43 | def convert_tokens_to_string(self, x) -> str: 44 | """ 45 | Converts the given list of tokens to a string 46 | """ 47 | raise NotImplementedError() 48 | 49 | 50 | class HFTokenizerWrapper(TokenizerWrapper): 51 | """ 52 | This class provides an instance of TokenizerWrapper for HF tokenizers 53 | """ 54 | def __init__(self, tokenizer: PreTrainedTokenizerBase): 55 | super().__init__(tokenizer) 56 | self.__dict__.update(tokenizer.__dict__) 57 | for k, v in tokenizer.vocab.items(): 58 | decoded_token = tokenizer.decode(v, clean_up_tokenization_spaces=False, skip_special_tokens=True) 59 | if decoded_token != "": 60 | self.tokenizer_char_set.update(decoded_token) 61 | self.vocab_trie[decoded_token] = v 62 | self.all_token_ids = set(tokenizer.vocab.values()) 63 | 64 | def decode(self, *args, **kwargs) -> str: 65 | """ 66 | Decodes the given token ids to a string 67 | """ 68 | return self.tokenizer.decode(*args, **kwargs) 69 | 70 | def convert_ids_to_tokens(self, x) -> List[str]: 71 | """ 72 | Converts the given token ids to a list of tokens 73 | """ 74 | return self.tokenizer.convert_ids_to_tokens(x) 75 | 76 | def convert_tokens_to_string(self, x) -> str: 77 | """ 78 | Converts the given list of tokens to a string 79 | """ 80 | return self.tokenizer.convert_tokens_to_string(x) 81 | 82 | 83 | class TikTokenWrapper(TokenizerWrapper): 84 | """ 85 | This class provides an instance of TokenizerWrapper for TikToken tokenizers 86 | """ 87 | def __init__(self, tokenizer: tiktoken.core.Encoding): 88 | super().__init__(tokenizer) 89 | 90 | assert len(tokenizer.special_tokens_set) == 1 91 | self.all_special_ids = {tokenizer.encode_single_token(token) for token in tokenizer.special_tokens_set} 92 | for k_ in tokenizer.token_byte_values(): 93 | v = tokenizer.encode_single_token(k_) 94 | decoded_token = tokenizer.decode([tokenizer.encode_single_token(k_)]) 95 | if decoded_token != "": 96 | self.tokenizer_char_set.update(decoded_token) 97 | self.vocab_trie[decoded_token] = v 98 | self.all_token_ids.add(v) 99 | 100 | def decode(self, token_ids: Union[List[int], torch.Tensor], *args, **kwargs) -> str: 101 | """ 102 | Decodes the given token ids to a string 103 | """ 104 | clean_up_tokenization_spaces, skip_special_tokens = None, None 105 | if len(args) == 0: 106 | pass 107 | elif len(args) == 1: 108 | skip_special_tokens: bool = args[0] 109 | elif len(args) == 2: 110 | skip_special_tokens, clean_up_tokenization_spaces = args[0], args[1] 111 | 112 | if clean_up_tokenization_spaces is None: 113 | clean_up_tokenization_spaces = kwargs.get("clean_up_tokenization_spaces", True) 114 | if skip_special_tokens is None: 115 | skip_special_tokens = kwargs.get("skip_special_tokens", False) 116 | 117 | assert not clean_up_tokenization_spaces 118 | assert skip_special_tokens 119 | if isinstance(token_ids, torch.Tensor): 120 | token_ids = token_ids.tolist() 121 | 122 | token_ids: List[int] = [i for i in token_ids if i not in self.all_special_ids] 123 | 124 | return self.tokenizer.decode(token_ids) 125 | 126 | def convert_ids_to_tokens(self, x) -> List[str]: 127 | """ 128 | Converts the given token ids to a list of tokens 129 | """ 130 | return [self.tokenizer.decode([i]) for i in x] 131 | 132 | def convert_tokens_to_string(self, x) -> str: 133 | """ 134 | Converts the given list of tokens to a string 135 | """ 136 | return "".join(x) 137 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the multilspy API 3 | """ 4 | 5 | from . import multilspy_types as Types 6 | from .language_server import LanguageServer, SyncLanguageServer 7 | 8 | __all__ = ["LanguageServer", "Types", "SyncLanguageServer"] 9 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/language_servers/eclipse_jdtls/runtime_dependencies.json: -------------------------------------------------------------------------------- 1 | { 2 | "_description": "This file lists the runtime dependencies for the Java Language Server", 3 | "gradle": { 4 | "platform-agnostic": { 5 | "url": "https://services.gradle.org/distributions/gradle-7.3.3-bin.zip", 6 | "archiveType": "zip", 7 | "relative_extraction_path": "." 8 | } 9 | }, 10 | "vscode-java": { 11 | "darwin-arm64": { 12 | "url": "https://github.com/redhat-developer/vscode-java/releases/download/v1.23.0/java@darwin-arm64-1.23.0.vsix", 13 | "archiveType": "zip", 14 | "relative_extraction_path": "vscode-java" 15 | }, 16 | "darwin-x64": { 17 | "url": "https://github.com/redhat-developer/vscode-java/releases/download/v1.23.0/java@darwin-x64-1.23.0.vsix", 18 | "archiveType": "zip", 19 | "relative_extraction_path": "vscode-java" 20 | }, 21 | "linux-arm64": { 22 | "url": "https://github.com/redhat-developer/vscode-java/releases/download/v1.23.0/java@linux-arm64-1.23.0.vsix", 23 | "archiveType": "zip", 24 | "relative_extraction_path": "vscode-java" 25 | }, 26 | "linux-x64": { 27 | "url": "https://github.com/redhat-developer/vscode-java/releases/download/v1.23.0/java@linux-x64-1.23.0.vsix", 28 | "archiveType": "zip", 29 | "relative_extraction_path": "vscode-java", 30 | "jre_home_path": "extension/jre/17.0.8.1-linux-x86_64", 31 | "jre_path": "extension/jre/17.0.8.1-linux-x86_64/bin/java", 32 | "lombok_jar_path": "extension/lombok/lombok-1.18.30.jar", 33 | "jdtls_launcher_jar_path": "extension/server/plugins/org.eclipse.equinox.launcher_1.6.500.v20230717-2134.jar", 34 | "jdtls_readonly_config_path": "extension/server/config_linux" 35 | }, 36 | "win-x64": { 37 | "url": "https://github.com/redhat-developer/vscode-java/releases/download/v1.23.0/java@win32-x64-1.23.0.vsix", 38 | "archiveType": "zip", 39 | "relative_extraction_path": "vscode-java", 40 | "jre_home_path": "extension/jre/17.0.8.1-win32-x86_64", 41 | "jre_path": "extension/jre/17.0.8.1-win32-x86_64/bin/java.exe", 42 | "lombok_jar_path": "extension/lombok/lombok-1.18.30.jar", 43 | "jdtls_launcher_jar_path": "extension/server/plugins/org.eclipse.equinox.launcher_1.6.500.v20230717-2134.jar", 44 | "jdtls_readonly_config_path": "extension/server/config_win" 45 | } 46 | }, 47 | "intellicode": { 48 | "platform-agnostic": { 49 | "url": "https://VisualStudioExptTeam.gallery.vsassets.io/_apis/public/gallery/publisher/VisualStudioExptTeam/extension/vscodeintellicode/1.2.30/assetbyname/Microsoft.VisualStudio.Services.VSIXPackage", 50 | "alternate_url": "https://marketplace.visualstudio.com/_apis/public/gallery/publishers/VisualStudioExptTeam/vsextensions/vscodeintellicode/1.2.30/vspackage", 51 | "archiveType": "zip", 52 | "relative_extraction_path": "intellicode", 53 | "intellicode_jar_path": "extension/dist/com.microsoft.jdtls.intellicode.core-0.7.0.jar", 54 | "intellisense_members_path": "extension/dist/bundledModels/java_intellisense-members" 55 | } 56 | } 57 | } -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/language_servers/jedi_language_server/jedi_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides Python specific instantiation of the LanguageServer class. Contains various configurations and settings specific to Python. 3 | """ 4 | 5 | import json 6 | import logging 7 | import os 8 | import pathlib 9 | from contextlib import asynccontextmanager 10 | from typing import AsyncIterator 11 | 12 | from monitors4codegen.multilspy.multilspy_logger import MultilspyLogger 13 | from monitors4codegen.multilspy.language_server import LanguageServer 14 | from monitors4codegen.multilspy.lsp_protocol_handler.server import ProcessLaunchInfo 15 | from monitors4codegen.multilspy.lsp_protocol_handler.lsp_types import InitializeParams 16 | from monitors4codegen.multilspy.multilspy_config import MultilspyConfig 17 | 18 | 19 | class JediServer(LanguageServer): 20 | """ 21 | Provides Python specific instantiation of the LanguageServer class. Contains various configurations and settings specific to Python. 22 | """ 23 | 24 | def __init__(self, config: MultilspyConfig, logger: MultilspyLogger, repository_root_path: str): 25 | """ 26 | Creates a JediServer instance. This class is not meant to be instantiated directly. Use LanguageServer.create() instead. 27 | """ 28 | super().__init__( 29 | config, 30 | logger, 31 | repository_root_path, 32 | ProcessLaunchInfo(cmd="jedi-language-server", cwd=repository_root_path), 33 | "python", 34 | ) 35 | 36 | def _get_initialize_params(self, repository_absolute_path: str) -> InitializeParams: 37 | """ 38 | Returns the initialize params for the Jedi Language Server. 39 | """ 40 | with open(os.path.join(os.path.dirname(__file__), "initialize_params.json"), "r") as f: 41 | d = json.load(f) 42 | 43 | del d["_description"] 44 | 45 | d["processId"] = os.getpid() 46 | assert d["rootPath"] == "$rootPath" 47 | d["rootPath"] = repository_absolute_path 48 | 49 | assert d["rootUri"] == "$rootUri" 50 | d["rootUri"] = pathlib.Path(repository_absolute_path).as_uri() 51 | 52 | assert d["workspaceFolders"][0]["uri"] == "$uri" 53 | d["workspaceFolders"][0]["uri"] = pathlib.Path(repository_absolute_path).as_uri() 54 | 55 | assert d["workspaceFolders"][0]["name"] == "$name" 56 | d["workspaceFolders"][0]["name"] = os.path.basename(repository_absolute_path) 57 | 58 | return d 59 | 60 | @asynccontextmanager 61 | async def start_server(self) -> AsyncIterator["JediServer"]: 62 | """ 63 | Starts the JEDI Language Server, waits for the server to be ready and yields the LanguageServer instance. 64 | 65 | Usage: 66 | ``` 67 | async with lsp.start_server(): 68 | # LanguageServer has been initialized and ready to serve requests 69 | await lsp.request_definition(...) 70 | await lsp.request_references(...) 71 | # Shutdown the LanguageServer on exit from scope 72 | # LanguageServer has been shutdown 73 | ``` 74 | """ 75 | 76 | async def execute_client_command_handler(params): 77 | return [] 78 | 79 | async def do_nothing(params): 80 | return 81 | 82 | async def check_experimental_status(params): 83 | if params["quiescent"] == True: 84 | self.completions_available.set() 85 | 86 | async def window_log_message(msg): 87 | self.logger.log(f"LSP: window/logMessage: {msg}", logging.INFO) 88 | 89 | self.server.on_request("client/registerCapability", do_nothing) 90 | self.server.on_notification("language/status", do_nothing) 91 | self.server.on_notification("window/logMessage", window_log_message) 92 | self.server.on_request("workspace/executeClientCommand", execute_client_command_handler) 93 | self.server.on_notification("$/progress", do_nothing) 94 | self.server.on_notification("textDocument/publishDiagnostics", do_nothing) 95 | self.server.on_notification("language/actionableNotification", do_nothing) 96 | self.server.on_notification("experimental/serverStatus", check_experimental_status) 97 | 98 | async with super().start_server(): 99 | self.logger.log("Starting jedi-language-server server process", logging.INFO) 100 | await self.server.start() 101 | initialize_params = self._get_initialize_params(self.repository_root_path) 102 | 103 | self.logger.log( 104 | "Sending initialize request from LSP client to LSP server and awaiting response", 105 | logging.INFO, 106 | ) 107 | init_response = await self.server.send.initialize(initialize_params) 108 | assert init_response["capabilities"]["textDocumentSync"]["change"] == 2 109 | assert "completionProvider" in init_response["capabilities"] 110 | assert init_response["capabilities"]["completionProvider"] == { 111 | "triggerCharacters": [".", "'", '"'], 112 | "resolveProvider": True, 113 | } 114 | 115 | self.server.notify.initialized({}) 116 | 117 | yield self 118 | 119 | await self.server.shutdown() 120 | await self.server.stop() 121 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/language_servers/omnisharp/workspace_did_change_configuration.json: -------------------------------------------------------------------------------- 1 | { 2 | "RoslynExtensionsOptions": { 3 | "EnableDecompilationSupport": false, 4 | "EnableAnalyzersSupport": true, 5 | "EnableImportCompletion": true, 6 | "EnableAsyncCompletion": false, 7 | "DocumentAnalysisTimeoutMs": 30000, 8 | "DiagnosticWorkersThreadCount": 18, 9 | "AnalyzeOpenDocumentsOnly": true, 10 | "InlayHintsOptions": { 11 | "EnableForParameters": false, 12 | "ForLiteralParameters": false, 13 | "ForIndexerParameters": false, 14 | "ForObjectCreationParameters": false, 15 | "ForOtherParameters": false, 16 | "SuppressForParametersThatDifferOnlyBySuffix": false, 17 | "SuppressForParametersThatMatchMethodIntent": false, 18 | "SuppressForParametersThatMatchArgumentName": false, 19 | "EnableForTypes": false, 20 | "ForImplicitVariableTypes": false, 21 | "ForLambdaParameterTypes": false, 22 | "ForImplicitObjectCreation": false 23 | }, 24 | "LocationPaths": null 25 | }, 26 | "FormattingOptions": { 27 | "OrganizeImports": false, 28 | "EnableEditorConfigSupport": true, 29 | "NewLine": "\n", 30 | "UseTabs": false, 31 | "TabSize": 4, 32 | "IndentationSize": 4, 33 | "SpacingAfterMethodDeclarationName": false, 34 | "SeparateImportDirectiveGroups": false, 35 | "SpaceWithinMethodDeclarationParenthesis": false, 36 | "SpaceBetweenEmptyMethodDeclarationParentheses": false, 37 | "SpaceAfterMethodCallName": false, 38 | "SpaceWithinMethodCallParentheses": false, 39 | "SpaceBetweenEmptyMethodCallParentheses": false, 40 | "SpaceAfterControlFlowStatementKeyword": true, 41 | "SpaceWithinExpressionParentheses": false, 42 | "SpaceWithinCastParentheses": false, 43 | "SpaceWithinOtherParentheses": false, 44 | "SpaceAfterCast": false, 45 | "SpaceBeforeOpenSquareBracket": false, 46 | "SpaceBetweenEmptySquareBrackets": false, 47 | "SpaceWithinSquareBrackets": false, 48 | "SpaceAfterColonInBaseTypeDeclaration": true, 49 | "SpaceAfterComma": true, 50 | "SpaceAfterDot": false, 51 | "SpaceAfterSemicolonsInForStatement": true, 52 | "SpaceBeforeColonInBaseTypeDeclaration": true, 53 | "SpaceBeforeComma": false, 54 | "SpaceBeforeDot": false, 55 | "SpaceBeforeSemicolonsInForStatement": false, 56 | "SpacingAroundBinaryOperator": "single", 57 | "IndentBraces": false, 58 | "IndentBlock": true, 59 | "IndentSwitchSection": true, 60 | "IndentSwitchCaseSection": true, 61 | "IndentSwitchCaseSectionWhenBlock": true, 62 | "LabelPositioning": "oneLess", 63 | "WrappingPreserveSingleLine": true, 64 | "WrappingKeepStatementsOnSingleLine": true, 65 | "NewLinesForBracesInTypes": true, 66 | "NewLinesForBracesInMethods": true, 67 | "NewLinesForBracesInProperties": true, 68 | "NewLinesForBracesInAccessors": true, 69 | "NewLinesForBracesInAnonymousMethods": true, 70 | "NewLinesForBracesInControlBlocks": true, 71 | "NewLinesForBracesInAnonymousTypes": true, 72 | "NewLinesForBracesInObjectCollectionArrayInitializers": true, 73 | "NewLinesForBracesInLambdaExpressionBody": true, 74 | "NewLineForElse": true, 75 | "NewLineForCatch": true, 76 | "NewLineForFinally": true, 77 | "NewLineForMembersInObjectInit": true, 78 | "NewLineForMembersInAnonymousTypes": true, 79 | "NewLineForClausesInQuery": true 80 | }, 81 | "FileOptions": { 82 | "SystemExcludeSearchPatterns": [ 83 | "**/node_modules/**/*", 84 | "**/bin/**/*", 85 | "**/obj/**/*", 86 | "**/.git/**/*", 87 | "**/.git", 88 | "**/.svn", 89 | "**/.hg", 90 | "**/CVS", 91 | "**/.DS_Store", 92 | "**/Thumbs.db" 93 | ], 94 | "ExcludeSearchPatterns": [] 95 | }, 96 | "RenameOptions": { 97 | "RenameOverloads": false, 98 | "RenameInStrings": false, 99 | "RenameInComments": false 100 | }, 101 | "ImplementTypeOptions": { 102 | "InsertionBehavior": 0, 103 | "PropertyGenerationBehavior": 0 104 | }, 105 | "DotNetCliOptions": { 106 | "LocationPaths": null 107 | }, 108 | "Plugins": { 109 | "LocationPaths": null 110 | } 111 | } -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/language_servers/rust_analyzer/runtime_dependencies.json: -------------------------------------------------------------------------------- 1 | { 2 | "_description": "Used to download the runtime dependencies for running RustAnalyzer. Obtained from https://github.com/rust-lang/rust-analyzer/releases", 3 | "runtimeDependencies": [ 4 | { 5 | "id": "RustAnalyzer", 6 | "description": "RustAnalyzer for Linux (x64)", 7 | "url": "https://github.com/rust-lang/rust-analyzer/releases/download/2023-10-09/rust-analyzer-x86_64-unknown-linux-gnu.gz", 8 | "platformId": "linux-x64", 9 | "archiveType": "gz", 10 | "binaryName": "rust_analyzer" 11 | }, 12 | { 13 | "id": "RustAnalyzer", 14 | "description": "RustAnalyzer for Windows (x64)", 15 | "url": "https://github.com/rust-lang/rust-analyzer/releases/download/2023-10-09/rust-analyzer-x86_64-pc-windows-msvc.zip", 16 | "platformId": "win-x64", 17 | "archiveType": "zip", 18 | "binaryName": "rust-analyzer.exe" 19 | } 20 | ] 21 | } -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/language_servers/rust_analyzer/rust_analyzer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides Rust specific instantiation of the LanguageServer class. Contains various configurations and settings specific to Rust. 3 | """ 4 | 5 | import asyncio 6 | import json 7 | import logging 8 | import os 9 | import stat 10 | import pathlib 11 | from contextlib import asynccontextmanager 12 | from typing import AsyncIterator 13 | 14 | from monitors4codegen.multilspy.multilspy_logger import MultilspyLogger 15 | from monitors4codegen.multilspy.language_server import LanguageServer 16 | from monitors4codegen.multilspy.lsp_protocol_handler.server import ProcessLaunchInfo 17 | from monitors4codegen.multilspy.lsp_protocol_handler.lsp_types import InitializeParams 18 | from monitors4codegen.multilspy.multilspy_config import MultilspyConfig 19 | from monitors4codegen.multilspy.multilspy_utils import FileUtils 20 | from monitors4codegen.multilspy.multilspy_utils import PlatformUtils 21 | 22 | 23 | class RustAnalyzer(LanguageServer): 24 | """ 25 | Provides Rust specific instantiation of the LanguageServer class. Contains various configurations and settings specific to Rust. 26 | """ 27 | 28 | def __init__(self, config: MultilspyConfig, logger: MultilspyLogger, repository_root_path: str): 29 | """ 30 | Creates a RustAnalyzer instance. This class is not meant to be instantiated directly. Use LanguageServer.create() instead. 31 | """ 32 | rustanalyzer_executable_path = self.setup_runtime_dependencies(logger, config) 33 | super().__init__( 34 | config, 35 | logger, 36 | repository_root_path, 37 | ProcessLaunchInfo(cmd=rustanalyzer_executable_path, cwd=repository_root_path), 38 | "rust", 39 | ) 40 | self.server_ready = asyncio.Event() 41 | 42 | def setup_runtime_dependencies(self, logger: MultilspyLogger, config: MultilspyConfig) -> str: 43 | """ 44 | Setup runtime dependencies for OmniSharp. 45 | """ 46 | platform_id = PlatformUtils.get_platform_id() 47 | 48 | with open(os.path.join(os.path.dirname(__file__), "runtime_dependencies.json"), "r") as f: 49 | d = json.load(f) 50 | del d["_description"] 51 | 52 | assert platform_id.value in [ 53 | "linux-x64", 54 | "win-x64", 55 | ], "Only linux-x64 platform is supported for in multilspy at the moment" 56 | 57 | runtime_dependencies = d["runtimeDependencies"] 58 | runtime_dependencies = [ 59 | dependency for dependency in runtime_dependencies if dependency["platformId"] == platform_id.value 60 | ] 61 | assert len(runtime_dependencies) == 1 62 | dependency = runtime_dependencies[0] 63 | 64 | rustanalyzer_ls_dir = os.path.join(os.path.dirname(__file__), "static", "RustAnalyzer") 65 | rustanalyzer_executable_path = os.path.join(rustanalyzer_ls_dir, dependency["binaryName"]) 66 | if not os.path.exists(rustanalyzer_ls_dir): 67 | os.makedirs(rustanalyzer_ls_dir) 68 | if dependency["archiveType"] == "gz": 69 | FileUtils.download_and_extract_archive( 70 | logger, dependency["url"], rustanalyzer_executable_path, dependency["archiveType"] 71 | ) 72 | else: 73 | FileUtils.download_and_extract_archive( 74 | logger, dependency["url"], rustanalyzer_ls_dir, dependency["archiveType"] 75 | ) 76 | assert os.path.exists(rustanalyzer_executable_path) 77 | os.chmod(rustanalyzer_executable_path, stat.S_IEXEC) 78 | 79 | return rustanalyzer_executable_path 80 | 81 | def _get_initialize_params(self, repository_absolute_path: str) -> InitializeParams: 82 | """ 83 | Returns the initialize params for the Rust Analyzer Language Server. 84 | """ 85 | with open(os.path.join(os.path.dirname(__file__), "initialize_params.json"), "r") as f: 86 | d = json.load(f) 87 | 88 | del d["_description"] 89 | 90 | d["processId"] = os.getpid() 91 | assert d["rootPath"] == "$rootPath" 92 | d["rootPath"] = repository_absolute_path 93 | 94 | assert d["rootUri"] == "$rootUri" 95 | d["rootUri"] = pathlib.Path(repository_absolute_path).as_uri() 96 | 97 | assert d["workspaceFolders"][0]["uri"] == "$uri" 98 | d["workspaceFolders"][0]["uri"] = pathlib.Path(repository_absolute_path).as_uri() 99 | 100 | assert d["workspaceFolders"][0]["name"] == "$name" 101 | d["workspaceFolders"][0]["name"] = os.path.basename(repository_absolute_path) 102 | 103 | return d 104 | 105 | @asynccontextmanager 106 | async def start_server(self) -> AsyncIterator["RustAnalyzer"]: 107 | """ 108 | Starts the Rust Analyzer Language Server, waits for the server to be ready and yields the LanguageServer instance. 109 | 110 | Usage: 111 | ``` 112 | async with lsp.start_server(): 113 | # LanguageServer has been initialized and ready to serve requests 114 | await lsp.request_definition(...) 115 | await lsp.request_references(...) 116 | # Shutdown the LanguageServer on exit from scope 117 | # LanguageServer has been shutdown 118 | """ 119 | 120 | async def register_capability_handler(params): 121 | assert "registrations" in params 122 | for registration in params["registrations"]: 123 | if registration["method"] == "workspace/executeCommand": 124 | self.initialize_searcher_command_available.set() 125 | self.resolve_main_method_available.set() 126 | return 127 | 128 | async def lang_status_handler(params): 129 | # TODO: Should we wait for 130 | # server -> client: {'jsonrpc': '2.0', 'method': 'language/status', 'params': {'type': 'ProjectStatus', 'message': 'OK'}} 131 | # Before proceeding? 132 | if params["type"] == "ServiceReady" and params["message"] == "ServiceReady": 133 | self.service_ready_event.set() 134 | 135 | async def execute_client_command_handler(params): 136 | return [] 137 | 138 | async def do_nothing(params): 139 | return 140 | 141 | async def check_experimental_status(params): 142 | if params["quiescent"] == True: 143 | self.server_ready.set() 144 | 145 | async def window_log_message(msg): 146 | self.logger.log(f"LSP: window/logMessage: {msg}", logging.INFO) 147 | 148 | self.server.on_request("client/registerCapability", register_capability_handler) 149 | self.server.on_notification("language/status", lang_status_handler) 150 | self.server.on_notification("window/logMessage", window_log_message) 151 | self.server.on_request("workspace/executeClientCommand", execute_client_command_handler) 152 | self.server.on_notification("$/progress", do_nothing) 153 | self.server.on_notification("textDocument/publishDiagnostics", do_nothing) 154 | self.server.on_notification("language/actionableNotification", do_nothing) 155 | self.server.on_notification("experimental/serverStatus", check_experimental_status) 156 | 157 | async with super().start_server(): 158 | self.logger.log("Starting RustAnalyzer server process", logging.INFO) 159 | await self.server.start() 160 | initialize_params = self._get_initialize_params(self.repository_root_path) 161 | 162 | self.logger.log( 163 | "Sending initialize request from LSP client to LSP server and awaiting response", 164 | logging.INFO, 165 | ) 166 | init_response = await self.server.send.initialize(initialize_params) 167 | assert init_response["capabilities"]["textDocumentSync"]["change"] == 2 168 | assert "completionProvider" in init_response["capabilities"] 169 | assert init_response["capabilities"]["completionProvider"] == { 170 | "resolveProvider": True, 171 | "triggerCharacters": [":", ".", "'", "("], 172 | "completionItem": {"labelDetailsSupport": True}, 173 | } 174 | self.server.notify.initialized({}) 175 | self.completions_available.set() 176 | 177 | await self.server_ready.wait() 178 | 179 | yield self 180 | 181 | await self.server.shutdown() 182 | await self.server.stop() 183 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/lsp_protocol_handler/lsp_constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains constants used in the LSP protocol. 3 | """ 4 | 5 | class LSPConstants: 6 | """ 7 | This class contains constants used in the LSP protocol. 8 | """ 9 | 10 | # the key for uri used to represent paths 11 | URI = "uri" 12 | 13 | # the key for range, which is a from and to position within a text document 14 | RANGE = "range" 15 | 16 | # A key used in LocationLink type, used as the span of the origin link 17 | ORIGIN_SELECTION_RANGE = "originSelectionRange" 18 | 19 | # A key used in LocationLink type, used as the target uri of the link 20 | TARGET_URI = "targetUri" 21 | 22 | # A key used in LocationLink type, used as the target range of the link 23 | TARGET_RANGE = "targetRange" 24 | 25 | # A key used in LocationLink type, used as the target selection range of the link 26 | TARGET_SELECTION_RANGE = "targetSelectionRange" 27 | 28 | # key for the textDocument field in the request 29 | TEXT_DOCUMENT = "textDocument" 30 | 31 | # key used to represent the language a document is in - "java", "csharp", etc. 32 | LANGUAGE_ID = "languageId" 33 | 34 | # key used to represent the version of a document (a shared value betwen the client and server) 35 | VERSION = "version" 36 | 37 | # key used to represent the text of a document being sent from the client to the server on open 38 | TEXT = "text" 39 | 40 | # key used to represent a position (line and colnum) within a text document 41 | POSITION = "position" 42 | 43 | # key used to represent the line number of a position 44 | LINE = "line" 45 | 46 | # key used to represent the column number of a position 47 | CHARACTER = "character" 48 | 49 | # key used to represent the changes made to a document 50 | CONTENT_CHANGES = "contentChanges" 51 | 52 | # key used to represent name of symbols 53 | NAME = "name" 54 | 55 | # key used to represent the kind of symbols 56 | KIND = "kind" 57 | 58 | # key used to represent children in document symbols 59 | CHILDREN = "children" 60 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/multilspy_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration parameters for Multilspy. 3 | """ 4 | 5 | from enum import Enum 6 | from dataclasses import dataclass 7 | 8 | class Language(str, Enum): 9 | """ 10 | Possible languages with Multilspy. 11 | """ 12 | 13 | CSHARP = "csharp" 14 | PYTHON = "python" 15 | RUST = "rust" 16 | JAVA = "java" 17 | 18 | def __str__(self) -> str: 19 | return self.value 20 | 21 | @dataclass 22 | class MultilspyConfig: 23 | """ 24 | Configuration parameters 25 | """ 26 | code_language: Language 27 | trace_lsp_communication: bool = False 28 | 29 | @classmethod 30 | def from_dict(cls, env: dict): 31 | """ 32 | Create a MultilspyConfig instance from a dictionary 33 | """ 34 | import inspect 35 | return cls(**{ 36 | k: v for k, v in env.items() 37 | if k in inspect.signature(cls).parameters 38 | }) -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/multilspy_exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the exceptions raised by the Multilspy framework. 3 | """ 4 | 5 | class MultilspyException(Exception): 6 | """ 7 | Exceptions raised by the Multilspy framework. 8 | """ 9 | 10 | def __init__(self, message: str): 11 | """ 12 | Initializes the exception with the given message. 13 | """ 14 | super().__init__(message) -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/multilspy_logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multilspy logger module. 3 | """ 4 | import inspect 5 | import logging 6 | from datetime import datetime 7 | from pydantic import BaseModel 8 | 9 | class LogLine(BaseModel): 10 | """ 11 | Represents a line in the Multilspy log 12 | """ 13 | 14 | time: str 15 | level: str 16 | caller_file: str 17 | caller_name: str 18 | caller_line: int 19 | message: str 20 | 21 | class MultilspyLogger: 22 | """ 23 | Logger class 24 | """ 25 | 26 | def __init__(self) -> None: 27 | self.logger = logging.getLogger("multilspy") 28 | self.logger.setLevel(logging.INFO) 29 | 30 | def log(self, debug_message: str, level: int, sanitized_error_message: str = "") -> None: 31 | """ 32 | Log the debug and santized messages using the logger 33 | """ 34 | 35 | debug_message = debug_message.replace("'", '"').replace("\n", " ") 36 | sanitized_error_message = sanitized_error_message.replace("'", '"').replace("\n", " ") 37 | 38 | # Collect details about the callee 39 | curframe = inspect.currentframe() 40 | calframe = inspect.getouterframes(curframe, 2) 41 | caller_file = calframe[1][1].split("/")[-1] 42 | caller_line = calframe[1][2] 43 | caller_name = calframe[1][3] 44 | 45 | # Construct the debug log line 46 | debug_log_line = LogLine( 47 | time=str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), 48 | level=logging.getLevelName(level), 49 | caller_file=caller_file, 50 | caller_name=caller_name, 51 | caller_line=caller_line, 52 | message=debug_message, 53 | ) 54 | 55 | self.logger.log( 56 | level=level, 57 | msg=debug_log_line.json(), 58 | ) 59 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/multilspy_settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the settings for multilspy. 3 | """ 4 | 5 | import os 6 | import pathlib 7 | 8 | class MultilspySettings: 9 | """ 10 | Provides the various settings for multilspy. 11 | """ 12 | @staticmethod 13 | def get_language_server_directory() -> str: 14 | """Returns the directory for language servers""" 15 | user_home = pathlib.Path.home() 16 | multilspy_dir = str(pathlib.PurePath(user_home, ".multilspy")) 17 | lsp_dir = str(pathlib.PurePath(multilspy_dir, "lsp")) 18 | os.makedirs(lsp_dir, exist_ok=True) 19 | return lsp_dir 20 | 21 | @staticmethod 22 | def get_global_cache_directory() -> str: 23 | """Returns the cache directory""" 24 | global_cache_dir = os.path.join(str(pathlib.Path.home()), ".multilspy", "global_cache") 25 | os.makedirs(global_cache_dir, exist_ok=True) 26 | return global_cache_dir 27 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/multilspy_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains various utility functions like I/O operations, handling paths, etc. 3 | """ 4 | 5 | import gzip 6 | import logging 7 | import os 8 | from typing import Tuple 9 | import requests 10 | import shutil 11 | import uuid 12 | 13 | import platform 14 | import subprocess 15 | from enum import Enum 16 | 17 | from monitors4codegen.multilspy.multilspy_exceptions import MultilspyException 18 | from pathlib import PurePath, Path 19 | from monitors4codegen.multilspy.multilspy_logger import MultilspyLogger 20 | 21 | class TextUtils: 22 | """ 23 | Utilities for text operations. 24 | """ 25 | @staticmethod 26 | def get_line_col_from_index(text: str, index: int) -> Tuple[int, int]: 27 | """ 28 | Returns the zero-indexed line and column number of the given index in the given text 29 | """ 30 | l = 0 31 | c = 0 32 | idx = 0 33 | while idx < index: 34 | if text[idx] == '\n': 35 | l += 1 36 | c = 0 37 | else: 38 | c += 1 39 | idx += 1 40 | 41 | return l, c 42 | 43 | @staticmethod 44 | def get_index_from_line_col(text: str, line: int, col: int) -> int: 45 | """ 46 | Returns the index of the given zero-indexed line and column number in the given text 47 | """ 48 | idx = 0 49 | while line > 0: 50 | assert idx < len(text), (idx, len(text), text) 51 | if text[idx] == "\n": 52 | line -= 1 53 | idx += 1 54 | idx += col 55 | return idx 56 | 57 | @staticmethod 58 | def get_updated_position_from_line_and_column_and_edit(l: int, c: int, text_to_be_inserted: str) -> Tuple[int, int]: 59 | """ 60 | Utility function to get the position of the cursor after inserting text at a given line and column. 61 | """ 62 | num_newlines_in_gen_text = text_to_be_inserted.count('\n') 63 | if num_newlines_in_gen_text > 0: 64 | l += num_newlines_in_gen_text 65 | c = len(text_to_be_inserted.split('\n')[-1]) 66 | else: 67 | c += len(text_to_be_inserted) 68 | return (l, c) 69 | 70 | class PathUtils: 71 | """ 72 | Utilities for platform-agnostic path operations. 73 | """ 74 | @staticmethod 75 | def uri_to_path(uri: str) -> str: 76 | """ 77 | Converts a URI to a file path. Works on both Linux and Windows. 78 | 79 | This method was obtained from https://stackoverflow.com/a/61922504 80 | """ 81 | try: 82 | from urllib.parse import urlparse, unquote 83 | from urllib.request import url2pathname 84 | except ImportError: 85 | # backwards compatability 86 | from urlparse import urlparse 87 | from urllib import unquote, url2pathname 88 | parsed = urlparse(uri) 89 | host = "{0}{0}{mnt}{0}".format(os.path.sep, mnt=parsed.netloc) 90 | return os.path.normpath(os.path.join(host, url2pathname(unquote(parsed.path)))) 91 | 92 | class FileUtils: 93 | """ 94 | Utility functions for file operations. 95 | """ 96 | 97 | @staticmethod 98 | def read_file(logger: MultilspyLogger, file_path: str) -> str: 99 | """ 100 | Reads the file at the given path and returns the contents as a string. 101 | """ 102 | encodings = ["utf-8-sig", "utf-16"] 103 | try: 104 | for encoding in encodings: 105 | try: 106 | with open(file_path, "r", encoding=encoding) as inp_file: 107 | return inp_file.read() 108 | except UnicodeError: 109 | continue 110 | except Exception as exc: 111 | logger.log(f"File read '{file_path}' failed: {exc}", logging.ERROR) 112 | raise MultilspyException("File read failed.") from None 113 | logger.log(f"File read '{file_path}' failed: Unsupported encoding.", logging.ERROR) 114 | raise MultilspyException(f"File read '{file_path}' failed: Unsupported encoding.") from None 115 | 116 | @staticmethod 117 | def download_file(logger: MultilspyLogger, url: str, target_path: str) -> None: 118 | """ 119 | Downloads the file from the given URL to the given {target_path} 120 | """ 121 | try: 122 | response = requests.get(url, stream=True, timeout=60) 123 | if response.status_code != 200: 124 | logger.log(f"Error downloading file '{url}': {response.status_code} {response.text}", logging.ERROR) 125 | raise MultilspyException("Error downoading file.") 126 | with open(target_path, "wb") as f: 127 | shutil.copyfileobj(response.raw, f) 128 | except Exception as exc: 129 | logger.log(f"Error downloading file '{url}': {exc}", logging.ERROR) 130 | raise MultilspyException("Error downoading file.") from None 131 | 132 | @staticmethod 133 | def download_and_extract_archive(logger: MultilspyLogger, url: str, target_path: str, archive_type: str) -> None: 134 | """ 135 | Downloads the archive from the given URL having format {archive_type} and extracts it to the given {target_path} 136 | """ 137 | try: 138 | tmp_files = [] 139 | tmp_file_name = str(PurePath(os.path.expanduser("~"), "multilspy_tmp", uuid.uuid4().hex)) 140 | tmp_files.append(tmp_file_name) 141 | os.makedirs(os.path.dirname(tmp_file_name), exist_ok=True) 142 | FileUtils.download_file(logger, url, tmp_file_name) 143 | if archive_type in ["zip", "tar", "gztar", "bztar", "xztar"]: 144 | assert os.path.isdir(target_path) 145 | shutil.unpack_archive(tmp_file_name, target_path, archive_type) 146 | elif archive_type == "zip.gz": 147 | assert os.path.isdir(target_path) 148 | tmp_file_name_ungzipped = tmp_file_name + ".zip" 149 | tmp_files.append(tmp_file_name_ungzipped) 150 | with gzip.open(tmp_file_name, "rb") as f_in, open(tmp_file_name_ungzipped, "wb") as f_out: 151 | shutil.copyfileobj(f_in, f_out) 152 | shutil.unpack_archive(tmp_file_name_ungzipped, target_path, "zip") 153 | elif archive_type == "gz": 154 | with gzip.open(tmp_file_name, "rb") as f_in, open(target_path, "wb") as f_out: 155 | shutil.copyfileobj(f_in, f_out) 156 | else: 157 | logger.log(f"Unknown archive type '{archive_type}' for extraction", logging.ERROR) 158 | raise MultilspyException(f"Unknown archive type '{archive_type}'") 159 | except Exception as exc: 160 | logger.log(f"Error extracting archive '{tmp_file_name}' obtained from '{url}': {exc}", logging.ERROR) 161 | raise MultilspyException("Error extracting archive.") from exc 162 | finally: 163 | for tmp_file_name in tmp_files: 164 | if os.path.exists(tmp_file_name): 165 | Path.unlink(Path(tmp_file_name)) 166 | 167 | class PlatformId(str, Enum): 168 | """ 169 | multilspy supported platforms 170 | """ 171 | WIN_x86 = "win-x86" 172 | WIN_x64 = "win-x64" 173 | WIN_arm64 = "win-arm64" 174 | OSX = "osx" 175 | OSX_x64 = "osx-x64" 176 | OSX_arm64 = "osx-arm64" 177 | LINUX_x86 = "linux-x86" 178 | LINUX_x64 = "linux-x64" 179 | LINUX_arm64 = "linux-arm64" 180 | LINUX_MUSL_x64 = "linux-musl-x64" 181 | LINUX_MUSL_arm64 = "linux-musl-arm64" 182 | 183 | class DotnetVersion(str, Enum): 184 | """ 185 | multilspy supported dotnet versions 186 | """ 187 | V4 = "4" 188 | V6 = "6" 189 | V7 = "7" 190 | VMONO = "mono" 191 | 192 | class PlatformUtils: 193 | """ 194 | This class provides utilities for platform detection and identification. 195 | """ 196 | 197 | @staticmethod 198 | def get_platform_id() -> PlatformId: 199 | """ 200 | Returns the platform id for the current system 201 | """ 202 | system = platform.system() 203 | machine = platform.machine() 204 | bitness = platform.architecture()[0] 205 | system_map = {"Windows": "win", "Darwin": "osx", "Linux": "linux"} 206 | machine_map = {"AMD64": "x64", "x86_64": "x64", "i386": "x86", "i686": "x86", "aarch64": "arm64"} 207 | if system in system_map and machine in machine_map: 208 | platform_id = system_map[system] + "-" + machine_map[machine] 209 | if system == "Linux" and bitness == "64bit": 210 | libc = platform.libc_ver()[0] 211 | if libc != 'glibc': 212 | platform_id += "-" + libc 213 | return PlatformId(platform_id) 214 | else: 215 | raise MultilspyException("Unknown platform: " + system + " " + machine + " " + bitness) 216 | 217 | @staticmethod 218 | def get_dotnet_version() -> DotnetVersion: 219 | """ 220 | Returns the dotnet version for the current system 221 | """ 222 | try: 223 | result = subprocess.run(["dotnet", "--list-runtimes"], capture_output=True, check=True) 224 | version = '' 225 | for line in result.stdout.decode('utf-8').split('\n'): 226 | if line.startswith('Microsoft.NETCore.App'): 227 | version = line.split(' ')[1] 228 | break 229 | if version == '': 230 | raise MultilspyException("dotnet not found on the system") 231 | if version.startswith("7"): 232 | return DotnetVersion.V7 233 | elif version.startswith("6"): 234 | return DotnetVersion.V6 235 | elif version.startswith("4"): 236 | return DotnetVersion.V4 237 | else: 238 | raise MultilspyException("Unknown dotnet version: " + version) 239 | except subprocess.CalledProcessError: 240 | try: 241 | result = subprocess.run(["mono", "--version"], capture_output=True, check=True) 242 | return DotnetVersion.VMONO 243 | except subprocess.CalledProcessError: 244 | raise MultilspyException("dotnet or mono not found on the system") 245 | -------------------------------------------------------------------------------- /src/monitors4codegen/multilspy/type_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides type-helpers used across multilspy implementation 3 | """ 4 | 5 | import inspect 6 | 7 | from typing import Callable, TypeVar, Type 8 | 9 | R = TypeVar("R", bound=object) 10 | 11 | def ensure_all_methods_implemented( 12 | source_cls: Type[object], 13 | ) -> Callable[[Type[R]], Type[R]]: 14 | """ 15 | A decorator to ensure that all methods of source_cls class are implemented in the decorated class. 16 | """ 17 | 18 | def check_all_methods_implemented(target_cls: R) -> R: 19 | for name, _ in inspect.getmembers(source_cls, inspect.isfunction): 20 | if name not in target_cls.__dict__ or not callable(target_cls.__dict__[name]): 21 | raise NotImplementedError(f"{name} is not implemented in {target_cls}") 22 | 23 | return target_cls 24 | 25 | return check_all_methods_implemented -------------------------------------------------------------------------------- /tests/monitor_guided_decoding/test_classinstantiation_monitor_java.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for Monitor-Guided Decoding for valid class instantiations in Java 3 | """ 4 | 5 | import torch 6 | import transformers 7 | import pytest 8 | 9 | from pathlib import PurePath 10 | from monitors4codegen.multilspy.language_server import SyncLanguageServer 11 | from monitors4codegen.multilspy.multilspy_config import Language 12 | from tests.test_utils import create_test_context, is_cuda_available 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 15 | from monitors4codegen.monitor_guided_decoding.monitors.class_instantiation_monitor import ClassInstantiationMonitor 16 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 17 | from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor 18 | from transformers.generation.utils import LogitsProcessorList 19 | from monitors4codegen.multilspy.multilspy_types import Position 20 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper 21 | 22 | pytest_plugins = ("pytest_asyncio",) 23 | 24 | @pytest.mark.asyncio 25 | async def test_multilspy_java_example_repo_class_instantiation() -> None: 26 | """ 27 | Test the working of ClassInstantiationMonitor with Java repository - ExampleRepo 28 | """ 29 | code_language = Language.JAVA 30 | params = { 31 | "code_language": code_language, 32 | "repo_url": "https://github.com/LakshyAAAgrawal/ExampleRepo/", 33 | "repo_commit": "f3762fd55a457ff9c6b0bf3b266de2b203a766ab", 34 | } 35 | 36 | device = torch.device('cuda' if is_cuda_available() else 'cpu') 37 | 38 | model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( 39 | "bigcode/santacoder", trust_remote_code=True 40 | ).to(device) 41 | tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") 42 | 43 | with create_test_context(params) as context: 44 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 45 | with lsp.start_server(): 46 | completions_filepath = "Main.java" 47 | with lsp.open_file(completions_filepath): 48 | deleted_text = lsp.delete_text_between_positions( 49 | completions_filepath, 50 | Position(line=16, character=24), 51 | Position(line=36, character=5) 52 | ) 53 | assert deleted_text == """Student("Alice", 10); 54 | Person p2 = new Teacher("Bob", "Science"); 55 | 56 | // Create some course objects 57 | Course c1 = new Course("Math 101", t1, mathStudents); 58 | Course c2 = new Course("English 101", t2, englishStudents); 59 | 60 | // Print some information about the objects 61 | 62 | System.out.println("Person p1's name is " + p1.getName()); 63 | 64 | System.out.println("Student s1's name is " + s1.getName()); 65 | System.out.println("Student s1's id is " + s1.getId()); 66 | 67 | System.out.println("Teacher t1's name is " + t1.getName()); 68 | System.out.println("Teacher t1's subject is " + t1.getSubject()); 69 | 70 | System.out.println("Course c1's name is " + c1.getName()); 71 | System.out.println("Course c1's teacher is " + c1.getTeacher().getName()); 72 | 73 | """ 74 | prompt_pos = (16, 24) 75 | 76 | with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: 77 | filecontent = f.read() 78 | 79 | pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) 80 | assert filecontent[:pos_idx].endswith('new ') 81 | 82 | prompt = filecontent[:pos_idx] 83 | assert filecontent[pos_idx-1] == " " 84 | prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] 85 | 86 | generated_code_without_mgd = model.generate( 87 | prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True 88 | ) 89 | generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -30:]) 90 | 91 | assert ( 92 | generated_code_without_mgd 93 | == " Person(\"John\", \"Doe\", \"123-4567\", \"kenaa@example.com\", \"1234" 94 | ) 95 | 96 | filebuffer = MonitorFileBuffer( 97 | lsp.language_server, 98 | completions_filepath, 99 | prompt_pos, 100 | prompt_pos, 101 | code_language, 102 | ) 103 | monitor = ClassInstantiationMonitor(HFTokenizerWrapper(tokenizer), filebuffer) 104 | mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) 105 | 106 | # Generate code using santacoder model with the MGD logits processor and greedy decoding 107 | logits_processor = LogitsProcessorList([mgd_logits_processor]) 108 | generated_code = model.generate( 109 | prompt_tokenized, 110 | do_sample=False, 111 | max_new_tokens=30, 112 | logits_processor=logits_processor, 113 | early_stopping=True, 114 | ) 115 | 116 | generated_code = tokenizer.decode(generated_code[0, -30:]) 117 | 118 | assert ( 119 | generated_code 120 | == "Student(\"John\", 1001);\n Person p2 = new Student(\"Mary\", 1002);\n Person p" 121 | ) 122 | -------------------------------------------------------------------------------- /tests/monitor_guided_decoding/test_dereferences_monitor_java_openai.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for Monitor-Guided Decoding for dereferences in Java, with OpenAI models 3 | """ 4 | import os 5 | import pytest 6 | import tiktoken 7 | import torch 8 | 9 | from openai import OpenAI 10 | from pathlib import PurePath 11 | from monitors4codegen.multilspy.language_server import SyncLanguageServer 12 | from monitors4codegen.multilspy.multilspy_config import Language 13 | from tests.test_utils import create_test_context 14 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 15 | from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor 16 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 17 | from monitors4codegen.multilspy.multilspy_types import Position 18 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TikTokenWrapper 19 | from monitors4codegen.monitor_guided_decoding.openai_gen import openai_mgd, OpenAI_Models 20 | 21 | @pytest.mark.asyncio 22 | @pytest.mark.skipif(not "OPENAI_API_KEY" in os.environ, reason="OpenAI API key not found") 23 | async def test_dereferences_monitor_java_openai_clickhouse_highlevel_sinker(): 24 | """ 25 | Test the working of dereferences monitor with Java repository - clickhouse-highlevel-sinker modified 26 | """ 27 | code_language = Language.JAVA 28 | params = { 29 | "code_language": code_language, 30 | "repo_url": "https://github.com/Index103000/clickhouse-highlevel-sinker/", 31 | "repo_commit": "ee31d278918fe5e64669a6840c4d8fb53889e573", 32 | } 33 | 34 | client = OpenAI() 35 | client.api_key = os.environ["OPENAI_API_KEY"] 36 | 37 | encoding = tiktoken.encoding_for_model("text-davinci-003") 38 | tokenizer = TikTokenWrapper(encoding) 39 | 40 | with create_test_context(params) as context: 41 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 42 | completions_filepath = "src/main/java/com/xlvchao/clickhouse/datasource/ClickHouseDataSource.java" 43 | # All the communication with the language server must be performed inside the context manager 44 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 45 | with lsp.start_server(): 46 | with lsp.open_file(completions_filepath): 47 | filebuffer = MonitorFileBuffer(lsp.language_server, completions_filepath, (74, 17), (74, 17), code_language, "") 48 | deleted_text = filebuffer.lsp.delete_text_between_positions( 49 | completions_filepath, Position(line=74, character=17), Position(line=78, character=4) 50 | ) 51 | assert ( 52 | deleted_text 53 | == """newServerNode() 54 | .withIp(arr[0]) 55 | .withPort(Integer.parseInt(arr[1])) 56 | .build(); 57 | """ 58 | ) 59 | 60 | prompt_pos = (74, 17) 61 | 62 | with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: 63 | filecontent = f.read() 64 | 65 | pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) 66 | assert filecontent[pos_idx] == "n" 67 | prompt = filecontent[:pos_idx] 68 | assert prompt[-1] == "." 69 | prompt_tokenized = encoding.encode(prompt)[-(2048 - 512):] 70 | 71 | completions = client.completions.create( 72 | model="text-davinci-003", 73 | prompt=tokenizer.decode(torch.Tensor(prompt_tokenized).long(), skip_special_tokens=True, clean_up_tokenization_spaces=False), 74 | max_tokens=30, 75 | temperature=0 76 | ) 77 | 78 | generated_code_without_mgd = completions.choices[0].text 79 | 80 | assert ( 81 | generated_code_without_mgd == 82 | 'newBuilder()\n .host(arr[0])\n .port(Integer.parseInt(arr[1]))\n .' 83 | ) 84 | 85 | filebuffer = MonitorFileBuffer( 86 | lsp.language_server, 87 | completions_filepath, 88 | prompt_pos, 89 | prompt_pos, 90 | code_language, 91 | ) 92 | monitor = DereferencesMonitor(tokenizer, filebuffer) 93 | 94 | gen_tokens, generated_code = openai_mgd( 95 | client=client, 96 | model=OpenAI_Models.TD3, 97 | tokenizer=tokenizer, 98 | prompt_tokenized=prompt_tokenized, 99 | temp=0, 100 | top_p=0.95, 101 | monitor=monitor, 102 | num_new_tokens=30, 103 | ) 104 | 105 | assert ( 106 | generated_code == 107 | 'newServerNode()\n .withIp(arr[0])\n .withPort(Integer.parseInt(arr[1]' 108 | ) 109 | 110 | @pytest.mark.asyncio 111 | @pytest.mark.skipif(not "OPENAI_API_KEY" in os.environ, reason="OpenAI API key not found") 112 | async def test_dereferences_monitor_java_openai_clickhouse_highlevel_sinker_modified(): 113 | """ 114 | Test the working of dereferences monitor with Java repository - clickhouse-highlevel-sinker modified 115 | """ 116 | code_language = Language.JAVA 117 | params = { 118 | "code_language": code_language, 119 | "repo_url": "https://github.com/LakshyAAAgrawal/clickhouse-highlevel-sinker/", 120 | "repo_commit": "5775fd7a67e7b60998e1614cf44a8a1fc3190ab0" 121 | } 122 | 123 | client = OpenAI() 124 | client.api_key = os.environ["OPENAI_API_KEY"] 125 | 126 | encoding = tiktoken.encoding_for_model("text-davinci-003") 127 | tokenizer = TikTokenWrapper(encoding) 128 | 129 | with create_test_context(params) as context: 130 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 131 | # All the communication with the language server must be performed inside the context manager 132 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 133 | # The context manager is an asynchronous context manager, so it must be used with async with. 134 | with lsp.start_server(): 135 | completions_filepath = "src/main/java/com/xlvchao/clickhouse/datasource/ClickHouseDataSource.java" 136 | with lsp.open_file(completions_filepath): 137 | deleted_text = lsp.delete_text_between_positions( 138 | completions_filepath, 139 | Position(line=75, character=17), 140 | Position(line=77, character=4) 141 | ) 142 | assert deleted_text == """withIpPort(arr[0], Integer.parseInt(arr[1])) 143 | .build(); 144 | """ 145 | 146 | prompt_pos = (75, 17) 147 | 148 | with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: 149 | filecontent = f.read() 150 | 151 | pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) 152 | assert filecontent[pos_idx] == "w" 153 | prompt = filecontent[:pos_idx] 154 | assert prompt[-1] == "." 155 | prompt_tokenized = encoding.encode(prompt)[-(2048 - 512):] 156 | 157 | completions = client.completions.create( 158 | model="text-davinci-003", 159 | prompt=tokenizer.decode(torch.Tensor(prompt_tokenized).long(), skip_special_tokens=True, clean_up_tokenization_spaces=False), 160 | max_tokens=30, 161 | temperature=0 162 | ) 163 | 164 | generated_code_without_mgd = completions.choices[0].text 165 | 166 | assert ( 167 | generated_code_without_mgd == 168 | 'host(arr[0])\n .port(Integer.parseInt(arr[1]))\n .build();\n }\n\n' 169 | ) 170 | 171 | filebuffer = MonitorFileBuffer( 172 | lsp.language_server, 173 | completions_filepath, 174 | prompt_pos, 175 | prompt_pos, 176 | code_language, 177 | ) 178 | monitor = DereferencesMonitor(tokenizer, filebuffer) 179 | 180 | gen_tokens, generated_code = openai_mgd( 181 | client=client, 182 | model=OpenAI_Models.TD3, 183 | tokenizer=tokenizer, 184 | prompt_tokenized=prompt_tokenized, 185 | temp=0, 186 | top_p=0.95, 187 | monitor=monitor, 188 | num_new_tokens=30, 189 | ) 190 | 191 | assert ( 192 | generated_code == 193 | 'withIpPort(arr[0], Integer.parseInt(arr[1]))\n .build();\n }\n\n private' 194 | ) 195 | -------------------------------------------------------------------------------- /tests/monitor_guided_decoding/test_joint_monitors.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for Monitor-Guided Decoding running 2 monitors simultaneously 3 | """ 4 | 5 | import torch 6 | import transformers 7 | import pytest 8 | 9 | from pathlib import PurePath 10 | from monitors4codegen.multilspy.language_server import SyncLanguageServer 11 | from monitors4codegen.multilspy.multilspy_config import Language 12 | from tests.test_utils import create_test_context, is_cuda_available 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 15 | from monitors4codegen.monitor_guided_decoding.monitors.switch_enum_monitor import SwitchEnumMonitor 16 | from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor 17 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 18 | from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor 19 | from transformers.generation.utils import LogitsProcessorList 20 | from monitors4codegen.multilspy.multilspy_types import Position 21 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper 22 | 23 | pytest_plugins = ("pytest_asyncio",) 24 | 25 | @pytest.mark.asyncio 26 | @pytest.mark.skip(reason="TODO: This runs too slow. Reimplement joint monitoring") 27 | async def test_multilspy_csharp_ryujinx_joint_switch_enum_dereferences() -> None: 28 | """ 29 | Test the working of Joint monitoring with SwitchEnumMonitor and DereferencesMonitor with C# repository - Ryujinx 30 | """ 31 | 32 | code_language = Language.CSHARP 33 | params = { 34 | "code_language": code_language, 35 | "repo_url": "https://github.com/Ryujinx/Ryujinx/", 36 | "repo_commit": "e768a54f17b390c3ac10904c7909e3bef020edbd" 37 | } 38 | 39 | device = torch.device('cuda' if is_cuda_available() else 'cpu') 40 | 41 | model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( 42 | "bigcode/santacoder", trust_remote_code=True 43 | ).to(device) 44 | tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") 45 | 46 | with create_test_context(params) as context: 47 | lsp1 = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 48 | lsp2 = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 49 | with lsp1.start_server(), lsp2.start_server(): 50 | completions_filepath = "src/ARMeilleure/CodeGen/X86/CodeGenerator.cs" 51 | with lsp1.open_file(completions_filepath), lsp2.open_file(completions_filepath): 52 | deleted_text1 = lsp1.delete_text_between_positions( 53 | completions_filepath, 54 | Position(line=224, character=37), 55 | Position(line=243, character=28) 56 | ) 57 | deleted_text2 = lsp2.delete_text_between_positions( 58 | completions_filepath, 59 | Position(line=224, character=37), 60 | Position(line=243, character=28) 61 | ) 62 | assert deleted_text1 == deleted_text2 63 | assert deleted_text1 == """Intrinsic.X86Comisdlt: 64 | context.Assembler.Comisd(src1, src2); 65 | context.Assembler.Setcc(dest, X86Condition.Below); 66 | break; 67 | 68 | case Intrinsic.X86Comisseq: 69 | context.Assembler.Comiss(src1, src2); 70 | context.Assembler.Setcc(dest, X86Condition.Equal); 71 | break; 72 | 73 | case Intrinsic.X86Comissge: 74 | context.Assembler.Comiss(src1, src2); 75 | context.Assembler.Setcc(dest, X86Condition.AboveOrEqual); 76 | break; 77 | 78 | case Intrinsic.X86Comisslt: 79 | context.Assembler.Comiss(src1, src2); 80 | context.Assembler.Setcc(dest, X86Condition.Below); 81 | break; 82 | """ 83 | filebuffer_enum = MonitorFileBuffer( 84 | lsp1.language_server, 85 | completions_filepath, 86 | (224, 37), 87 | (224, 37), 88 | code_language, 89 | ) 90 | monitor_switch_enum = SwitchEnumMonitor(HFTokenizerWrapper(tokenizer), filebuffer_enum) 91 | mgd_logits_processor_switch_enum = MGDLogitsProcessor([monitor_switch_enum], lsp1.language_server.server.loop) 92 | 93 | filebuffer_dereferences = MonitorFileBuffer( 94 | lsp2.language_server, 95 | completions_filepath, 96 | (224, 37), 97 | (224, 37), 98 | code_language, 99 | ) 100 | monitor_dereferences = DereferencesMonitor(HFTokenizerWrapper(tokenizer), filebuffer_dereferences) 101 | mgd_logits_processor_dereferences = MGDLogitsProcessor([monitor_dereferences], lsp2.language_server.server.loop) 102 | 103 | with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: 104 | filecontent = f.read() 105 | 106 | pos_idx = TextUtils.get_index_from_line_col(filecontent, 224, 37) 107 | assert filecontent[:pos_idx].endswith('case ') 108 | prompt = filecontent[:pos_idx] 109 | assert prompt[-1] == " " 110 | prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] 111 | 112 | generated_code_without_mgd = model.generate( 113 | prompt_tokenized, do_sample=False, max_new_tokens=50, early_stopping=True 114 | ) 115 | generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -50:]) 116 | 117 | assert ( 118 | generated_code_without_mgd 119 | == " Intrinsic.X86Comisdgt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.GreaterThan);\n break;\n\n case In" 120 | ) 121 | 122 | # Generate code using santacoder model with the MGD logits processor and greedy decoding 123 | logits_processor = LogitsProcessorList([mgd_logits_processor_switch_enum, mgd_logits_processor_dereferences]) 124 | generated_code = model.generate( 125 | prompt_tokenized, 126 | do_sample=False, 127 | max_new_tokens=50, 128 | logits_processor=logits_processor, 129 | early_stopping=True, 130 | ) 131 | 132 | generated_code = tokenizer.decode(generated_code[0, -50:]) 133 | 134 | assert ( 135 | generated_code 136 | == "Intrinsic.X86Comisdlt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.Below);\n break;\n\n case Intrinsic" 137 | ) 138 | -------------------------------------------------------------------------------- /tests/monitor_guided_decoding/test_numargs_monitor_java.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for Monitor-Guided Decoding for correct number of arguments in Java 3 | """ 4 | 5 | import torch 6 | import transformers 7 | import pytest 8 | 9 | from pathlib import PurePath 10 | from monitors4codegen.multilspy.language_server import SyncLanguageServer 11 | from monitors4codegen.multilspy.multilspy_config import Language 12 | from tests.test_utils import create_test_context, is_cuda_available 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 15 | from monitors4codegen.monitor_guided_decoding.monitors.numargs_monitor import NumMethodArgumentsMonitor 16 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 17 | from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor 18 | from transformers.generation.utils import LogitsProcessorList 19 | from monitors4codegen.multilspy.multilspy_types import Position 20 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper 21 | 22 | pytest_plugins = ("pytest_asyncio",) 23 | 24 | @pytest.mark.asyncio 25 | async def test_multilspy_java_clickhouse_highlevel_sinker_modified_numargs(): 26 | """ 27 | Test the working of numargs_monitor with Java repository - clickhouse-highlevel-sinker modified 28 | """ 29 | code_language = Language.JAVA 30 | params = { 31 | "code_language": code_language, 32 | "repo_url": "https://github.com/LakshyAAAgrawal/clickhouse-highlevel-sinker/", 33 | "repo_commit": "5775fd7a67e7b60998e1614cf44a8a1fc3190ab0" 34 | } 35 | 36 | device = torch.device('cuda' if is_cuda_available() else 'cpu') 37 | 38 | model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( 39 | "bigcode/santacoder", trust_remote_code=True 40 | ).to(device) 41 | tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") 42 | 43 | with create_test_context(params) as context: 44 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 45 | # All the communication with the language server must be performed inside the context manager 46 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 47 | # The context manager is an asynchronous context manager, so it must be used with async with. 48 | with lsp.start_server(): 49 | completions_filepath = "src/main/java/com/xlvchao/clickhouse/datasource/ClickHouseDataSource.java" 50 | with lsp.open_file(completions_filepath): 51 | deleted_text = lsp.delete_text_between_positions( 52 | completions_filepath, 53 | Position(line=75, character=28), 54 | Position(line=77, character=4) 55 | ) 56 | assert deleted_text == """arr[0], Integer.parseInt(arr[1])) 57 | .build(); 58 | """ 59 | 60 | prompt_pos = (75, 28) 61 | 62 | with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: 63 | filecontent = f.read() 64 | 65 | pos_idx = TextUtils.get_index_from_line_col(filecontent, prompt_pos[0], prompt_pos[1]) 66 | assert filecontent[pos_idx] == "a" 67 | prompt = filecontent[:pos_idx] 68 | assert prompt[-1] == "(" 69 | prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] 70 | 71 | gen = model.generate( 72 | prompt_tokenized, do_sample=False, max_new_tokens=30, early_stopping=True 73 | ) 74 | generated_code_without_mgd = tokenizer.decode(gen[0, -30:]) 75 | 76 | assert ( 77 | generated_code_without_mgd == 78 | "arr[0])\n .withPort(Integer.parseInt(arr[1]))\n .build();\n }\n\n private List" 79 | ) 80 | 81 | filebuffer = MonitorFileBuffer( 82 | lsp.language_server, 83 | completions_filepath, 84 | prompt_pos, 85 | prompt_pos, 86 | code_language, 87 | ) 88 | monitor = NumMethodArgumentsMonitor(HFTokenizerWrapper(tokenizer), filebuffer) 89 | mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) 90 | 91 | # Generate code using santacoder model with the MGD logits processor and greedy decoding 92 | logits_processor = LogitsProcessorList([mgd_logits_processor]) 93 | gen = model.generate( 94 | prompt_tokenized, 95 | do_sample=False, 96 | max_new_tokens=30, 97 | logits_processor=logits_processor, 98 | early_stopping=True, 99 | ) 100 | 101 | generated_code = tokenizer.decode(gen[0, -30:]) 102 | 103 | assert ( 104 | generated_code == 105 | "arr[0].trim(), Integer.parseInt(arr[1].trim()))\n .build();\n }\n\n private List convertTo" 106 | ) 107 | -------------------------------------------------------------------------------- /tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for Monitor-Guided Decoding for switch-enum in C# 3 | """ 4 | 5 | import torch 6 | import transformers 7 | import pytest 8 | 9 | from pathlib import PurePath 10 | from monitors4codegen.multilspy.language_server import SyncLanguageServer 11 | from monitors4codegen.multilspy.multilspy_config import Language 12 | from tests.test_utils import create_test_context, is_cuda_available 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 15 | from monitors4codegen.monitor_guided_decoding.monitors.switch_enum_monitor import SwitchEnumMonitor 16 | from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor 17 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 18 | from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor 19 | from transformers.generation.utils import LogitsProcessorList 20 | from monitors4codegen.multilspy.multilspy_types import Position 21 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper 22 | 23 | pytest_plugins = ("pytest_asyncio",) 24 | 25 | @pytest.mark.asyncio 26 | async def test_multilspy_csharp_ryujinx_switch_enum() -> None: 27 | """ 28 | Test the working of SwitchEnumMonitor with C# repository - Ryujinx 29 | """ 30 | code_language = Language.CSHARP 31 | params = { 32 | "code_language": code_language, 33 | "repo_url": "https://github.com/Ryujinx/Ryujinx/", 34 | "repo_commit": "e768a54f17b390c3ac10904c7909e3bef020edbd" 35 | } 36 | 37 | device = torch.device('cuda' if is_cuda_available() else 'cpu') 38 | 39 | model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( 40 | "bigcode/santacoder", trust_remote_code=True 41 | ).to(device) 42 | tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") 43 | 44 | with create_test_context(params) as context: 45 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 46 | with lsp.start_server(): 47 | completions_filepath = "src/ARMeilleure/CodeGen/Arm64/CodeGenerator.cs" 48 | with lsp.open_file(completions_filepath): 49 | deleted_text = lsp.delete_text_between_positions( 50 | completions_filepath, 51 | Position(line=1369, character=21), 52 | Position(line=1385, character=8) 53 | ) 54 | assert deleted_text == """AccessSize.Byte: 55 | context.Assembler.Stlxrb(desired, address, result); 56 | break; 57 | case AccessSize.Hword: 58 | context.Assembler.Stlxrh(desired, address, result); 59 | break; 60 | default: 61 | context.Assembler.Stlxr(desired, address, result); 62 | break; 63 | } 64 | 65 | context.Assembler.Cbnz(result, startOffset - context.StreamOffset); // Retry if store failed. 66 | 67 | context.JumpHere(); 68 | 69 | context.Assembler.Clrex(); 70 | """ 71 | filebuffer = MonitorFileBuffer( 72 | lsp.language_server, 73 | completions_filepath, 74 | (1369, 21), 75 | (1369, 21), 76 | code_language, 77 | ) 78 | monitor = SwitchEnumMonitor(HFTokenizerWrapper(tokenizer), filebuffer) 79 | mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) 80 | 81 | with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: 82 | filecontent = f.read() 83 | 84 | pos_idx = TextUtils.get_index_from_line_col(filecontent, 1369, 21) 85 | assert filecontent[:pos_idx].endswith('case ') 86 | prompt = filecontent[:pos_idx] 87 | assert prompt[-1] == " " 88 | prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] 89 | 90 | generated_code_without_mgd = model.generate( 91 | prompt_tokenized, do_sample=False, max_new_tokens=100, early_stopping=True 92 | ) 93 | generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -100:]) 94 | 95 | assert ( 96 | generated_code_without_mgd 97 | == "1:\n context.Assembler.Stb(result, Register(ZrRegister, result.Type));\n break;\n case 2:\n context.Assembler.Stw(result, Register(ZrRegister, result.Type));\n break;\n case 4:\n context.Assembler.Std(result, Register(ZrRegister, result.Type));\n break;\n case 8:\n context.Assembler.Stq(result, Register(Zr" 98 | ) 99 | 100 | # Generate code using santacoder model with the MGD logits processor and greedy decoding 101 | logits_processor = LogitsProcessorList([mgd_logits_processor]) 102 | generated_code = model.generate( 103 | prompt_tokenized, 104 | do_sample=False, 105 | max_new_tokens=100, 106 | logits_processor=logits_processor, 107 | early_stopping=True, 108 | ) 109 | 110 | generated_code = tokenizer.decode(generated_code[0, -100:]) 111 | 112 | assert ( 113 | generated_code 114 | == "AccessSize.Byte:\n context.Assembler.Staxrb(actual, address);\n break;\n case AccessSize.Hword:\n context.Assembler.Staxrh(actual, address);\n break;\n default:\n context.Assembler.Staxr(actual, address);\n break;\n }\n\n context.Assembler.Cmp(actual, desired);\n\n context.JumpToNear(ArmCondition.Eq);\n\n context.Assembler.Staxr(result," 115 | ) 116 | 117 | 118 | 119 | completions_filepath = "src/ARMeilleure/CodeGen/X86/CodeGenerator.cs" 120 | with lsp.open_file(completions_filepath): 121 | deleted_text = lsp.delete_text_between_positions( 122 | completions_filepath, 123 | Position(line=224, character=37), 124 | Position(line=243, character=28) 125 | ) 126 | assert deleted_text == """Intrinsic.X86Comisdlt: 127 | context.Assembler.Comisd(src1, src2); 128 | context.Assembler.Setcc(dest, X86Condition.Below); 129 | break; 130 | 131 | case Intrinsic.X86Comisseq: 132 | context.Assembler.Comiss(src1, src2); 133 | context.Assembler.Setcc(dest, X86Condition.Equal); 134 | break; 135 | 136 | case Intrinsic.X86Comissge: 137 | context.Assembler.Comiss(src1, src2); 138 | context.Assembler.Setcc(dest, X86Condition.AboveOrEqual); 139 | break; 140 | 141 | case Intrinsic.X86Comisslt: 142 | context.Assembler.Comiss(src1, src2); 143 | context.Assembler.Setcc(dest, X86Condition.Below); 144 | break; 145 | """ 146 | filebuffer = MonitorFileBuffer( 147 | lsp.language_server, 148 | completions_filepath, 149 | (224, 37), 150 | (224, 37), 151 | code_language, 152 | ) 153 | monitor = SwitchEnumMonitor(HFTokenizerWrapper(tokenizer), filebuffer) 154 | mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) 155 | 156 | with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f: 157 | filecontent = f.read() 158 | 159 | pos_idx = TextUtils.get_index_from_line_col(filecontent, 224, 37) 160 | assert filecontent[:pos_idx].endswith('case ') 161 | prompt = filecontent[:pos_idx] 162 | assert prompt[-1] == " " 163 | prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :] 164 | 165 | generated_code_without_mgd = model.generate( 166 | prompt_tokenized, do_sample=False, max_new_tokens=50, early_stopping=True 167 | ) 168 | generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -50:]) 169 | 170 | assert ( 171 | generated_code_without_mgd 172 | == " Intrinsic.X86Comisdgt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.GreaterThan);\n break;\n\n case In" 173 | ) 174 | 175 | # Generate code using santacoder model with the MGD logits processor and greedy decoding 176 | logits_processor = LogitsProcessorList([mgd_logits_processor]) 177 | generated_code = model.generate( 178 | prompt_tokenized, 179 | do_sample=False, 180 | max_new_tokens=50, 181 | logits_processor=logits_processor, 182 | early_stopping=True, 183 | ) 184 | 185 | generated_code = tokenizer.decode(generated_code[0, -50:]) 186 | 187 | assert ( 188 | generated_code 189 | == "Intrinsic.X86Comisdlt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.LessThan);\n break;\n\n case Intrinsic" 190 | ) 191 | -------------------------------------------------------------------------------- /tests/monitor_guided_decoding/test_typestate_monitor_rust.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for MGD for typestate 3 | """ 4 | 5 | import pytest 6 | import torch 7 | import transformers 8 | 9 | from pathlib import PurePath 10 | from monitors4codegen.multilspy.language_server import SyncLanguageServer 11 | from monitors4codegen.multilspy.multilspy_config import Language 12 | from tests.test_utils import create_test_context, is_cuda_available 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | from monitors4codegen.multilspy.multilspy_utils import TextUtils 15 | from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor 16 | from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer 17 | from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor 18 | from transformers.generation.utils import LogitsProcessorList 19 | from monitors4codegen.multilspy.multilspy_types import Position 20 | from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper 21 | 22 | pytest_plugins = ("pytest_asyncio",) 23 | 24 | @pytest.mark.asyncio 25 | async def test_typestate_monitor_rust_huggingface_models_mediaplayer() -> None: 26 | """ 27 | Test the working of typestate monitor with Rust repository - mediaplayer 28 | """ 29 | code_language = Language.RUST 30 | params = { 31 | "code_language": code_language, 32 | "repo_url": "https://github.com/LakshyAAAgrawal/MediaPlayer_example/", 33 | "repo_commit": "80cd910cfeb2a05c9e74b69773373c077b00b4c2", 34 | } 35 | 36 | device = torch.device('cuda' if is_cuda_available() else 'cpu') 37 | 38 | model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained( 39 | "bigcode/santacoder", trust_remote_code=True 40 | ).to(device) # 41 | tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder") 42 | 43 | with create_test_context(params) as context: 44 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 45 | filepath = "src/playlist.rs" 46 | # All the communication with the language server must be performed inside the context manager 47 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 48 | with lsp.start_server(): 49 | with lsp.open_file(filepath): 50 | filebuffer = MonitorFileBuffer(lsp.language_server, filepath, (10, 40), (10, 40), code_language, "") 51 | deleted_text = filebuffer.lsp.delete_text_between_positions( 52 | filepath, Position(line=10, character=40), Position(line=12, character=4) 53 | ) 54 | assert ( 55 | deleted_text 56 | == """reset(); 57 | media_player1 = media_player; 58 | """ 59 | ) 60 | monitor = DereferencesMonitor(HFTokenizerWrapper(tokenizer), filebuffer) 61 | mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop) 62 | 63 | with open(str(PurePath(context.source_directory, filepath)), "r") as f: 64 | filecontent = f.read() 65 | 66 | pos_idx = TextUtils.get_index_from_line_col(filecontent, 10, 40) 67 | assert filecontent[pos_idx] == "r" 68 | 69 | with open(str(PurePath(context.source_directory, "src/media_player.rs")), "r") as f: 70 | classExprTypes = f.read() 71 | 72 | prompt = filecontent[:pos_idx] + """(); 73 | media_player1 = media_player; 74 | } 75 | }.""" 76 | assert prompt[-1] == "." 77 | prompt_tokenized = tokenizer.encode("" + classExprTypes + '\n' + prompt, return_tensors="pt")[:, -(2048-512):].to(device) 78 | 79 | generated_code_without_mgd = model.generate( 80 | prompt_tokenized, do_sample=False, max_new_tokens=5, top_p=0.95, temperature=0.2, 81 | ) 82 | 83 | num_gen_tokens = generated_code_without_mgd.shape[-1] - prompt_tokenized.shape[-1] 84 | generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -num_gen_tokens:]) 85 | 86 | assert ( 87 | generated_code_without_mgd 88 | == """stop(); 89 | let media""" 90 | ) 91 | 92 | # Generate code using santacoder model with the MGD logits processor and greedy decoding 93 | logits_processor = LogitsProcessorList([mgd_logits_processor]) 94 | generated_code = model.generate( 95 | prompt_tokenized, 96 | do_sample=False, 97 | max_new_tokens=30, 98 | logits_processor=logits_processor, 99 | early_stopping=True, 100 | ) 101 | 102 | num_gen_tokens = generated_code.shape[-1] - prompt_tokenized.shape[-1] 103 | generated_code = tokenizer.decode(generated_code[0, -num_gen_tokens:]) 104 | 105 | assert ( 106 | generated_code 107 | == """reset(); 108 | let media_player = media_player.set_media_file_path(song); 109 | let media_player = media_""" 110 | ) 111 | -------------------------------------------------------------------------------- /tests/multilspy/multilspy_context.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides the MultilspyContext class, which stores the context for a Multilspy test. 3 | """ 4 | 5 | import dataclasses 6 | from monitors4codegen.multilspy.multilspy_config import MultilspyConfig 7 | from monitors4codegen.multilspy.multilspy_logger import MultilspyLogger 8 | 9 | @dataclasses.dataclass 10 | class MultilspyContext: 11 | """ 12 | Stores the context for a Multilspy test. 13 | """ 14 | config: MultilspyConfig 15 | logger: MultilspyLogger 16 | source_directory: str -------------------------------------------------------------------------------- /tests/multilspy/test_multilspy_csharp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for running the C# Language Server: OmniSharp 3 | """ 4 | 5 | import pytest 6 | 7 | from monitors4codegen.multilspy import LanguageServer 8 | from monitors4codegen.multilspy.multilspy_config import Language 9 | from monitors4codegen.multilspy.multilspy_types import Position, CompletionItemKind 10 | from tests.test_utils import create_test_context 11 | from pathlib import PurePath 12 | 13 | pytest_plugins = ("pytest_asyncio",) 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_multilspy_csharp_ryujinx(): 18 | """ 19 | Test the working of multilspy with C# repository - Ryujinx 20 | """ 21 | code_language = Language.CSHARP 22 | params = { 23 | "code_language": code_language, 24 | "repo_url": "https://github.com/Ryujinx/Ryujinx/", 25 | "repo_commit": "e768a54f17b390c3ac10904c7909e3bef020edbd" 26 | } 27 | with create_test_context(params) as context: 28 | lsp = LanguageServer.create(context.config, context.logger, context.source_directory) 29 | 30 | # All the communication with the language server must be performed inside the context manager 31 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 32 | # The context manager is an asynchronous context manager, so it must be used with async with. 33 | async with lsp.start_server(): 34 | result = await lsp.request_definition(str(PurePath("src/Ryujinx.Audio/Input/AudioInputManager.cs")), 176, 44) 35 | 36 | assert isinstance(result, list) 37 | assert len(result) == 1 38 | item = result[0] 39 | assert item["relativePath"] == str(PurePath("src/Ryujinx.Audio/Constants.cs")) 40 | assert item["range"] == { 41 | "start": {"line": 15, "character": 28}, 42 | "end": {"line": 15, "character": 50}, 43 | } 44 | 45 | result = await lsp.request_references(str(PurePath("src/Ryujinx.Audio/Constants.cs")), 15, 40) 46 | 47 | assert isinstance(result, list) 48 | assert len(result) == 2 49 | 50 | for item in result: 51 | del item["uri"] 52 | del item["absolutePath"] 53 | 54 | assert result == [ 55 | { 56 | "relativePath": str(PurePath("src/Ryujinx.Audio/Input/AudioInputManager.cs")), 57 | "range": { 58 | "start": {"line": 176, "character": 37}, 59 | "end": {"line": 176, "character": 59}, 60 | }, 61 | }, 62 | { 63 | "relativePath": str(PurePath("src/Ryujinx.Audio/Input/AudioInputSystem.cs")), 64 | "range": { 65 | "start": {"line": 77, "character": 29}, 66 | "end": {"line": 77, "character": 51}, 67 | }, 68 | }, 69 | ] 70 | 71 | completions_filepath = "src/ARMeilleure/CodeGen/Arm64/CodeGenerator.cs" 72 | with lsp.open_file(completions_filepath): 73 | deleted_text = lsp.delete_text_between_positions( 74 | completions_filepath, 75 | Position(line=1352, character=21), 76 | Position(line=1385, character=8) 77 | ) 78 | assert deleted_text == """AccessSize.Byte: 79 | context.Assembler.Ldaxrb(actual, address); 80 | break; 81 | case AccessSize.Hword: 82 | context.Assembler.Ldaxrh(actual, address); 83 | break; 84 | default: 85 | context.Assembler.Ldaxr(actual, address); 86 | break; 87 | } 88 | 89 | context.Assembler.Cmp(actual, expected); 90 | 91 | context.JumpToNear(ArmCondition.Ne); 92 | 93 | switch (accessSize) 94 | { 95 | case AccessSize.Byte: 96 | context.Assembler.Stlxrb(desired, address, result); 97 | break; 98 | case AccessSize.Hword: 99 | context.Assembler.Stlxrh(desired, address, result); 100 | break; 101 | default: 102 | context.Assembler.Stlxr(desired, address, result); 103 | break; 104 | } 105 | 106 | context.Assembler.Cbnz(result, startOffset - context.StreamOffset); // Retry if store failed. 107 | 108 | context.JumpHere(); 109 | 110 | context.Assembler.Clrex(); 111 | """ 112 | completions = await lsp.request_completions(completions_filepath, 1352, 21) 113 | completions = [completion["completionText"] for completion in completions if completion["kind"] == CompletionItemKind.EnumMember] 114 | assert set(completions) == set(['AccessSize.Byte', 'AccessSize.Hword', 'AccessSize.Auto']) 115 | 116 | completions_filepath = "src/ARMeilleure/CodeGen/X86/CodeGenerator.cs" 117 | with lsp.open_file(completions_filepath): 118 | deleted_text = lsp.delete_text_between_positions( 119 | completions_filepath, 120 | Position(line=226, character=79), 121 | Position(line=243, character=28) 122 | ) 123 | assert deleted_text == """Below); 124 | break; 125 | 126 | case Intrinsic.X86Comisseq: 127 | context.Assembler.Comiss(src1, src2); 128 | context.Assembler.Setcc(dest, X86Condition.Equal); 129 | break; 130 | 131 | case Intrinsic.X86Comissge: 132 | context.Assembler.Comiss(src1, src2); 133 | context.Assembler.Setcc(dest, X86Condition.AboveOrEqual); 134 | break; 135 | 136 | case Intrinsic.X86Comisslt: 137 | context.Assembler.Comiss(src1, src2); 138 | context.Assembler.Setcc(dest, X86Condition.Below); 139 | break; 140 | """ 141 | completions = await lsp.request_completions(completions_filepath, 226, 79, allow_incomplete=True) 142 | completions = [completion["completionText"] for completion in completions if completion["kind"] != CompletionItemKind.Keyword] 143 | assert set(completions) == set(['NotSign', 'ParityOdd', 'NotOverflow', 'Less', 'AboveOrEqual', 'LessOrEqual', 'Overflow', 'Greater', 'ParityEven', 'Sign', 'BelowOrEqual', 'Equal', 'GreaterOrEqual', 'Below', 'Above', 'NotEqual']) -------------------------------------------------------------------------------- /tests/multilspy/test_multilspy_python.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for running the Python Language Server: jedi-language-server 3 | """ 4 | 5 | import pytest 6 | from monitors4codegen.multilspy import LanguageServer 7 | from monitors4codegen.multilspy.multilspy_config import Language 8 | from tests.test_utils import create_test_context 9 | from pathlib import PurePath 10 | 11 | pytest_plugins = ("pytest_asyncio",) 12 | 13 | @pytest.mark.asyncio 14 | async def test_multilspy_python_black(): 15 | """ 16 | Test the working of multilspy with python repository - black 17 | """ 18 | code_language = Language.PYTHON 19 | params = { 20 | "code_language": code_language, 21 | "repo_url": "https://github.com/psf/black/", 22 | "repo_commit": "f3b50e466969f9142393ec32a4b2a383ffbe5f23" 23 | } 24 | with create_test_context(params) as context: 25 | lsp = LanguageServer.create(context.config, context.logger, context.source_directory) 26 | 27 | # All the communication with the language server must be performed inside the context manager 28 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 29 | # The context manager is an asynchronous context manager, so it must be used with async with. 30 | async with lsp.start_server(): 31 | result = await lsp.request_definition(str(PurePath("src/black/mode.py")), 163, 4) 32 | 33 | assert isinstance(result, list) 34 | assert len(result) == 1 35 | item = result[0] 36 | assert item["relativePath"] == str(PurePath("src/black/mode.py")) 37 | assert item["range"] == { 38 | "start": {"line": 163, "character": 4}, 39 | "end": {"line": 163, "character": 20}, 40 | } 41 | 42 | result = await lsp.request_references(str(PurePath("src/black/mode.py")), 163, 4) 43 | 44 | assert isinstance(result, list) 45 | assert len(result) == 8 46 | 47 | for item in result: 48 | del item["uri"] 49 | del item["absolutePath"] 50 | 51 | assert result == [ 52 | { 53 | "relativePath": str(PurePath("src/black/__init__.py")), 54 | "range": { 55 | "start": {"line": 71, "character": 4}, 56 | "end": {"line": 71, "character": 20}, 57 | }, 58 | }, 59 | { 60 | "relativePath": str(PurePath("src/black/__init__.py")), 61 | "range": { 62 | "start": {"line": 1105, "character": 11}, 63 | "end": {"line": 1105, "character": 27}, 64 | }, 65 | }, 66 | { 67 | "relativePath": str(PurePath("src/black/__init__.py")), 68 | "range": { 69 | "start": {"line": 1113, "character": 11}, 70 | "end": {"line": 1113, "character": 27}, 71 | }, 72 | }, 73 | { 74 | "relativePath": str(PurePath("src/black/mode.py")), 75 | "range": { 76 | "start": {"line": 163, "character": 4}, 77 | "end": {"line": 163, "character": 20}, 78 | }, 79 | }, 80 | { 81 | "relativePath": str(PurePath("src/black/parsing.py")), 82 | "range": { 83 | "start": {"line": 7, "character": 68}, 84 | "end": {"line": 7, "character": 84}, 85 | }, 86 | }, 87 | { 88 | "relativePath": str(PurePath("src/black/parsing.py")), 89 | "range": { 90 | "start": {"line": 37, "character": 11}, 91 | "end": {"line": 37, "character": 27}, 92 | }, 93 | }, 94 | { 95 | "relativePath": str(PurePath("src/black/parsing.py")), 96 | "range": { 97 | "start": {"line": 39, "character": 14}, 98 | "end": {"line": 39, "character": 30}, 99 | }, 100 | }, 101 | { 102 | "relativePath": str(PurePath("src/black/parsing.py")), 103 | "range": { 104 | "start": {"line": 44, "character": 11}, 105 | "end": {"line": 44, "character": 27}, 106 | }, 107 | }, 108 | ] -------------------------------------------------------------------------------- /tests/multilspy/test_multilspy_rust.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for running the Rust Language Server: rust-analyzer 3 | """ 4 | 5 | import unittest 6 | import pytest 7 | 8 | from monitors4codegen.multilspy import LanguageServer 9 | from monitors4codegen.multilspy.multilspy_config import Language 10 | from monitors4codegen.multilspy.multilspy_types import Position, CompletionItemKind 11 | from tests.test_utils import create_test_context 12 | from pathlib import PurePath 13 | 14 | pytest_plugins = ("pytest_asyncio",) 15 | 16 | @pytest.mark.asyncio 17 | async def test_multilspy_rust_carbonyl(): 18 | """ 19 | Test the working of multilspy with rust repository - carbonyl 20 | """ 21 | code_language = Language.RUST 22 | params = { 23 | "code_language": code_language, 24 | "repo_url": "https://github.com/fathyb/carbonyl/", 25 | "repo_commit": "ab80a276b1bd1c2c8dcefc8f248415dfc61dc2bf" 26 | } 27 | with create_test_context(params) as context: 28 | lsp = LanguageServer.create(context.config, context.logger, context.source_directory) 29 | 30 | # All the communication with the language server must be performed inside the context manager 31 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 32 | # The context manager is an asynchronous context manager, so it must be used with async with. 33 | async with lsp.start_server(): 34 | result = await lsp.request_definition(str(PurePath("src/browser/bridge.rs")), 132, 18) 35 | 36 | assert isinstance(result, list) 37 | assert len(result) == 1 38 | item = result[0] 39 | assert item["relativePath"] == str(PurePath("src/input/tty.rs")) 40 | assert item["range"] == { 41 | "start": {"line": 43, "character": 11}, 42 | "end": {"line": 43, "character": 19}, 43 | } 44 | 45 | result = await lsp.request_references(str(PurePath("src/input/tty.rs")), 43, 15) 46 | 47 | assert isinstance(result, list) 48 | assert len(result) == 2 49 | 50 | for item in result: 51 | del item["uri"] 52 | del item["absolutePath"] 53 | 54 | case = unittest.TestCase() 55 | case.assertCountEqual( 56 | result, 57 | [ 58 | { 59 | "relativePath": str(PurePath("src/browser/bridge.rs")), 60 | "range": { 61 | "start": {"line": 132, "character": 13}, 62 | "end": {"line": 132, "character": 21}, 63 | }, 64 | }, 65 | { 66 | "relativePath": str(PurePath("src/input/tty.rs")), 67 | "range": { 68 | "start": {"line": 16, "character": 13}, 69 | "end": {"line": 16, "character": 21}, 70 | }, 71 | }, 72 | ], 73 | ) 74 | 75 | @pytest.mark.asyncio 76 | async def test_multilspy_rust_completions_mediaplayer() -> None: 77 | """ 78 | Test the working of multilspy with Rust repository - mediaplayer 79 | """ 80 | code_language = Language.RUST 81 | params = { 82 | "code_language": code_language, 83 | "repo_url": "https://github.com/LakshyAAAgrawal/MediaPlayer_example/", 84 | "repo_commit": "ba27bb16c7ba1d88808300364af65eb69b1d84a8", 85 | } 86 | 87 | with create_test_context(params) as context: 88 | lsp = LanguageServer.create(context.config, context.logger, context.source_directory) 89 | filepath = "src/playlist.rs" 90 | # All the communication with the language server must be performed inside the context manager 91 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 92 | async with lsp.start_server(): 93 | with lsp.open_file(filepath): 94 | deleted_text = lsp.delete_text_between_positions( 95 | filepath, Position(line=10, character=40), Position(line=12, character=4) 96 | ) 97 | assert ( 98 | deleted_text 99 | == """reset(); 100 | media_player1 = media_player; 101 | """ 102 | ) 103 | 104 | response = await lsp.request_completions(filepath, 10, 40, allow_incomplete=True) 105 | 106 | response = [item for item in response if item['kind'] != CompletionItemKind.Snippet] 107 | 108 | for item in response: 109 | item['completionText'] = item['completionText'][:item['completionText'].find('(')] 110 | 111 | assert set([item['completionText'] for item in response]) == {'reset', 'into', 'try_into', 'prepare'} 112 | -------------------------------------------------------------------------------- /tests/multilspy/test_sync_multilspy_csharp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for running the C# Language Server: OmniSharp 3 | """ 4 | 5 | 6 | from monitors4codegen.multilspy import SyncLanguageServer 7 | from monitors4codegen.multilspy.multilspy_config import Language 8 | from tests.test_utils import create_test_context 9 | from pathlib import PurePath 10 | 11 | 12 | def test_multilspy_csharp_ryujinx() -> None: 13 | """ 14 | Test the working of multilspy with C# repository - Ryujinx 15 | """ 16 | code_language = Language.CSHARP 17 | params = { 18 | "code_language": code_language, 19 | "repo_url": "https://github.com/Ryujinx/Ryujinx/", 20 | "repo_commit": "e768a54f17b390c3ac10904c7909e3bef020edbd" 21 | } 22 | with create_test_context(params) as context: 23 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 24 | 25 | # All the communication with the language server must be performed inside the context manager 26 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 27 | with lsp.start_server(): 28 | result = lsp.request_definition(str(PurePath("src/Ryujinx.Audio/Input/AudioInputManager.cs")), 176, 44) 29 | 30 | assert isinstance(result, list) 31 | assert len(result) == 1 32 | item = result[0] 33 | assert item["relativePath"] == str(PurePath("src/Ryujinx.Audio/Constants.cs")) 34 | assert item["range"] == { 35 | "start": {"line": 15, "character": 28}, 36 | "end": {"line": 15, "character": 50}, 37 | } 38 | 39 | result = lsp.request_references(str(PurePath("src/Ryujinx.Audio/Constants.cs")), 15, 40) 40 | 41 | assert isinstance(result, list) 42 | assert len(result) == 2 43 | 44 | for item in result: 45 | del item["uri"] 46 | del item["absolutePath"] 47 | 48 | assert result == [ 49 | { 50 | "relativePath": str(PurePath("src/Ryujinx.Audio/Input/AudioInputManager.cs")), 51 | "range": { 52 | "start": {"line": 176, "character": 37}, 53 | "end": {"line": 176, "character": 59}, 54 | }, 55 | }, 56 | { 57 | "relativePath": str(PurePath("src/Ryujinx.Audio/Input/AudioInputSystem.cs")), 58 | "range": { 59 | "start": {"line": 77, "character": 29}, 60 | "end": {"line": 77, "character": 51}, 61 | }, 62 | }, 63 | ] 64 | -------------------------------------------------------------------------------- /tests/multilspy/test_sync_multilspy_java.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for running the Java Language Server: Eclipse JDT.LS 3 | """ 4 | 5 | from pathlib import PurePath 6 | from monitors4codegen.multilspy import SyncLanguageServer 7 | from monitors4codegen.multilspy.multilspy_config import Language 8 | from tests.test_utils import create_test_context 9 | 10 | def test_multilspy_java_clickhouse_highlevel_sinker() -> None: 11 | """ 12 | Test the working of multilspy with Java repository - clickhouse-highlevel-sinker 13 | """ 14 | code_language = Language.JAVA 15 | params = { 16 | "code_language": code_language, 17 | "repo_url": "https://github.com/Index103000/clickhouse-highlevel-sinker/", 18 | "repo_commit": "ee31d278918fe5e64669a6840c4d8fb53889e573" 19 | } 20 | with create_test_context(params) as context: 21 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 22 | 23 | # All the communication with the language server must be performed inside the context manager 24 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 25 | with lsp.start_server(): 26 | filepath = str(PurePath("src/main/java/com/xlvchao/clickhouse/component/ClickHouseSinkManager.java")) 27 | result = lsp.request_definition(filepath, 44, 59) 28 | 29 | assert isinstance(result, list) 30 | assert len(result) == 1 31 | item = result[0] 32 | assert item["relativePath"] == str( 33 | PurePath("src/main/java/com/xlvchao/clickhouse/component/ScheduledCheckerAndCleaner.java") 34 | ) 35 | assert item["range"] == { 36 | "start": {"line": 22, "character": 11}, 37 | "end": {"line": 22, "character": 37}, 38 | } 39 | 40 | # TODO: The following test is running flaky on Windows. Investigate and fix. 41 | # On Windows, it returns the correct result sometimes and sometimes it returns the following: 42 | # incorrect_output = [ 43 | # { 44 | # "range": {"end": {"character": 86, "line": 24}, "start": {"character": 65, "line": 24}}, 45 | # "relativePath": "src\\main\\java\\com\\xlvchao\\clickhouse\\component\\ClickHouseSinkManager.java", 46 | # }, 47 | # { 48 | # "range": {"end": {"character": 61, "line": 2}, "start": {"character": 7, "line": 2}}, 49 | # "relativePath": "src\\test\\java\\com\\xlvchao\\clickhouse\\SpringbootDemo.java", 50 | # }, 51 | # { 52 | # "range": {"end": {"character": 29, "line": 28}, "start": {"character": 8, "line": 28}}, 53 | # "relativePath": "src\\test\\java\\com\\xlvchao\\clickhouse\\SpringbootDemo.java", 54 | # }, 55 | # { 56 | # "range": {"end": {"character": 69, "line": 28}, "start": {"character": 48, "line": 28}}, 57 | # "relativePath": "src\\test\\java\\com\\xlvchao\\clickhouse\\SpringbootDemo.java", 58 | # }, 59 | # ] 60 | 61 | result = lsp.request_references(filepath, 82, 27) 62 | 63 | assert isinstance(result, list) 64 | assert len(result) == 2 65 | 66 | for item in result: 67 | del item["uri"] 68 | del item["absolutePath"] 69 | 70 | assert result == [ 71 | { 72 | "relativePath": str( 73 | PurePath("src/main/java/com/xlvchao/clickhouse/component/ClickHouseSinkManager.java") 74 | ), 75 | "range": { 76 | "start": {"line": 75, "character": 66}, 77 | "end": {"line": 75, "character": 85}, 78 | }, 79 | }, 80 | { 81 | "relativePath": str( 82 | PurePath("src/main/java/com/xlvchao/clickhouse/component/ClickHouseSinkManager.java") 83 | ), 84 | "range": { 85 | "start": {"line": 71, "character": 12}, 86 | "end": {"line": 71, "character": 31}, 87 | }, 88 | }, 89 | ] 90 | -------------------------------------------------------------------------------- /tests/multilspy/test_sync_multilspy_python.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for running the Python Language Server: jedi-language-server 3 | """ 4 | 5 | from monitors4codegen.multilspy import SyncLanguageServer 6 | from monitors4codegen.multilspy.multilspy_config import Language 7 | from tests.test_utils import create_test_context 8 | from pathlib import PurePath 9 | 10 | def test_multilspy_python_black() -> None: 11 | """ 12 | Test the working of multilspy with python repository - black 13 | """ 14 | code_language = Language.PYTHON 15 | params = { 16 | "code_language": code_language, 17 | "repo_url": "https://github.com/psf/black/", 18 | "repo_commit": "f3b50e466969f9142393ec32a4b2a383ffbe5f23" 19 | } 20 | with create_test_context(params) as context: 21 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 22 | 23 | # All the communication with the language server must be performed inside the context manager 24 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 25 | with lsp.start_server(): 26 | result = lsp.request_definition(str(PurePath("src/black/mode.py")), 163, 4) 27 | 28 | assert isinstance(result, list) 29 | assert len(result) == 1 30 | item = result[0] 31 | assert item["relativePath"] == str(PurePath("src/black/mode.py")) 32 | assert item["range"] == { 33 | "start": {"line": 163, "character": 4}, 34 | "end": {"line": 163, "character": 20}, 35 | } 36 | 37 | result = lsp.request_references(str(PurePath("src/black/mode.py")), 163, 4) 38 | 39 | assert isinstance(result, list) 40 | assert len(result) == 8 41 | 42 | for item in result: 43 | del item["uri"] 44 | del item["absolutePath"] 45 | 46 | assert result == [ 47 | { 48 | "relativePath": str(PurePath("src/black/__init__.py")), 49 | "range": { 50 | "start": {"line": 71, "character": 4}, 51 | "end": {"line": 71, "character": 20}, 52 | }, 53 | }, 54 | { 55 | "relativePath": str(PurePath("src/black/__init__.py")), 56 | "range": { 57 | "start": {"line": 1105, "character": 11}, 58 | "end": {"line": 1105, "character": 27}, 59 | }, 60 | }, 61 | { 62 | "relativePath": str(PurePath("src/black/__init__.py")), 63 | "range": { 64 | "start": {"line": 1113, "character": 11}, 65 | "end": {"line": 1113, "character": 27}, 66 | }, 67 | }, 68 | { 69 | "relativePath": str(PurePath("src/black/mode.py")), 70 | "range": { 71 | "start": {"line": 163, "character": 4}, 72 | "end": {"line": 163, "character": 20}, 73 | }, 74 | }, 75 | { 76 | "relativePath": str(PurePath("src/black/parsing.py")), 77 | "range": { 78 | "start": {"line": 7, "character": 68}, 79 | "end": {"line": 7, "character": 84}, 80 | }, 81 | }, 82 | { 83 | "relativePath": str(PurePath("src/black/parsing.py")), 84 | "range": { 85 | "start": {"line": 37, "character": 11}, 86 | "end": {"line": 37, "character": 27}, 87 | }, 88 | }, 89 | { 90 | "relativePath": str(PurePath("src/black/parsing.py")), 91 | "range": { 92 | "start": {"line": 39, "character": 14}, 93 | "end": {"line": 39, "character": 30}, 94 | }, 95 | }, 96 | { 97 | "relativePath": str(PurePath("src/black/parsing.py")), 98 | "range": { 99 | "start": {"line": 44, "character": 11}, 100 | "end": {"line": 44, "character": 27}, 101 | }, 102 | }, 103 | ] 104 | -------------------------------------------------------------------------------- /tests/multilspy/test_sync_multilspy_rust.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains tests for running the Rust Language Server: rust-analyzer 3 | """ 4 | 5 | import unittest 6 | 7 | from monitors4codegen.multilspy import SyncLanguageServer 8 | from monitors4codegen.multilspy.multilspy_config import Language 9 | from tests.test_utils import create_test_context 10 | from pathlib import PurePath 11 | 12 | def test_multilspy_rust_carbonyl() -> None: 13 | """ 14 | Test the working of multilspy with rust repository - carbonyl 15 | """ 16 | code_language = Language.RUST 17 | params = { 18 | "code_language": code_language, 19 | "repo_url": "https://github.com/fathyb/carbonyl/", 20 | "repo_commit": "ab80a276b1bd1c2c8dcefc8f248415dfc61dc2bf" 21 | } 22 | with create_test_context(params) as context: 23 | lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory) 24 | 25 | # All the communication with the language server must be performed inside the context manager 26 | # The server process is started when the context manager is entered and is terminated when the context manager is exited. 27 | with lsp.start_server(): 28 | result = lsp.request_definition(str(PurePath("src/browser/bridge.rs")), 132, 18) 29 | 30 | assert isinstance(result, list) 31 | assert len(result) == 1 32 | item = result[0] 33 | assert item["relativePath"] == str(PurePath("src/input/tty.rs")) 34 | assert item["range"] == { 35 | "start": {"line": 43, "character": 11}, 36 | "end": {"line": 43, "character": 19}, 37 | } 38 | 39 | result = lsp.request_references(str(PurePath("src/input/tty.rs")), 43, 15) 40 | 41 | assert isinstance(result, list) 42 | assert len(result) == 2 43 | 44 | for item in result: 45 | del item["uri"] 46 | del item["absolutePath"] 47 | 48 | case = unittest.TestCase() 49 | case.assertCountEqual( 50 | result, 51 | [ 52 | { 53 | "relativePath": str(PurePath("src/browser/bridge.rs")), 54 | "range": { 55 | "start": {"line": 132, "character": 13}, 56 | "end": {"line": 132, "character": 21}, 57 | }, 58 | }, 59 | { 60 | "relativePath": str(PurePath("src/input/tty.rs")), 61 | "range": { 62 | "start": {"line": 16, "character": 13}, 63 | "end": {"line": 16, "character": 21}, 64 | }, 65 | }, 66 | ], 67 | ) 68 | -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | ; For running multilspy tests as seen from https://stackoverflow.com/a/72104554 3 | asyncio_mode = auto 4 | 5 | ; directories containing tests 6 | testpaths = 7 | tests 8 | 9 | ; force pattern for test content 10 | python_files = test_*.py 11 | python_functions = test_* 12 | python_classes = Test* 13 | 14 | pythonpath = 15 | ../ 16 | ../src/ 17 | tests/multilspy 18 | 19 | ; equivalent to pass the argument to pytest CLI 20 | addopts = 21 | ; increase verbosity 22 | --verbose 23 | ; do not capture output 24 | --capture=no -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import contextlib 4 | import shutil 5 | import torch 6 | 7 | from monitors4codegen.multilspy.multilspy_config import MultilspyConfig 8 | from monitors4codegen.multilspy.multilspy_logger import MultilspyLogger 9 | from tests.multilspy.multilspy_context import MultilspyContext 10 | from typing import Iterator 11 | from uuid import uuid4 12 | from monitors4codegen.multilspy.multilspy_utils import FileUtils 13 | 14 | @contextlib.contextmanager 15 | def create_test_context(params: dict) -> Iterator[MultilspyContext]: 16 | """ 17 | Creates a test context for the given parameters. 18 | """ 19 | config = MultilspyConfig.from_dict(params) 20 | logger = MultilspyLogger() 21 | 22 | user_home_dir = os.path.expanduser("~") 23 | multilspy_home_directory = str(pathlib.Path(user_home_dir, ".multilspy")) 24 | temp_extract_directory = str(pathlib.Path(multilspy_home_directory, uuid4().hex)) 25 | try: 26 | os.makedirs(temp_extract_directory, exist_ok=False) 27 | assert params['repo_url'].endswith('/') 28 | repo_zip_url = params['repo_url'] + f"archive/{params['repo_commit']}.zip" 29 | FileUtils.download_and_extract_archive(logger, repo_zip_url, temp_extract_directory, "zip") 30 | dir_contents = os.listdir(temp_extract_directory) 31 | assert len(dir_contents) == 1 32 | source_directory_path = str(pathlib.Path(temp_extract_directory, dir_contents[0])) 33 | 34 | yield MultilspyContext(config, logger, source_directory_path) 35 | finally: 36 | if os.path.exists(temp_extract_directory): 37 | shutil.rmtree(temp_extract_directory) 38 | 39 | def is_cuda_available() -> bool: 40 | """ 41 | Returns True if CUDA is available, False otherwise 42 | """ 43 | if torch.cuda.is_available(): 44 | try: 45 | t = torch.rand(1).cuda() 46 | t = t * 2 47 | return True 48 | except RuntimeError: 49 | return False 50 | return False --------------------------------------------------------------------------------