├── .github └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── experiments ├── __init__.py ├── main │ ├── .gitignore │ ├── __init__.py │ ├── analyze_avg_time.py │ ├── analyze_inf_res.py │ ├── analyze_many_temp.sh │ ├── analyze_multifun.py │ ├── analyze_outputs.py │ ├── analyze_pass_not_accepted.py │ ├── analyze_passatk.sh │ ├── analyze_resample.py │ ├── convert_files.py │ ├── count_syntax_errors.py │ ├── create_humaneval_repair_dataset.sh │ ├── create_mbpp_repair_dataset.sh │ ├── create_repair_dataset.py │ ├── download_models.py │ ├── figures.sh │ ├── figures │ │ ├── __init__.py │ │ ├── fig_compiler_perf_comparison.py │ │ ├── fig_compiler_perf_fc.py │ │ ├── fig_compiler_perf_repair.py │ │ ├── fig_compiler_perf_syn_tran.py │ │ └── fig_compiler_time.py │ ├── figures_revision │ │ ├── __init__.py │ │ ├── fig_compiler_perf_syn_tran_repair.py │ │ ├── fig_resample_hist.py │ │ └── fig_resample_hist.sh │ ├── filter_sensible_ts_outputs.py │ ├── fix_nc.py │ ├── inference_multiple.py │ ├── inference_multiple_repair.py │ ├── invalid_mbpp │ ├── kill_inf.sh │ ├── print_c.py │ ├── repair_datasets │ │ ├── humaneval_repair_dataset.jsonl │ │ └── mbpp_repair_dataset.jsonl │ ├── rerun_temp_inf.py │ ├── results_paper │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_repair-all_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_repair-all_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_synth_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_synth_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_translate_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_translate_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=1_t=1_synth_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=1_t=1_synth_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=1_t=1_translate_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=1_t=1_translate_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=2_t=1_synth_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=2_t=1_synth_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=2_t=1_translate_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=2_t=1_translate_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=3_t=1_synth_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=3_t=1_synth_nc.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=3_t=1_translate_c.jsonl │ │ ├── humaneval_Qwen_Qwen2.5-32B-Instruct_s=3_t=1_translate_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_repair-all_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_repair-all_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_synth_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_synth_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_translate_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_translate_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=1_t=1_synth_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=1_t=1_synth_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=1_t=1_translate_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=1_t=1_translate_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=2_t=1_synth_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=2_t=1_synth_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=2_t=1_translate_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=2_t=1_translate_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=3_t=1_synth_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=3_t=1_synth_nc.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=3_t=1_translate_c.jsonl │ │ ├── humaneval_codellama_CodeLlama-34b-Instruct-hf_s=3_t=1_translate_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_repair-all_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_repair-all_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_synth_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_synth_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_translate_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_translate_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=1_t=1_synth_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=1_t=1_synth_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=1_t=1_translate_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=1_t=1_translate_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=2_t=1_synth_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=2_t=1_synth_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=2_t=1_translate_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=2_t=1_translate_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=3_t=1_synth_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=3_t=1_synth_nc.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=3_t=1_translate_c.jsonl │ │ ├── humaneval_deepseek-ai_deepseek-coder-33b-instruct_s=3_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=0_t=1_repair-all_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=0_t=1_repair-all_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=0_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=0_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=0_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=0_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=1_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=1_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=1_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=1_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=2_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=2_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=2_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=2_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=3_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=3_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=3_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-27b-it_s=3_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=0_t=1_repair-all_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=0_t=1_repair-all_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=0_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=0_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=0_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=0_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=1_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=1_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=1_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=1_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=2_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=2_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=2_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=2_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=3_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=3_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=3_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-2b-it_s=3_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=0_t=1_repair-all_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=0_t=1_repair-all_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=0_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=0_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=0_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=0_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=1_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=1_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=1_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=1_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=2_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=2_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=2_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=2_t=1_translate_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=3_t=1_synth_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=3_t=1_synth_nc.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=3_t=1_translate_c.jsonl │ │ ├── humaneval_google_gemma-2-9b-it_s=3_t=1_translate_nc.jsonl │ │ ├── mbpp_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_repair-all_c.jsonl │ │ ├── mbpp_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_repair-all_nc.jsonl │ │ ├── mbpp_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_synth_c.jsonl │ │ ├── mbpp_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_synth_nc.jsonl │ │ ├── mbpp_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_translate_c.jsonl │ │ ├── mbpp_Qwen_Qwen2.5-32B-Instruct_s=0_t=1_translate_nc.jsonl │ │ ├── mbpp_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_repair-all_c.jsonl │ │ ├── mbpp_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_repair-all_nc.jsonl │ │ ├── mbpp_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_synth_c.jsonl │ │ ├── mbpp_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_synth_nc.jsonl │ │ ├── mbpp_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_translate_c.jsonl │ │ ├── mbpp_codellama_CodeLlama-34b-Instruct-hf_s=0_t=1_translate_nc.jsonl │ │ ├── mbpp_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_repair-all_c.jsonl │ │ ├── mbpp_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_repair-all_nc.jsonl │ │ ├── mbpp_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_synth_c.jsonl │ │ ├── mbpp_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_synth_nc.jsonl │ │ ├── mbpp_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_translate_c.jsonl │ │ ├── mbpp_deepseek-ai_deepseek-coder-33b-instruct_s=0_t=1_translate_nc.jsonl │ │ ├── mbpp_google_gemma-2-27b-it_s=0_t=1_repair-all_c.jsonl │ │ ├── mbpp_google_gemma-2-27b-it_s=0_t=1_repair-all_nc.jsonl │ │ ├── mbpp_google_gemma-2-27b-it_s=0_t=1_synth_c.jsonl │ │ ├── mbpp_google_gemma-2-27b-it_s=0_t=1_synth_nc.jsonl │ │ ├── mbpp_google_gemma-2-27b-it_s=0_t=1_translate_c.jsonl │ │ ├── mbpp_google_gemma-2-27b-it_s=0_t=1_translate_nc.jsonl │ │ ├── mbpp_google_gemma-2-2b-it_s=0_t=1_repair-all_c.jsonl │ │ ├── mbpp_google_gemma-2-2b-it_s=0_t=1_repair-all_nc.jsonl │ │ ├── mbpp_google_gemma-2-2b-it_s=0_t=1_synth_c.jsonl │ │ ├── mbpp_google_gemma-2-2b-it_s=0_t=1_synth_nc.jsonl │ │ ├── mbpp_google_gemma-2-2b-it_s=0_t=1_translate_c.jsonl │ │ ├── mbpp_google_gemma-2-2b-it_s=0_t=1_translate_nc.jsonl │ │ ├── mbpp_google_gemma-2-9b-it_s=0_t=1_repair-all_c.jsonl │ │ ├── mbpp_google_gemma-2-9b-it_s=0_t=1_repair-all_nc.jsonl │ │ ├── mbpp_google_gemma-2-9b-it_s=0_t=1_synth_c.jsonl │ │ ├── mbpp_google_gemma-2-9b-it_s=0_t=1_synth_nc.jsonl │ │ ├── mbpp_google_gemma-2-9b-it_s=0_t=1_translate_c.jsonl │ │ └── mbpp_google_gemma-2-9b-it_s=0_t=1_translate_nc.jsonl │ ├── run_experiments.sh │ ├── run_experiments_repair.py │ ├── run_experiments_syn_tran.py │ ├── translate_canonical_humaneval.py │ └── util.py └── translation │ ├── humaneval-x │ ├── .gitignore │ ├── dataset.json │ ├── execute.py │ ├── metric.py │ ├── process_input.py │ └── translate.py │ └── mbpp │ ├── dataset.json │ └── generate.py ├── incremental_tsc.py ├── package.json ├── poetry.lock ├── pyproject.toml ├── setup_conda.sh ├── setup_env.sh ├── test ├── __init__.py ├── data │ ├── __init__.py │ ├── manually_fixed │ │ ├── HumanEval_1.ts │ │ ├── HumanEval_10.ts │ │ ├── HumanEval_6.ts │ │ ├── HumanEval_7.ts │ │ ├── HumanEval_8.ts │ │ └── HumanEval_9.ts │ ├── openai_openai_humaneval_ts_gpt-4o-2024-05-13.jsonl │ ├── openai_openai_humaneval_ts_gpt-4o-2024-05-13_filtered.jsonl │ └── print_instance.py ├── test_parser_base.py ├── test_parser_ts.py └── utils.py ├── ts_parser ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── build.sh ├── install_rust.sh └── src │ └── main.rs └── typesafe_llm ├── __init__.py ├── parser ├── __init__.py ├── parser_base.py ├── parser_shared.py ├── parser_ts.py ├── parser_ts_types.py ├── types_base.py ├── types_ts.py └── util.py ├── sampling.py ├── trie.py └── util.py /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: QA & Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v3 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: 3.11 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -e . 26 | pip install pytest pre-commit 27 | 28 | - name: Run pre-commit checks 29 | run: pre-commit run --all-files 30 | 31 | - name: Run tests 32 | run: pytest test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | .vscode 3 | /test*.py 4 | /test*.js 5 | /test*.ts 6 | profile.json 7 | profile.html 8 | cache 9 | secret.sh 10 | secret.py 11 | node_modules 12 | package-lock.json 13 | nohup.out 14 | .DS_Store 15 | test_go/ 16 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm+all 17 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm+all 18 | 19 | ### PyCharm+all ### 20 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 21 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 22 | 23 | # User-specific stuff 24 | .idea/**/workspace.xml 25 | .idea/**/tasks.xml 26 | .idea/**/usage.statistics.xml 27 | .idea/**/dictionaries 28 | .idea/**/shelf 29 | 30 | # AWS User-specific 31 | .idea/**/aws.xml 32 | 33 | # Generated files 34 | .idea/**/contentModel.xml 35 | 36 | # Sensitive or high-churn files 37 | .idea/**/dataSources/ 38 | .idea/**/dataSources.ids 39 | .idea/**/dataSources.local.xml 40 | .idea/**/sqlDataSources.xml 41 | .idea/**/dynamic.xml 42 | .idea/**/uiDesigner.xml 43 | .idea/**/dbnavigator.xml 44 | 45 | # Gradle 46 | .idea/**/gradle.xml 47 | .idea/**/libraries 48 | 49 | # Gradle and Maven with auto-import 50 | # When using Gradle or Maven with auto-import, you should exclude module files, 51 | # since they will be recreated, and may cause churn. Uncomment if using 52 | # auto-import. 53 | # .idea/artifacts 54 | # .idea/compiler.xml 55 | # .idea/jarRepositories.xml 56 | # .idea/modules.xml 57 | # .idea/*.iml 58 | # .idea/modules 59 | # *.iml 60 | # *.ipr 61 | 62 | # CMake 63 | cmake-build-*/ 64 | 65 | # Mongo Explorer plugin 66 | .idea/**/mongoSettings.xml 67 | 68 | # File-based project format 69 | *.iws 70 | 71 | # IntelliJ 72 | out/ 73 | 74 | # mpeltonen/sbt-idea plugin 75 | .idea_modules/ 76 | 77 | # JIRA plugin 78 | atlassian-ide-plugin.xml 79 | 80 | # Cursive Clojure plugin 81 | .idea/replstate.xml 82 | 83 | # SonarLint plugin 84 | .idea/sonarlint/ 85 | 86 | # Crashlytics plugin (for Android Studio and IntelliJ) 87 | com_crashlytics_export_strings.xml 88 | crashlytics.properties 89 | crashlytics-build.properties 90 | fabric.properties 91 | 92 | # Editor-based Rest Client 93 | .idea/httpRequests 94 | 95 | # Android studio 3.1+ serialized cache file 96 | .idea/caches/build_file_checksums.ser 97 | 98 | ### PyCharm+all Patch ### 99 | # Ignore everything but code style settings and run configurations 100 | # that are supposed to be shared within teams. 101 | 102 | .idea/* 103 | 104 | !.idea/codeStyles 105 | !.idea/runConfigurations 106 | 107 | ### Python ### 108 | # Byte-compiled / optimized / DLL files 109 | __pycache__/ 110 | *.py[cod] 111 | *$py.class 112 | 113 | # C extensions 114 | *.so 115 | 116 | # Distribution / packaging 117 | .Python 118 | build/ 119 | develop-eggs/ 120 | dist/ 121 | downloads/ 122 | eggs/ 123 | .eggs/ 124 | lib/ 125 | lib64/ 126 | parts/ 127 | sdist/ 128 | var/ 129 | wheels/ 130 | share/python-wheels/ 131 | *.egg-info/ 132 | .installed.cfg 133 | *.egg 134 | MANIFEST 135 | 136 | # PyInstaller 137 | # Usually these files are written by a python script from a template 138 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 139 | *.manifest 140 | *.spec 141 | 142 | # Installer logs 143 | pip-log.txt 144 | pip-delete-this-directory.txt 145 | 146 | # Unit test / coverage reports 147 | htmlcov/ 148 | .tox/ 149 | .nox/ 150 | .coverage 151 | .coverage.* 152 | .cache 153 | nosetests.xml 154 | coverage.xml 155 | *.cover 156 | *.py,cover 157 | .hypothesis/ 158 | .pytest_cache/ 159 | cover/ 160 | 161 | # Translations 162 | *.mo 163 | *.pot 164 | 165 | # Django stuff: 166 | *.log 167 | local_settings.py 168 | db.sqlite3 169 | db.sqlite3-journal 170 | 171 | # Flask stuff: 172 | instance/ 173 | .webassets-cache 174 | 175 | # Scrapy stuff: 176 | .scrapy 177 | 178 | # Sphinx documentation 179 | docs/_build/ 180 | 181 | # PyBuilder 182 | .pybuilder/ 183 | target/ 184 | 185 | # Jupyter Notebook 186 | .ipynb_checkpoints 187 | 188 | # IPython 189 | profile_default/ 190 | ipython_config.py 191 | 192 | # pyenv 193 | # For a library or package, you might want to ignore these files since the code is 194 | # intended to run in multiple environments; otherwise, check them in: 195 | # .python-version 196 | 197 | # pipenv 198 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 199 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 200 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 201 | # install all needed dependencies. 202 | #Pipfile.lock 203 | 204 | # poetry 205 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 206 | # This is especially recommended for binary packages to ensure reproducibility, and is more 207 | # commonly ignored for libraries. 208 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 209 | #poetry.lock 210 | 211 | # pdm 212 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 213 | #pdm.lock 214 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 215 | # in version control. 216 | # https://pdm.fming.dev/#use-with-ide 217 | .pdm.toml 218 | 219 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 220 | __pypackages__/ 221 | 222 | # Celery stuff 223 | celerybeat-schedule 224 | celerybeat.pid 225 | 226 | # SageMath parsed files 227 | *.sage.py 228 | 229 | # Environments 230 | .env 231 | .venv 232 | env/ 233 | venv/ 234 | ENV/ 235 | env.bak/ 236 | venv.bak/ 237 | 238 | # Spyder project settings 239 | .spyderproject 240 | .spyproject 241 | 242 | # Rope project settings 243 | .ropeproject 244 | 245 | # mkdocs documentation 246 | /site 247 | 248 | # mypy 249 | .mypy_cache/ 250 | .dmypy.json 251 | dmypy.json 252 | 253 | # Pyre type checker 254 | .pyre/ 255 | 256 | # pytype static type analyzer 257 | .pytype/ 258 | 259 | # Cython debug symbols 260 | cython_debug/ 261 | 262 | # PyCharm 263 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 264 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 265 | # and can be added to the global gitignore or merged into this file. For a more nuclear 266 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 267 | #.idea/ 268 | 269 | ### Python Patch ### 270 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 271 | poetry.toml 272 | 273 | 274 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all 275 | 276 | experiments/compiler_testing/ts_compilers/ 277 | experiments/compiler_testing/results/ 278 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.0 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | types_or: [ python, pyi ] 9 | args: [ --fix ] 10 | # Run the formatter. 11 | - id: ruff-format 12 | types_or: [ python, pyi ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Type-Constrained Code Generation with Language Models 2 | ===================================================== 3 | [![arXiv](https://img.shields.io/badge/arXiv-2504.09246-b31b1b.svg)](https://arxiv.org/abs/2504.09246) 4 | [![QA & Tests](https://github.com/eth-sri/type-constrained-code-generation/actions/workflows/tests.yml/badge.svg)](https://github.com/eth-sri/type-constrained-code-generation/actions/workflows/tests.yml) 5 | 6 | 7 | This is an implementation of a completion engine that parses type safe programs incrementally, guaranteeing that intermediate outputs can be completed to type-safe programs. 8 | The completion enginge can be used to constrain the sampling from an LLM model to only type-safe programs. 9 | The implementation currently only handles TypeScript. 10 | 11 | More details on the properties of the completion engine and supported features can be found in the paper [Type-Constrained Code Generation with Language Models](https://arxiv.org/abs/2504.09246). 12 | 13 | ### Overview 14 | When set-up correctly, the package can be used to sample type-safe TypeScript programs from a language model. 15 | The following will incrementally generate the code for a TypeScript merge sort function, while ensuring that the generated code is type-safe: 16 | 17 | ```python 18 | from typesafe_llm.sampling import sample_constrained 19 | 20 | sample_constrained( 21 | prompt="function merge_sort(x:number[]):number[] {", 22 | max_tokens=100, 23 | device="cuda", 24 | model_name = "google/gemma-2-2b-it", 25 | temperature=0, 26 | do_sample=False, 27 | trace=True, 28 | ) 29 | print("Generation completed") 30 | ``` 31 | 32 | The project contains two main parts: 33 | - The sampling algorithm, which is used to sample type-safe TypeScript programs from a language model. 34 | - The parser, which is used to parse TypeScript programs and check their completability to type-safe programs. 35 | 36 | ### Setup 37 | 38 | To install the package, we recommend setting up a conda environment using NVIDIA GPUs. 39 | 40 | ```bash 41 | git clone https://github.com/eth-sri/type-constrained-code-generation.git 42 | cd type-constrained-code-generation 43 | conda create -n typesafe_llm python=3.11 44 | conda activate typesafe_llm 45 | 46 | # for LLM inference 47 | # set up torch 48 | conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y 49 | # install flash-attention 50 | pip install flash-attn==2.7.3 --no-build-isolation 51 | 52 | # install package 53 | pip install -e . 54 | ``` 55 | 56 | If you only want to use the parser and do not want to sample from a language model, you can skip the installation of `torch` and `flash-attention`. 57 | 58 | ### Programmatic Usage 59 | 60 | #### LLM Sampling 61 | 62 | To sample type-safe TypeScript programs from a language model, you can use the `sample_constrained` function from the `typesafe_llm.sampling` module. 63 | 64 | ```python 65 | from typesafe_llm.sampling import sample_constrained 66 | 67 | sample = sample_constrained( 68 | prompt="function merge_sort(x:number[]):number[] {", 69 | max_tokens=100, 70 | device="cuda", 71 | model_name = "google/gemma-2-2b-it", 72 | temperature=0.1, 73 | do_sample=True, 74 | ) 75 | print(sample) 76 | ``` 77 | 78 | If GPU is available, set device to "cuda", on MacBook Pro set device to "mps" (when pytorch nightly is installed). 79 | Setting the device to "cpu" always works. 80 | `trace` controls a debugging output for live debugging of the generation process. 81 | Set to False for programmatic use. 82 | 83 | #### Incremental TypeScript parsing 84 | 85 | You can also independently use the parser to parse TypeScript programs and check their completability. 86 | 87 | ```python 88 | from typesafe_llm.parser.parser_ts import parse_ts_program 89 | 90 | states = parse_ts_program("let x:number = 1;x;") 91 | print(list(states)) 92 | # only one accepting state 93 | 94 | states = parse_ts_program('let x:number = "he') 95 | print(list(states)) 96 | # some accepting states, could be saved by y".length 97 | 98 | states = parse_ts_program('let x:boolean = 1 < "hey" +') 99 | print(list(states)) 100 | # no states, can not turn "hey" + ... into a number, but need number for < operator 101 | 102 | states = parse_ts_program('let x:number = 1;let y') 103 | print(list(states)) 104 | # two partial states, one where the second variable has name "y" and one where it is not completed yet 105 | ``` 106 | 107 | ### Tests 108 | 109 | To run the tests, you can use the following command: 110 | 111 | ```bash 112 | pytest test 113 | ``` 114 | 115 | ## Reproducing experiments 116 | 117 | In this section we provide an overview on how to reproduce the experiments presented in our [paper](https://arxiv.org/abs/2504.09246). 118 | 119 | ### Requirements 120 | 121 | To reproduce our experiments locally, it is required to have higher-end GPUs, e.g. NVIDIA A100 80GB. The package includes setup scripts for all software requirements using miniconda. Required Hardware / Software: 122 | 123 | - x86/64 architecture CPUs 124 | - 80GB GPU VRAM 125 | - CUDA 12.4 or newer 126 | 127 | Further the Gemma 2 model family requires accepting an EULA. Please create a huggingface account and visit the model websites to accept the EULA. 128 | - https://huggingface.co/google/gemma-2b-it 129 | - https://huggingface.co/google/gemma-9b-it 130 | - https://huggingface.co/google/gemma-27b-it 131 | 132 | You will later be requested for a Hugginface Access Token. Log in with the account with which you accepted the EULA and visit [the Access Token page](https://huggingface.co/settings/tokens) to generate an access token: https://huggingface.co/settings/tokens 133 | 134 | ### Setup 135 | 136 | Follow the installation instructions to install conda and all dependencies for the experiments: 137 | 138 | ```bash 139 | bash ./setup_conda.sh 140 | # Restart your shell 141 | bash ./setup_env.sh 142 | # NOTE: Some models are guarded on huggingface, so you will need to visit their model page, accept the EULA and enter the huggingface Access Token to your account when prompted. See section "Requirements" for more details. 143 | ``` 144 | 145 | > Important note: Before running the experiments, you need to download the models and datasets used for the experiments. 146 | 147 | We provide a script to download the required dataset and models for our experiments. This script must be run before starting the experiments. 148 | You may specify models to download by passing the `models` paramater. 149 | 150 | ```bash 151 | python3 experiments/main/download_models.py --models google/gemma-2-2b-it,google/gemma-2-9b-it 152 | ``` 153 | 154 | To download all required models and datasets, run the following command: 155 | 156 | ```bash 157 | python3 experiments/main/download_models.py 158 | ``` 159 | 160 | 161 | ### Warming up 162 | 163 | To warm up, we start by reproducing the result for synthesis of the smallest model (Gemma 2 2B) and the MBPP dataset. To avoid using busy GPUs in a shared setting, use command `nvidia-smi` to check which GPUs are free. Then specify the IDs of GPUs you want to use by setting the `CUDA_VISIBLE_DEVICES` environment variable. If you want to use GPU 0 and 1, run the following command: 164 | 165 | ```bash 166 | CUDA_VISIBLE_DEVICES=0,1 python3 experiments/main/run_experiments_syn_tran.py --models google/gemma-2-2b-it --tasks synth --subsets mbpp 167 | ``` 168 | 169 | This reproduces the results for Gemma-2B on the synthesis task on MBPP. 170 | The experiment should finish within approximately 4 hours on a single GPU. 171 | The results of the experiment (and all other results) will be stored in `experiments/main/results` in an appropriately named `jsonl` file. The general schema is `experiments/main/results/__s=_t=__.jsonl`. In this concrete example `experiments/main/results/mbpp_google_gemma-2-2b-it_s=0_t=1_synth_nc.jsonl` and `..._c.jsonl` for the unconstrained and type-constrained variants respectively. 172 | 173 | > The experiment runs can be cancelled at any time, intermediate results are stored in the `jsonl` files. Upon restarting, the script will automatically pick up the last completed instance and continue from there. It may happen that running tasks daemonize and continue running (check `nvidia-smi`). Make sure to kill them manually before restarting. 174 | 175 | Our experiment script automatically distributes jobs over indicated GPUs. 176 | The script then repeatedly queries whether running jobs are completed and new GPUs are available. You will therefore see something like the following ouput: 177 | ``` 178 | + CUDA_VISIBLE_DEVICES=0 python3 inference_multiple.py --max-tokens 1000 --timeout 300 --model_name google/gemma-2-2b-it --seed 0 --temp 1 --subset mbpp --try_top_k 10000000000000000 --constrained False --output_file 'results/mbpp_google_gemma-2-2b-it_s=0_t=1_synth_nc.jsonl' 179 | + CUDA_VISIBLE_DEVICES=1 python3 inference_multiple.py --max-tokens 1000 --timeout 300 --model_name google/gemma-2-2b-it --seed 0 --temp 1 --subset mbpp --try_top_k 10000000000000000 --constrained True --output_file 'results/mbpp_google_gemma-2-2b-it_s=0_t=1_synth_c.jsonl' 180 | Total jobs: 2, Running jobs: 2, Remaining jobs: 0. Waiting for running jobs to finish... 181 | ``` 182 | 183 | To reproduce other tasks, the following commands reproduce the results for the translation task and the repair task on MBPP, and should take around 4 hours each: 184 | 185 | ```bash 186 | CUDA_VISIBLE_DEVICES=0,1 python3 experiments/main/run_experiments_syn_tran.py --models google/gemma-2-2b-it --tasks translate --subsets mbpp 187 | CUDA_VISIBLE_DEVICES=0,1 python3 experiments/main/run_experiments_repair.py --models google/gemma-2-2b-it --subsets mbpp 188 | ``` 189 | 190 | 191 | ### Running more experiments 192 | 193 | Then you can run more experiments for synthesis and translation by providing different models (`--models`), tasks (`--tasks`), and benchmarks (`--subsets`). Remember to use `CUDA_VISIBLE_DEVICES`. 194 | Note that a single 80 GB GPU provides sufficient VRAM to host any model used in our experiments. 195 | 196 | ```bash 197 | CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_syn_tran.py --models google/gemma-2-2b-it,google/gemma-2-9b-it --tasks synth --subsets mbpp,humaneval 198 | CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_syn_tran.py --models Qwen/Qwen2.5-32B-Instruct --tasks translate --subsets mbpp 199 | ``` 200 | 201 | You can similarly start the repair task: 202 | 203 | ```bash 204 | CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_repair.py --models google/gemma-2-2b-it,google/gemma-2-9b-it --subsets mbpp,humaneval 205 | CUDA_VISIBLE_DEVICES=0 python3 experiments/main/run_experiments_repair.py --models Qwen/Qwen2.5-32B-Instruct --subsets mbpp 206 | ``` 207 | 208 | Below is the list of all options for these parameters. Running all these options will cover all our experiments but can take several days. For the sake of time, reviewers may check a subset that they are interested in. 209 | 210 | ```bash 211 | FLAGS 212 | --models=MODELS 213 | Default: google/gemma-2-2b-it,google/gemma-2-9b-it,google/gemma-2-27b-it,deepseek-ai/deepseek-coder-33b-instruct,codellama/CodeLlama-34b-Instruct-hf,Qwen/Qwen2.5-32B-Instruct 214 | --tasks=TASKS (only for experiments/main/run_experiments_syn_tran.py) 215 | Default: synth,translate 216 | --subsets=SUBSETS 217 | Default: humaneval,mbpp 218 | ``` 219 | 220 | You can also deep dive into obtaining the list of all available parameters: 221 | 222 | ```bash 223 | python3 experiments/main/run_experiments_syn_tran.py --help 224 | python3 experiments/main/run_experiments_repair.py --help 225 | ``` 226 | 227 | ### Execution Time of Benchmarks 228 | 229 | The runtime of our main experiments depends on the choice of datasets and tasks and the choice of models. Generally, larger datasets and larger models result in longer execution times. 230 | 231 | Our benchmark features the MBPP and HumanEval datasets, adapted for three tasks: synthesis, translate, and repair. 232 | Taking into account additional instances due to running on several seeds, the experiments can be ordered in increasing order of runtime as: MBPP-repair, HumanEval-repair, MBPP-{synthesis,translate}, and HumanEval-{synthesis,translate}. 233 | 234 | Our evaluation further features 6 models, in order of increasing parameter size, Gemma 2 2B, Gemma 2 9B, Gemma 2 27B, Qwen 2.5 32B, DeepSeek Coder 33B, and CodeLlama 34B. 235 | 236 | Thus, the quickest experiment is computing the performance of Gemma 2 2B synthesis on MBPP, taking approximately 4h on a single GPU. The longest experiment is computing performance of CodeLlama 34B synthesis on HumanEval. 237 | 238 | ### Recreating Figures 239 | 240 | You can run the following command to produce the figures for the paper. You may run this script with partial results, in which case you will receive a print out of missing results and its positions in the table will be substituted with "-1". 241 | 242 | ```bash 243 | bash experiments/main/figures.sh 244 | ``` 245 | 246 | The results map to the corresponding figures in the paper as follows: 247 | - Table 2 and 3: all models, all tasks, all datasets, i.e., `[mbpp|humaneval]_*_s=[0|1|2|3]_t=1_[synth|translate|repair-all]_[c|nc].jsonl`. Vanilla and Syntax can be computed based on non-constrained (`nc`) variants. 248 | - Table 4: all models, synthesis, all datasets, i.e., `[mbpp|humaneval]_*_s=[0|1|2|3]_t=1_synth_[c|nc].jsonl` 249 | - Figure 8: Gemma 2 2B, synthesis, HumanEval, i.e., `humaneval_google_gemma-2-2b-it_s=[0|1|2|3]_t=1_synth_[c|nc].jsonl` 250 | 251 | Since running the entire pipeline takes several days using 8 GPUs, we have included our raw data in the `experiments/main/results_paper` directory. You can directly run the figures script without running the experiments for the submitted results like this: 252 | 253 | ```bash 254 | bash experiments/main/figures.sh results_paper 255 | ``` 256 | 257 | > Note: Table 4 is a runtime table. You should expect the runtime per instance to differ based on the CPU and GPU used, however the *runtime increase* should be consistent with our findings. 258 | 259 | ## Project Structure 260 | 261 | The core part of our work is the implementation of a completion engine that incrementally parses type-safe TypeScript programs. 262 | The completion engine can then be used to constrain the sampling from an LLM model to only generate type-safe programs. 263 | 264 | This project is organized as a Python package. 265 | The relevant code for the implementation of type-constrained decoding and sampling is located in the `typesafe_llm` directory. 266 | The experiments are located in the `experiments` directory. 267 | We further provide a test suite in the `tests` directory. 268 | The usage of the latter two is described above. 269 | In the following sections we describe the structure of the `typesafe_llm` package. 270 | 271 | ### (Constrained) Sampling (Algorithm 1) 272 | 273 | The sampling algorithm presented in Section 2.1 of the paper is located in `typesafe_llm/sampling.py`. 274 | It uses the `transformers` library to infer predictions from a language model, sample from it and, if constraining is enabled, runs a parser in parallel to reject invalid programs (`sample_constrained`). 275 | 276 | ### Prefix Automaton Definition and Base Automata (Section 3.2) 277 | 278 | The prefix automaton is defined in `typesafe_llm/automata/parser_base.py`. 279 | The automaton is implicitely defined by defining the transition function and acceptance status in each state, subclassing from `IncrementalParserState`. 280 | A state indicates that it is an accepting state by setting the field `accept` to True. 281 | The transition function is invoked by the method `parse_char` and returns a list of new states that can be reached by parsing the given character. 282 | The file further contains the definitions of concatenation, union, kleene plus and terminal automata. 283 | 284 | ### Identifiers, Literals and Types (Section 3.3) 285 | 286 | The automaton for identifiers (`ExistingIdentifierParserState`) is the first automaton defined in `typesafe_llm/automata/parser_ts.py`. 287 | The following automata parse literals (`LiteralParserState` and its subclasses), including more advanced literals such as regular expressions and template strings. 288 | 289 | The automaton for types is defined seperately in `typesafe_llm/automata/parser_ts_types.py`. 290 | 291 | ### Expressions (Section 3.4) 292 | 293 | The expression automaton is defined in `typesafe_llm/automata/parser_ts.py` in the class `ExpressionParserState`. 294 | It implements the extension logic and the pruning of invalid transitions due to operator precedence and type constraints. 295 | The derivability algorithm is implemented for each state individually in the method `derivable`. It determines the directly derivable types and call the reachability algorithm with them. 296 | The type reachability algorithm is implemented in `typesafe_llm/parser/types_ts.py` in the method `reachable`, leveraging `_reachable_bfs` - a straightforward breadth-first search translation of the presented reachability algorithm. 297 | 298 | ### Statements and the entire Program (Section 3.5) 299 | 300 | The automaton for statements is defined in `typesafe_llm/automata/parser_ts.py` in the class `StatementParserState`. 301 | It handles the constraining for valid return types. 302 | The automaton for the entire program is defined in `typesafe_llm/automata/parser_ts.py` in the class `ProgramParserState`. 303 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/main/.gitignore: -------------------------------------------------------------------------------- 1 | results* 2 | multiple_outputs.jsonl 3 | humanevalx_outputs.jsonl 4 | test.js 5 | test.ts 6 | test*.py 7 | -------------------------------------------------------------------------------- /experiments/main/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/experiments/main/__init__.py -------------------------------------------------------------------------------- /experiments/main/analyze_avg_time.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from experiments.main.util import invalid_mbpp_instances 5 | 6 | 7 | def main( 8 | outputs_files, 9 | field, 10 | condition, 11 | ): 12 | # mbpp or main size 13 | target_size = ( 14 | ( 15 | 159 16 | if any("humaneval" in s for s in outputs_files) 17 | else (390 - len(invalid_mbpp_instances)) 18 | ) 19 | if not any("repair" in s for s in outputs_files) 20 | else (292 if any("humaneval" in s for s in outputs_files) else 248) 21 | ) 22 | outputs_by_instance = {} 23 | instances = set() 24 | for outputs_file in outputs_files: 25 | outputs_by_instance[outputs_file] = {} 26 | try: 27 | with open(outputs_file, "r") as f: 28 | outputs = [] 29 | for i, line in enumerate(f): 30 | # print(i) 31 | outputs.append(json.loads(line)) 32 | except Exception: 33 | outputs = [] 34 | for output in outputs: 35 | if output["instance_id"] in invalid_mbpp_instances: 36 | continue 37 | outputs_by_instance[outputs_file][output["instance_id"]] = output 38 | instances.add(output["instance_id"]) 39 | if ( 40 | len(instances) < target_size 41 | and not any("repair" in s for s in outputs_files) 42 | or any( 43 | len(outputs_by_instance_f) < target_size 44 | for filename, outputs_by_instance_f in outputs_by_instance.items() 45 | ) 46 | ): 47 | res = "incomplete" 48 | else: 49 | i = [] 50 | for instance_id in sorted(instances): 51 | res = 0 52 | for file_name in outputs_files: 53 | outputs_by_instance_f = outputs_by_instance.get(file_name, {}) 54 | cond = ( 55 | outputs_by_instance_f.get(instance_id, {}).get(condition) 56 | if condition != "none" 57 | else False 58 | ) 59 | if outputs_by_instance_f.get(instance_id, {}) and not cond: 60 | res = max(outputs_by_instance_f[instance_id][field], res) 61 | i.append(res) 62 | res = i 63 | return res, len(instances), outputs_by_instance 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("files", nargs="*") 69 | parser.add_argument( 70 | "-f", 71 | "--field", 72 | choices=["time_taken"], 73 | default="time_taken", 74 | ) 75 | parser.add_argument( 76 | "-c", 77 | "--condition", 78 | choices=["none"], 79 | default="none", 80 | ) 81 | args = parser.parse_args() 82 | res, n, outputs_by_instance = main( 83 | args.files, 84 | args.field, 85 | args.condition, 86 | ) 87 | print(res) 88 | -------------------------------------------------------------------------------- /experiments/main/analyze_inf_res.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from experiments.main.util import cutoff, invalid_mbpp_instances, extract_code 5 | 6 | 7 | def main( 8 | outputs_files, 9 | field, 10 | condition, 11 | show_list, 12 | intersection, 13 | ): 14 | # mbpp or main size 15 | target_size = ( 16 | ( 17 | 159 18 | if any("humaneval" in s for s in outputs_files) 19 | else (390 - len(invalid_mbpp_instances)) 20 | ) 21 | if not any("repair" in s for s in outputs_files) 22 | else (292 if any("humaneval" in s for s in outputs_files) else 248) 23 | ) 24 | outputs_by_instance = {} 25 | instances = set() 26 | for outputs_file in outputs_files: 27 | outputs_by_instance[outputs_file] = {} 28 | try: 29 | with open(outputs_file, "r") as f: 30 | outputs = [] 31 | for i, line in enumerate(f): 32 | # print(i) 33 | outputs.append(json.loads(line)) 34 | except Exception: 35 | outputs = [] 36 | # print("Error", e) 37 | for output in outputs: 38 | # this correctly handles repair-all 39 | if any(id in output["instance_id"] for id in invalid_mbpp_instances): 40 | continue 41 | outputs_by_instance[outputs_file][output["instance_id"]] = output 42 | instances.add(output["instance_id"]) 43 | if len(instances) < target_size or any( 44 | len(outputs_by_instance_f) < target_size 45 | for filename, outputs_by_instance_f in outputs_by_instance.items() 46 | ): 47 | res = "incomplete" 48 | else: 49 | i = 0 50 | for instance_id in sorted(instances): 51 | res = True if intersection else False 52 | for file_name in outputs_files: 53 | outputs_by_instance_f = outputs_by_instance.get(file_name, {}) 54 | cond = ( 55 | outputs_by_instance_f.get(instance_id, {}).get(condition) 56 | if condition != "none" 57 | else False 58 | ) 59 | if outputs_by_instance_f.get(instance_id, {}) and not cond: 60 | if field == "syntax_error": 61 | res = ( 62 | "SyntaxError" 63 | in outputs_by_instance_f[instance_id]["compiler_output"] 64 | or "syntax error" 65 | in outputs_by_instance_f[instance_id]["compiler_output"] 66 | ) 67 | else: 68 | res = bool(outputs_by_instance_f[instance_id].get(field, True)) 69 | break 70 | res_pos = res 71 | if res_pos and show_list: 72 | print(instance_id) 73 | i += res_pos 74 | res = i / len(instances) 75 | return res, len(instances), outputs_by_instance 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("files", nargs="*") 81 | parser.add_argument("-n", "--non_interactive", action="store_true") 82 | parser.add_argument( 83 | "-f", 84 | "--field", 85 | choices=[ 86 | "syntax_error", 87 | "compiler_output", 88 | "tests_passed", 89 | "compiled", 90 | "syntax_ok", 91 | ], 92 | default="compiler_output", 93 | ) 94 | parser.add_argument( 95 | "-c", 96 | "--condition", 97 | choices=["compiler_output", "tests_passed", "none", "compiled", "syntax_ok"], 98 | default="compiler_output", 99 | ) 100 | parser.add_argument("-l", action="store_true", help="list all negative cases") 101 | parser.add_argument( 102 | "-i", 103 | action="store_true", 104 | help="intersection instead of union (use for compiler_output)", 105 | ) 106 | args = parser.parse_args() 107 | res, n, outputs_by_instance = main( 108 | args.files, 109 | args.field, 110 | args.condition, 111 | args.l, 112 | args.i, 113 | ) 114 | print(res) 115 | if args.non_interactive: 116 | exit() 117 | input() 118 | for file_name, outputs_by_instance_f in outputs_by_instance.items(): 119 | for instance_id, output in outputs_by_instance_f.items(): 120 | if not output[args.field]: 121 | code = cutoff(extract_code(output["code"], "Go", 0)).splitlines() 122 | code = "\n".join([f"{i+1:03d} {s}" for i, s in enumerate(code)]) 123 | print(code) 124 | print(output["crashed"]) 125 | print(output["compiler_output"]) 126 | print(output["instance_id"]) 127 | input() 128 | -------------------------------------------------------------------------------- /experiments/main/analyze_many_temp.sh: -------------------------------------------------------------------------------- 1 | params=("compiler_output -i" "tests_passed") 2 | temp="1" 3 | suffixs=("" "_translate" "_repair") 4 | models=( 5 | # "microsoft/Phi-3.5-mini-instruct" 6 | # "codellama/CodeLlama-7b-Instruct-hf" 7 | # "meta-llama/Llama-3.1-8B-Instruct" 8 | # "google/gemma-2b-it" 9 | "google/gemma-2-2b-it" 10 | "google/gemma-2-9b-it" 11 | "deepseek-ai/deepseek-coder-33b-instruct" 12 | # "deepseek-ai/deepseek-coder-7b-instruct-v1.5" 13 | # "deepseek-ai/deepseek-coder-1.3b-instruct" 14 | # "meta-llama/Llama-3.1-70B-Instruct" 15 | # "codellama/CodeLlama-70b-Instruct-hf" 16 | "codellama/CodeLlama-34b-Instruct-hf" 17 | # "codellama/CodeLlama-13b-Instruct-hf" 18 | # "google/codegemma-7b-it" 19 | # "bigcode/octocoder" 20 | "google/gemma-2-27b-it" 21 | "Qwen/Qwen2.5-32B-Instruct" 22 | ) 23 | subset="humaneval" 24 | for param in "${params[@]}" 25 | do 26 | for suffix in "${suffixs[@]}" 27 | do 28 | echo "$param unconstrained $suffix" 29 | for model in "${models[@]}" 30 | do 31 | python3 analyze_inf_res.py -f ${param} "results/${subset}_${model//\//_}_s=0_t=${temp}${suffix}_nc.jsonl" --non_interactive 32 | done 33 | echo "$param constrained $suffix" 34 | for model in "${models[@]}" 35 | do 36 | python3 analyze_inf_res.py -f ${param} "results/${subset}_${model//\//_}_s=0_t=${temp}${suffix}_c.jsonl" --non_interactive 37 | done 38 | echo "$param union $suffix" 39 | for model in "${models[@]}" 40 | do 41 | python3 analyze_inf_res.py -f ${param} "results/${subset}_${model//\//_}_s=0_t=${temp}${suffix}_nc.jsonl" "results/${subset}_${model//\//_}_s=0_t=${temp}${suffix}_c.jsonl" --non_interactive 42 | done 43 | done 44 | done 45 | -------------------------------------------------------------------------------- /experiments/main/analyze_multifun.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import fire 4 | 5 | 6 | def main( 7 | outputs_files, 8 | ): 9 | if isinstance(outputs_files, str): 10 | outputs_files = [outputs_files] 11 | outputs_by_instance = {} 12 | instances = set() 13 | for outputs_file in outputs_files: 14 | outputs_by_instance[outputs_file] = {} 15 | try: 16 | with open(outputs_file, "r") as f: 17 | outputs = [] 18 | for i, line in enumerate(f): 19 | # print(i) 20 | outputs.append(json.loads(line)) 21 | except Exception as e: 22 | raise e 23 | outputs = [] 24 | for output in outputs: 25 | outputs_by_instance[outputs_file][output["instance_id"]] = output 26 | instances.add(output["instance_id"]) 27 | i = 0 28 | for instance_id in instances: 29 | res_pos = False 30 | for file_name in outputs_files: 31 | outputs_by_instance_f = outputs_by_instance.get(file_name, {}) 32 | output = outputs_by_instance_f[instance_id] 33 | res = re.findall( 34 | r"error TS2304: Cannot find name '(.+)'", output["compiler_output"] 35 | ) 36 | if res: 37 | for r in res: 38 | if r in output["code"]: 39 | res_pos = True 40 | break 41 | if res_pos: 42 | break 43 | i += res_pos 44 | print(i) 45 | 46 | 47 | fire.Fire(main) 48 | -------------------------------------------------------------------------------- /experiments/main/analyze_outputs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import fire 5 | 6 | from typesafe_llm.parser.parser_ts import parse_ts_program 7 | from experiments.main.util import cutoff, tsx_compiles, passes_tests_js 8 | 9 | 10 | def main(outputs_file, recompile=False): 11 | with open(outputs_file, "r") as f: 12 | outputs = [json.loads(line) for line in f] 13 | outputs_by_instance_constrained = {} 14 | for output in outputs: 15 | outputs_by_instance_constrained[ 16 | (output["instance_id"], output["constrained"]) 17 | ] = output 18 | nums = [ 19 | [0, 0, 0, 0, 0], 20 | [0, 0, 0, 0, 0], 21 | ] # unconstrained: total, compiled, compiled in sublang, passed, time 22 | solved = [set(), set()] 23 | for (instance_id, constrained), output in outputs_by_instance_constrained.items(): 24 | if "initial prompt" in output["crashed"]: 25 | continue 26 | compiled, err = ( 27 | tsx_compiles(output["tsc_code"]) 28 | if recompile 29 | else ( 30 | output["compiled"], 31 | False, 32 | ) 33 | ) 34 | compiled_sub = output["compiled_in_sublang"] if not constrained else compiled 35 | start = time.time() 36 | tests_passed = ( 37 | (passes_tests_js(compiled) if recompile else output["tests_passed"]) 38 | if compiled 39 | else False 40 | ) 41 | end = time.time() 42 | nums[constrained][0] += 1 43 | nums[constrained][1] += bool(compiled) 44 | nums[constrained][2] += bool(compiled_sub) 45 | nums[constrained][3] += bool(tests_passed) and bool(compiled_sub) 46 | nums[constrained][4] += end - start if recompile else output["time_taken"] 47 | if compiled_sub and compiled: 48 | solved[constrained].add(output["instance_id"]) 49 | if ( 50 | False 51 | and not compiled 52 | and constrained 53 | and outputs_by_instance_constrained[(instance_id, False)]["compiled"] 54 | ): 55 | new_sub_compiled = parse_ts_program( 56 | cutoff(outputs_by_instance_constrained[(instance_id, False)]["code"]) 57 | ) 58 | if new_sub_compiled: 59 | continue 60 | print(instance_id) 61 | print(cutoff(outputs_by_instance_constrained[(instance_id, False)]["code"])) 62 | print("-" * 80) 63 | print(cutoff(output["code"])) 64 | print(output["compiled_in_sublang"]) 65 | print(output["tests_passed"]) 66 | print(output["crashed"]) 67 | print(output["time_taken"]) 68 | print("-" * 80) 69 | input() 70 | print(nums) 71 | print([[x / total for x in num for total in [num[0]]] for num in nums]) 72 | print(solved[0] & solved[1]) 73 | print(solved[0] - solved[1]) 74 | print(solved[1] - solved[0]) 75 | 76 | 77 | if __name__ == "__main__": 78 | fire.Fire(main) 79 | -------------------------------------------------------------------------------- /experiments/main/analyze_pass_not_accepted.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from experiments.main.util import cutoff 5 | from typesafe_llm.parser.parser_ts import parse_ts_program 6 | 7 | outputs_file = sys.argv[1] 8 | with open(outputs_file, "r") as f: 9 | outputs = [json.loads(line) for line in f] 10 | outputs_by_instance_constrained = {} 11 | for output in outputs: 12 | outputs_by_instance_constrained[(output["instance_id"], output["constrained"])] = ( 13 | output 14 | ) 15 | i = 0 16 | for (instance_id, constrained), output in outputs_by_instance_constrained.items(): 17 | if output["tests_passed"]: 18 | i += 1 19 | print(i / len(outputs)) 20 | for (instance_id, constrained), output in outputs_by_instance_constrained.items(): 21 | if output["tests_passed"] and not any( 22 | s.accept for s in parse_ts_program(cutoff(output["compilable"])) 23 | ): 24 | print(output["instance_id"]) 25 | -------------------------------------------------------------------------------- /experiments/main/analyze_passatk.sh: -------------------------------------------------------------------------------- 1 | param="tests_passed" 2 | temp="1" 3 | suffixs=("" "_translate") 4 | seeds=(0 1 2 3 4) 5 | models=( 6 | # "microsoft/Phi-3.5-mini-instruct" 7 | # "codellama/CodeLlama-7b-Instruct-hf" 8 | # "meta-llama/Llama-3.1-8B-Instruct" 9 | # "google/gemma-2b-it" 10 | # "google/gemma-2-2b-it" 11 | # "google/gemma-2-9b-it" 12 | "deepseek-ai/deepseek-coder-33b-instruct" 13 | # "deepseek-ai/deepseek-coder-7b-instruct-v1.5" 14 | # "deepseek-ai/deepseek-coder-1.3b-instruct" 15 | # "meta-llama/Llama-3.1-70B-Instruct" 16 | # "codellama/CodeLlama-70b-Instruct-hf" 17 | "codellama/CodeLlama-34b-Instruct-hf" 18 | # "codellama/CodeLlama-13b-Instruct-hf" 19 | # "google/codegemma-7b-it" 20 | # "bigcode/octocoder" 21 | "google/gemma-2-27b-it" 22 | "Qwen/Qwen2.5-32B-Instruct" 23 | ) 24 | for suffix in "${suffixs[@]}" 25 | do 26 | echo "unconstrained $suffix" 27 | for model in "${models[@]}" 28 | do 29 | files=$(printf "results/humaneval_${model//\//_}_s=%s_t=${temp}${suffix}_nc.jsonl " "${seeds[@]}") 30 | python3 analyze_inf_res.py -c ${param} -f ${param} ${files} --non_interactive 31 | done 32 | echo "constrained $suffix" 33 | for model in "${models[@]}" 34 | do 35 | files=$(printf "results/humaneval_${model//\//_}_s=%s_t=${temp}${suffix}_c.jsonl " "${seeds[@]}") 36 | python3 analyze_inf_res.py -c ${param} -f ${param} ${files} --non_interactive 37 | done 38 | echo "union $suffix" 39 | for model in "${models[@]}" 40 | do 41 | files=$(printf "results/humaneval_${model//\//_}_s=%s_t=${temp}${suffix}_nc.jsonl results/humaneval_${model//\//_}_s=%s_t=${temp}${suffix}_c.jsonl " "${seeds[@]}" "${seeds[@]}") 42 | python3 analyze_inf_res.py -c ${param} -f ${param} ${files} --non_interactive 43 | done 44 | done 45 | # suffix="_repair" 46 | # echo "unconstrained $suffix" 47 | # for model in "${models[@]}" 48 | # do 49 | # python3 analyze_inf_res.py -f ${param} "results/humaneval_${model//\//_}_s=0_t=${temp}_nc.jsonl" "results/humaneval_${model//\//_}_s=0_t=${temp}${suffix}_nc.jsonl" --non_interactive 50 | # done 51 | # echo "constrained $suffix" 52 | # for model in "${models[@]}" 53 | # do 54 | # python3 analyze_inf_res.py -f ${param} "results/humaneval_${model//\//_}_s=0_t=${temp}_nc.jsonl" "results/humaneval_${model//\//_}_s=0_t=${temp}${suffix}_c.jsonl" --non_interactive 55 | # done 56 | # echo "union $suffix" 57 | # for model in "${models[@]}" 58 | # do 59 | # python3 analyze_inf_res.py -f ${param} "results/humaneval_${model//\//_}_s=0_t=${temp}_nc.jsonl" "results/humaneval_${model//\//_}_s=0_t=${temp}${suffix}_nc.jsonl" "results/humaneval_${model//\//_}_s=0_t=${temp}${suffix}_c.jsonl" --non_interactive 60 | # done 61 | -------------------------------------------------------------------------------- /experiments/main/analyze_resample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from collections import defaultdict 4 | 5 | from experiments.main.util import cutoff, invalid_mbpp_instances, extract_code 6 | 7 | 8 | def main( 9 | outputs_files, 10 | ): 11 | outputs_by_instance = {} 12 | instances = set() 13 | for outputs_file in outputs_files: 14 | outputs_by_instance[outputs_file] = {} 15 | try: 16 | with open(outputs_file, "r") as f: 17 | outputs = [] 18 | for i, line in enumerate(f): 19 | # print(i) 20 | outputs.append(json.loads(line)) 21 | except Exception: 22 | outputs = [] 23 | for output in outputs: 24 | if output["instance_id"] in invalid_mbpp_instances: 25 | continue 26 | outputs_by_instance[outputs_file][output["instance_id"]] = output 27 | instances.add(output["instance_id"]) 28 | chart_max_resample_vs_syntax_error_total = defaultdict(int) 29 | chart_max_resample_vs_syntax_error_se = defaultdict(int) 30 | for file_name in outputs_files: 31 | for instance_id in sorted(instances): 32 | outputs_by_instance_f = outputs_by_instance.get(file_name, {}) 33 | instance = outputs_by_instance_f.get(instance_id, {}) 34 | if not instance: 35 | continue 36 | resamples = instance.get("resamples", None) 37 | if resamples is None: 38 | continue 39 | did_not_terminate = "SyntaxError" in instance["compiler_output"] 40 | cutoffed_code = cutoff(extract_code(instance["code"], "TypeScript", 0)) 41 | len_of_cutoffed = len(cutoffed_code) 42 | resamples = [x for x in resamples if x[0] < len_of_cutoffed] 43 | if not resamples: 44 | max_resample = 0 45 | else: 46 | max_resample = max(x[1] for x in resamples) 47 | if max_resample > 15: 48 | max_resample -= max_resample % 10 49 | chart_max_resample_vs_syntax_error_total[max_resample] += 1 50 | chart_max_resample_vs_syntax_error_se[max_resample] += did_not_terminate 51 | for max_resample, total in sorted(chart_max_resample_vs_syntax_error_total.items()): 52 | print( 53 | f"({max_resample}, {chart_max_resample_vs_syntax_error_se[max_resample]*100/total})" 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("files", nargs="*") 60 | args = parser.parse_args() 61 | main( 62 | args.files, 63 | ) 64 | -------------------------------------------------------------------------------- /experiments/main/convert_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | 5 | 6 | def convert_files(directory): 7 | # Define the pattern to match files 8 | pattern = os.path.join(directory, "*_nc.jsonl") 9 | 10 | # Find all files matching the pattern 11 | files = glob.glob(pattern) 12 | 13 | for file_path in files: 14 | # Extract the base name of the file 15 | base_name = os.path.basename(file_path) 16 | 17 | # Check if the file does not end with _repair-all or _translate before _c.jsonl or _nc.jsonl 18 | if not any(s in file_path for s in ("translate", "repair")): 19 | # Find the position of the last underscore before _c.jsonl or _nc.jsonl 20 | underscore_pos = base_name.rfind("_nc.jsonl") 21 | 22 | if underscore_pos != -1: 23 | # Extract the part before the underscore 24 | prefix = base_name[:underscore_pos] 25 | 26 | # Create the new file name 27 | new_file_name = f"{prefix}_synth_nc.jsonl" 28 | 29 | # Define the new file path 30 | new_file_path = os.path.join(directory, new_file_name) 31 | 32 | # Rename the file 33 | os.rename(file_path, new_file_path) 34 | print(f"Renamed '{file_path}' to '{new_file_path}'") 35 | 36 | 37 | # Specify the directory containing the files 38 | directory = sys.argv[1] 39 | 40 | # Call the function to convert files 41 | convert_files(directory) 42 | -------------------------------------------------------------------------------- /experiments/main/count_syntax_errors.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import sys 4 | import tempfile 5 | 6 | from experiments.main.util import cutoff 7 | 8 | path_to_ts_parser = "../../ts_parser/target/release/ts_parser" 9 | outputs_file = sys.argv[1] 10 | try: 11 | with open(outputs_file, "r") as f: 12 | outputs = [] 13 | for i, line in enumerate(f): 14 | # print(i) 15 | outputs.append(json.loads(line)) 16 | except Exception: 17 | outputs = [] 18 | outputs_by_instance_constrained = {} 19 | for output in outputs: 20 | outputs_by_instance_constrained[(output["instance_id"], output["constrained"])] = ( 21 | output 22 | ) 23 | total = len(outputs) 24 | syntax = 0 25 | other = 0 26 | for (instance_id, constrained), output in outputs_by_instance_constrained.items(): 27 | with tempfile.NamedTemporaryFile(suffix=".ts") as f: 28 | code = cutoff(output["compilable"]).encode() 29 | f.write(code) 30 | f.flush() 31 | res = subprocess.run([path_to_ts_parser, f.name], capture_output=True) 32 | if res.returncode != 0: 33 | syntax += 1 34 | elif not output["compiled"]: 35 | other += 1 36 | print((other + syntax) / total, other / (syntax + other) if syntax + other != 0 else 1) 37 | -------------------------------------------------------------------------------- /experiments/main/create_humaneval_repair_dataset.sh: -------------------------------------------------------------------------------- 1 | models=( 2 | "google/gemma-2-2b-it" 3 | "google/gemma-2-9b-it" 4 | "google/gemma-2-27b-it" 5 | "deepseek-ai/deepseek-coder-33b-instruct" 6 | "codellama/CodeLlama-34b-Instruct-hf" 7 | "Qwen/Qwen2.5-32B-Instruct" 8 | ) 9 | seeds=(0 1 2 3) 10 | temp=1 11 | suffix="_synth" 12 | subset=humaneval 13 | ds_name="repair_datasets/${subset}_repair_dataset.jsonl" 14 | # empty 15 | : > $ds_name 16 | # append outputs for models and seeds 17 | for seed in "${seeds[@]}" 18 | do 19 | for model in "${models[@]}" 20 | do 21 | python3 create_repair_dataset.py "results/${subset}_${model//\//_}_s=${seed}_t=${temp}${suffix}_nc.jsonl" >> $ds_name 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /experiments/main/create_mbpp_repair_dataset.sh: -------------------------------------------------------------------------------- 1 | models=( 2 | "google/gemma-2-2b-it" 3 | "google/gemma-2-9b-it" 4 | "google/gemma-2-27b-it" 5 | "deepseek-ai/deepseek-coder-33b-instruct" 6 | "codellama/CodeLlama-34b-Instruct-hf" 7 | "Qwen/Qwen2.5-32B-Instruct" 8 | ) 9 | seeds=(0 ) 10 | temp=1 11 | suffix="_synth" 12 | subset=mbpp 13 | ds_name="repair_datasets/${subset}_repair_dataset.jsonl" 14 | # empty 15 | : > $ds_name 16 | # append outputs for models and seeds 17 | for seed in "${seeds[@]}" 18 | do 19 | for model in "${models[@]}" 20 | do 21 | python3 create_repair_dataset.py "results/${subset}_${model//\//_}_s=${seed}_t=${temp}${suffix}_nc.jsonl" >> $ds_name 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /experiments/main/create_repair_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import json 4 | 5 | 6 | def main(input_files: list[Path]): 7 | for file in input_files: 8 | with open(file) as f: 9 | for line in f: 10 | line = line.strip() 11 | instance = json.loads(line) 12 | instance["repair_id"] = instance["instance_id"] + str(file) 13 | if instance["compiler_output"]: 14 | print(json.dumps(instance)) 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("files", nargs="*", type=Path) 20 | args = parser.parse_args() 21 | main(args.files) 22 | -------------------------------------------------------------------------------- /experiments/main/download_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from transformers import AutoModelForCausalLM 6 | 7 | 8 | def main(): 9 | # Default models 10 | default_models = [ 11 | "google/gemma-2-2b-it", 12 | "google/gemma-2-9b-it", 13 | "google/gemma-2-27b-it", 14 | "deepseek-ai/deepseek-coder-33b-instruct", 15 | "codellama/CodeLlama-34b-Instruct-hf", 16 | "Qwen/Qwen2.5-32B-Instruct", 17 | ] 18 | 19 | # Set up argument parser 20 | parser = argparse.ArgumentParser( 21 | description="Download and load models for causal language modeling." 22 | ) 23 | parser.add_argument( 24 | "--models", 25 | type=str, 26 | default=",".join(default_models), 27 | help=f"Comma-separated list of model names to load.\nDefault: {','.join(default_models)}", 28 | ) 29 | parser.add_argument( 30 | "--device-map", 31 | type=str, 32 | default="auto", 33 | help="Device to load models on temporarily (e.g., 'cpu', 'auto', 'cuda:0').\nDefault: 'cpu'", 34 | ) 35 | args = parser.parse_args() 36 | 37 | # Load datasets 38 | dataset_name = "nuprl/MultiPL-E" 39 | load_dataset(dataset_name, "humaneval-ts")["test"] 40 | load_dataset(dataset_name, "mbpp-ts")["test"] 41 | 42 | # Parse models 43 | models = [x.strip() for x in args.models.split(",")] 44 | 45 | # Load models 46 | for model in models: 47 | x = AutoModelForCausalLM.from_pretrained(model, device_map=args.device_map) 48 | del x 49 | torch.cuda.empty_cache() 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /experiments/main/figures.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Description: Run the experiment for generating samples from the model constrained and unconstrained 3 | set -e 4 | # cd into the directory of this file 5 | cd "$(dirname "${BASH_SOURCE[0]}")" 6 | # DIR is either the parameter passed to this script or the result directory 7 | DIR="${1:-results}" 8 | python3 -m pip install tabulate 9 | echo "" 10 | echo "The following Table/Figure references refer to the revised paper, as attached to the comment thread of the discussion." 11 | echo "" 12 | echo "Table 2" 13 | echo "Humaneval" 14 | python3 figures_revision/fig_compiler_perf_syn_tran_repair.py --subset humaneval --directory "$DIR" 15 | echo "MBPP" 16 | python3 figures_revision/fig_compiler_perf_syn_tran_repair.py --subset mbpp --directory "$DIR" 17 | echo "" 18 | echo "Table 3" 19 | echo "Humaneval" 20 | python3 figures/fig_compiler_perf_fc.py --subset humaneval --directory "$DIR" 21 | echo "MBPP" 22 | python3 figures/fig_compiler_perf_fc.py --subset mbpp --directory "$DIR" 23 | echo "" 24 | echo "Table 4" 25 | python3 figures/fig_compiler_time.py --directory "$DIR" 26 | echo "" 27 | echo "Figure 8" 28 | bash figures_revision/fig_resample_hist.sh "$DIR" 29 | -------------------------------------------------------------------------------- /experiments/main/figures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/experiments/main/figures/__init__.py -------------------------------------------------------------------------------- /experiments/main/figures/fig_compiler_perf_comparison.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | from collections import defaultdict 4 | 5 | import fire 6 | from tabulate import tabulate 7 | 8 | from experiments.main.analyze_inf_res import main as inf_res 9 | 10 | SUFFIXES = ["_synth", "_translate"] 11 | SUBSETS = ["main", "mbpp"] 12 | SUBSET_SIZE = { 13 | "main": 1, 14 | "mbpp": 1, 15 | } 16 | MODEL_NAME_MAP = { 17 | "google/gemma-2-2b-it": "Gemma 2 2B", 18 | "google/gemma-2-9b-it": "Gemma 2 9B", 19 | "google/gemma-2-27b-it": "Gemma 2 27B", 20 | "deepseek-ai/deepseek-coder-33b-instruct": "DeepSeek Coder 33B", 21 | "codellama/CodeLlama-34b-Instruct-hf": "CodeLlama 34B", 22 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5 32B", 23 | } 24 | 25 | 26 | def main(format="github", field="compiler_output", subset="main"): 27 | models = ( 28 | "google/gemma-2-2b-it", 29 | "google/gemma-2-9b-it", 30 | "google/gemma-2-27b-it", 31 | "deepseek-ai/deepseek-coder-33b-instruct", 32 | "codellama/CodeLlama-34b-Instruct-hf", 33 | "Qwen/Qwen2.5-32B-Instruct", 34 | ) 35 | temp = 1 36 | condition = "compiler_output" 37 | unconstrained = [] 38 | constrained = [] 39 | ideal_syntax = [] 40 | totals = [] 41 | for model in models: 42 | du = defaultdict(float) 43 | dc = defaultdict(float) 44 | id = defaultdict(float) 45 | total = defaultdict(int) 46 | for suffix in SUFFIXES: 47 | seeds = [0] if subset == "mbpp" else [1, 2, 3, 4] 48 | for seed in seeds: 49 | res, n, _ = inf_res( 50 | [ 51 | f"results_bak_04122024/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 52 | f"results/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_c.jsonl", 53 | ], 54 | field=field, 55 | condition=condition, 56 | show_list=False, 57 | intersection=True, 58 | ) 59 | total[suffix] += n 60 | du[suffix] += res * n if res != "incomplete" else -float("inf") 61 | res, n, _ = inf_res( 62 | [ 63 | f"results_bak_04122024/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 64 | f"results/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_c.jsonl", 65 | ], 66 | field=field, 67 | condition=condition, 68 | show_list=False, 69 | intersection=True, 70 | ) 71 | dc[suffix] += res * n if res != "incomplete" else -float("inf") 72 | unconstrained.append(du) 73 | constrained.append(dc) 74 | ideal_syntax.append(id) 75 | totals.append(total) 76 | 77 | headers = ["Limited", "Unlimited"] * len(SUFFIXES) 78 | rows = [] 79 | for model, unc, con, id, total in zip( 80 | models, unconstrained, constrained, ideal_syntax, totals 81 | ): 82 | row = [MODEL_NAME_MAP[model]] 83 | for suffix in SUFFIXES: 84 | row.extend( 85 | ( 86 | "{:.1f}".format(unc[suffix] * 100 / total[suffix]) 87 | if total[suffix] 88 | else -1, 89 | "{:.1f}".format(con[suffix] * 100 / total[suffix]) 90 | if total[suffix] 91 | else -1, 92 | ) 93 | ) 94 | rows.append(row) 95 | if format == "csv": 96 | writer = csv.writer(sys.stdout) 97 | writer.writerow([""] + headers) 98 | writer.writerows(rows) 99 | else: 100 | print(tabulate(rows, headers=headers, tablefmt=format)) 101 | 102 | 103 | if __name__ == "__main__": 104 | fire.Fire(main) 105 | -------------------------------------------------------------------------------- /experiments/main/figures/fig_compiler_perf_fc.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | from collections import defaultdict 4 | 5 | import fire 6 | from tabulate import tabulate 7 | 8 | from experiments.main.analyze_inf_res import main as inf_res 9 | 10 | SUFFIXES = ["_synth", "_translate", "_repair-all"] 11 | SUBSETS = ["humaneval", "mbpp"] 12 | SUBSET_SIZE = { 13 | "humaneval": 1, 14 | "mbpp": 1, 15 | } 16 | MODEL_NAME_MAP = { 17 | "google/gemma-2-2b-it": "Gemma 2 2B", 18 | "google/gemma-2-9b-it": "Gemma 2 9B", 19 | "google/gemma-2-27b-it": "Gemma 2 27B", 20 | "deepseek-ai/deepseek-coder-33b-instruct": "DS Coder 33B", 21 | "codellama/CodeLlama-34b-Instruct-hf": "CodeLlama 34B", 22 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5 32B", 23 | } 24 | 25 | 26 | def main( 27 | format="github", field="tests_passed", subset="humaneval", directory="results" 28 | ): 29 | models = ( 30 | "google/gemma-2-2b-it", 31 | "google/gemma-2-9b-it", 32 | "google/gemma-2-27b-it", 33 | "deepseek-ai/deepseek-coder-33b-instruct", 34 | "codellama/CodeLlama-34b-Instruct-hf", 35 | "Qwen/Qwen2.5-32B-Instruct", 36 | ) 37 | temp = 1 38 | condition = "compiler_output" 39 | unconstrained = [] 40 | constrained = [] 41 | ideal_syntax = [] 42 | totals = [] 43 | for model in models: 44 | du = defaultdict(float) 45 | dc = defaultdict(float) 46 | id = defaultdict(float) 47 | total = defaultdict(int) 48 | for suffix in SUFFIXES: 49 | seeds = [0] if subset == "mbpp" or "repair-all" in suffix else [0, 1, 2, 3] 50 | discard_unconstrained, discard_constrained = False, False 51 | for seed in seeds: 52 | res, n, _ = inf_res( 53 | [ 54 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl" 55 | ], 56 | field=field, 57 | condition=condition, 58 | show_list=False, 59 | intersection=field == "compiler_output", 60 | ) 61 | if res == "incomplete": 62 | discard_unconstrained = True 63 | discard_constrained = True 64 | continue 65 | total[suffix] += n 66 | du[suffix] += res * n 67 | res, n, _ = inf_res( 68 | [ 69 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 70 | ], 71 | field="syntax_error", 72 | condition="none", 73 | show_list=False, 74 | intersection=True, 75 | ) 76 | if res == "incomplete": 77 | discard_unconstrained = True 78 | continue 79 | id[suffix] += res * n 80 | res, n, _ = inf_res( 81 | [ 82 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 83 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_c.jsonl", 84 | ], 85 | field=field, 86 | condition=condition, 87 | show_list=False, 88 | intersection=field == "compiler_output", 89 | ) 90 | if res == "incomplete": 91 | discard_constrained = True 92 | continue 93 | dc[suffix] += res * n 94 | if discard_unconstrained: 95 | du[suffix] = -1 96 | id[suffix] = -1 97 | if discard_constrained: 98 | dc[suffix] = -1 99 | unconstrained.append(du) 100 | constrained.append(dc) 101 | ideal_syntax.append(id) 102 | totals.append(total) 103 | 104 | # print("Average improvement functional correctness (over Models)") 105 | # for suffix in SUFFIXES: 106 | # improvements_syntax = [] 107 | # improvements_constrained = [] 108 | # for model, unc, con, id in zip( 109 | # models, unconstrained, constrained, ideal_syntax 110 | # ): 111 | # improvements_syntax.append( 112 | # id[suffix] * 100 / unc[suffix] if unc[suffix] != 0 else 0 113 | # ) 114 | # improvements_constrained.append( 115 | # (con[suffix] - unc[suffix]) * 100 / unc[suffix] 116 | # if unc[suffix] != 0 117 | # else 0 118 | # ) 119 | # print( 120 | # f"{suffix}, Syntax: {sum(improvements_syntax) / len(improvements_syntax):.1f}%" 121 | # ) 122 | # print( 123 | # f"{suffix}, Types: {sum(improvements_constrained) / len(improvements_constrained):.1f}%" 124 | # ) 125 | 126 | headers = ["", "Model"] + ["Standard", "Types", ""] * len(SUFFIXES) 127 | rows = [] 128 | for model, unc, con, id, total in zip( 129 | models, unconstrained, constrained, ideal_syntax, totals 130 | ): 131 | row = ["", MODEL_NAME_MAP[model]] 132 | for suffix in SUFFIXES: 133 | row.extend( 134 | ( 135 | "{:.1f}".format(unc[suffix] * 100 / total[suffix]) 136 | if unc[suffix] != -1 and total[suffix] != -1 137 | else -1, 138 | "\\textbf{{{:.1f}}}".format(con[suffix] * 100 / total[suffix]) 139 | if con[suffix] != -1 and total[suffix] != -1 140 | else -1, 141 | "", 142 | ) 143 | ) 144 | row.pop(-1) 145 | rows.append(row) 146 | if format == "csv": 147 | writer = csv.writer(sys.stdout) 148 | writer.writerow([""] + headers) 149 | writer.writerows(rows) 150 | else: 151 | print(tabulate(rows, headers=headers, tablefmt=format, floatfmt=".1f")) 152 | 153 | 154 | if __name__ == "__main__": 155 | fire.Fire(main) 156 | -------------------------------------------------------------------------------- /experiments/main/figures/fig_compiler_perf_repair.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | from collections import defaultdict 4 | 5 | import fire 6 | from tabulate import tabulate 7 | 8 | from experiments.main.analyze_inf_res import main as inf_res 9 | 10 | SUFFIXES = ["_repair-all"] 11 | SUBSETS = ["humaneval", "mbpp"] 12 | SUBSET_SIZE = { 13 | "humaneval": 1, 14 | "mbpp": 1, 15 | } 16 | MODEL_NAME_MAP = { 17 | "google/gemma-2-2b-it": "Gemma 2 2B", 18 | "google/gemma-2-9b-it": "Gemma 2 9B", 19 | "google/gemma-2-27b-it": "Gemma 2 27B", 20 | "deepseek-ai/deepseek-coder-33b-instruct": "DeepSeek Coder 33B", 21 | "codellama/CodeLlama-34b-Instruct-hf": "CodeLlama 34B", 22 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5 32B", 23 | } 24 | 25 | 26 | def main(format="github", field="compiler_output", directory="results"): 27 | models = ( 28 | "google/gemma-2-2b-it", 29 | "google/gemma-2-9b-it", 30 | "google/gemma-2-27b-it", 31 | "deepseek-ai/deepseek-coder-33b-instruct", 32 | "codellama/CodeLlama-34b-Instruct-hf", 33 | "Qwen/Qwen2.5-32B-Instruct", 34 | ) 35 | suffix = "_repair-all" 36 | temp = 1 37 | condition = "compiler_output" 38 | unconstrained = [] 39 | constrained = [] 40 | ideal_syntax = [] 41 | totals = [] 42 | for model in models: 43 | total = defaultdict(float) 44 | du = defaultdict(float) 45 | dc = defaultdict(float) 46 | id = defaultdict(float) 47 | for subset in SUBSETS: 48 | seeds = [0] 49 | for seed in seeds: 50 | res, n, _ = inf_res( 51 | [ 52 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl" 53 | ], 54 | field=field, 55 | condition=condition, 56 | show_list=False, 57 | intersection=field == "compiler_output", 58 | ) 59 | total[subset] += n if res != "incomplete" else 0 60 | du[subset] += res * n if res != "incomplete" else 0 61 | res, n, _ = inf_res( 62 | [ 63 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 64 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_c.jsonl", 65 | ], 66 | field=field, 67 | condition=condition, 68 | show_list=False, 69 | intersection=field == "compiler_output", 70 | ) 71 | dc[subset] += res * n if res != "incomplete" else 0 72 | res, n, _ = inf_res( 73 | [ 74 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 75 | ], 76 | field="syntax_error", 77 | condition="none", 78 | show_list=False, 79 | intersection=True, 80 | ) 81 | id[subset] += res * n if res != "incomplete" else 0 82 | unconstrained.append(du) 83 | constrained.append(dc) 84 | ideal_syntax.append(id) 85 | totals.append(total) 86 | 87 | headers = ["Model"] + ["Standard", "Types"] * len(SUBSETS) 88 | rows = [] 89 | for model, unc, con, id, total in zip( 90 | models, unconstrained, constrained, ideal_syntax, totals 91 | ): 92 | row = [MODEL_NAME_MAP[model]] 93 | for subset in SUBSETS: 94 | row.extend( 95 | ( 96 | # int(total[subset]), 97 | "${}$&$_{{\\downarrow {:.1f}\\%}}$".format( 98 | int(unc[subset] * SUBSET_SIZE[subset]), 99 | (total[subset] - unc[subset]) * 100 / total[subset] 100 | if unc[subset] != 0 101 | else 0, 102 | ), 103 | "$\\textbf{{{}}}$&$_{{\\downarrow {:.1f}\\%}}$".format( 104 | int(con[subset] * SUBSET_SIZE[subset]), 105 | (total[subset] - con[subset]) * 100 / total[subset] 106 | if total[subset] != 0 107 | else 0, 108 | ), 109 | ) 110 | ) 111 | rows.append(row) 112 | if format == "csv": 113 | writer = csv.writer(sys.stdout) 114 | writer.writerow([""] + headers) 115 | writer.writerows(rows) 116 | else: 117 | print(tabulate(rows, headers=headers, tablefmt=format)) 118 | 119 | 120 | if __name__ == "__main__": 121 | fire.Fire(main) 122 | -------------------------------------------------------------------------------- /experiments/main/figures/fig_compiler_perf_syn_tran.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | from collections import defaultdict 4 | 5 | import fire 6 | from tabulate import tabulate 7 | 8 | from experiments.main.analyze_inf_res import main as inf_res 9 | 10 | SUFFIXES = ["_synth", "_translate"] 11 | SUBSETS = ["humaneval", "mbpp"] 12 | SUBSET_SIZE = { 13 | "humaneval": 1, 14 | "mbpp": 1, 15 | } 16 | MODEL_NAME_MAP = { 17 | "google/gemma-2-2b-it": "Gemma 2 2B", 18 | "google/gemma-2-9b-it": "Gemma 2 9B", 19 | "google/gemma-2-27b-it": "Gemma 2 27B", 20 | "deepseek-ai/deepseek-coder-33b-instruct": "DeepSeek Coder 33B", 21 | "codellama/CodeLlama-34b-Instruct-hf": "CodeLlama 34B", 22 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5 32B", 23 | } 24 | 25 | 26 | def main( 27 | format="github", field="compiler_output", subset="humaneval", directory="results" 28 | ): 29 | models = ( 30 | "google/gemma-2-2b-it", 31 | "google/gemma-2-9b-it", 32 | "google/gemma-2-27b-it", 33 | "deepseek-ai/deepseek-coder-33b-instruct", 34 | "codellama/CodeLlama-34b-Instruct-hf", 35 | "Qwen/Qwen2.5-32B-Instruct", 36 | ) 37 | temp = 1 38 | condition = "compiler_output" 39 | unconstrained = [] 40 | constrained = [] 41 | ideal_syntax = [] 42 | for model in models: 43 | du = defaultdict(float) 44 | dc = defaultdict(float) 45 | id = defaultdict(float) 46 | for suffix in SUFFIXES: 47 | seeds = [0] if subset == "mbpp" else [0, 1, 2, 3] 48 | for seed in seeds: 49 | res, n, _ = inf_res( 50 | [ 51 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl" 52 | ], 53 | field=field, 54 | condition=condition, 55 | show_list=False, 56 | intersection=field == "compiler_output", 57 | ) 58 | du[suffix] += res * n if res != "incomplete" else 0 59 | res, n, _ = inf_res( 60 | [ 61 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 62 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_c.jsonl", 63 | ], 64 | field=field, 65 | condition=condition, 66 | show_list=False, 67 | intersection=field == "compiler_output", 68 | ) 69 | dc[suffix] += res * n if res != "incomplete" else 0 70 | res, n, _ = inf_res( 71 | [ 72 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 73 | ], 74 | field="syntax_error", 75 | condition="none", 76 | show_list=False, 77 | intersection=True, 78 | ) 79 | id[suffix] += res * n if res != "incomplete" else 0 80 | unconstrained.append(du) 81 | constrained.append(dc) 82 | ideal_syntax.append(id) 83 | 84 | headers = ["", "Model"] + ["Standard", "Syntax", "Types", ""] * len(SUFFIXES) 85 | rows = [] 86 | for model, unc, con, id in zip(models, unconstrained, constrained, ideal_syntax): 87 | row = ["", MODEL_NAME_MAP[model]] 88 | for suffix in SUFFIXES: 89 | row.extend( 90 | ( 91 | int(unc[suffix] * SUBSET_SIZE[subset]), 92 | "${}$&$_{{\\downarrow {:.1f}\\%}}$".format( 93 | int((unc[suffix] - id[suffix]) * SUBSET_SIZE[subset]), 94 | id[suffix] * 100 / unc[suffix] if unc[suffix] != 0 else 0, 95 | ), 96 | "$\\textbf{{{}}}$&$_{{\\downarrow {:.1f}\\%}}$".format( 97 | int(con[suffix] * SUBSET_SIZE[subset]), 98 | (unc[suffix] - con[suffix]) * 100 / unc[suffix] 99 | if unc[suffix] != 0 100 | else 0, 101 | ), 102 | "", 103 | ) 104 | ) 105 | row.pop(-1) 106 | rows.append(row) 107 | if format == "csv": 108 | writer = csv.writer(sys.stdout) 109 | writer.writerow([""] + headers) 110 | writer.writerows(rows) 111 | else: 112 | print(tabulate(rows, headers=headers, tablefmt=format)) 113 | 114 | 115 | if __name__ == "__main__": 116 | fire.Fire(main) 117 | -------------------------------------------------------------------------------- /experiments/main/figures/fig_compiler_time.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | from collections import defaultdict 4 | from statistics import median 5 | 6 | import fire 7 | from tabulate import tabulate 8 | 9 | from experiments.main.analyze_avg_time import main as inf_res 10 | 11 | SUFFIXES = ["_synth", "_translate"] 12 | SUBSETS = ["humaneval", "mbpp"] 13 | SUBSET_SIZE = { 14 | "humaneval": 1, 15 | "mbpp": 1, 16 | } 17 | MODEL_NAME_MAP = { 18 | "google/gemma-2-2b-it": "Gemma 2 2B", 19 | "google/gemma-2-9b-it": "Gemma 2 9B", 20 | "google/gemma-2-27b-it": "Gemma 2 27B", 21 | "deepseek-ai/deepseek-coder-33b-instruct": "DeepSeek Coder 33B", 22 | "codellama/CodeLlama-34b-Instruct-hf": "CodeLlama 34B", 23 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5 32B", 24 | } 25 | 26 | 27 | def main(format="github", field="time_taken", suffix="_synth", directory="results"): 28 | models = ( 29 | "google/gemma-2-2b-it", 30 | "google/gemma-2-9b-it", 31 | "google/gemma-2-27b-it", 32 | "deepseek-ai/deepseek-coder-33b-instruct", 33 | "codellama/CodeLlama-34b-Instruct-hf", 34 | "Qwen/Qwen2.5-32B-Instruct", 35 | ) 36 | temp = 1 37 | condition = "none" 38 | unconstrained = [] 39 | constrained = [] 40 | totals = [] 41 | for model in models: 42 | du = defaultdict(list) 43 | dc = defaultdict(list) 44 | total = defaultdict(int) 45 | for subset in SUBSETS: 46 | seeds = [0] if subset != "humaneval" or "repair" in suffix else [0, 1, 2, 3] 47 | for seed in seeds: 48 | res, n, _ = inf_res( 49 | [ 50 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl" 51 | ], 52 | field=field, 53 | condition=condition, 54 | ) 55 | if res == "incomplete": 56 | total[subset] = -1 57 | break 58 | total[subset] += n 59 | du[subset] += res if res != "incomplete" else [] 60 | res, n, _ = inf_res( 61 | [ 62 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 63 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_c.jsonl", 64 | ], 65 | field=field, 66 | condition=condition, 67 | ) 68 | if res == "incomplete": 69 | total[subset] = -1 70 | break 71 | dc[subset] += res if res != "incomplete" else [] 72 | unconstrained.append(du) 73 | constrained.append(dc) 74 | totals.append(total) 75 | 76 | # print( 77 | # "Average additional time taken for constrained synthesis compared to unconstrained synthesis" 78 | # ) 79 | # for subset in SUBSETS: 80 | # adds = [] 81 | # for model, unc, con, total in zip(models, unconstrained, constrained, totals): 82 | # if con[subset] and unc[subset]: 83 | # adds.append(median(con[subset]) * 100 / median(unc[subset]) - 100) 84 | # else: 85 | # adds.append(-1) 86 | # print(f"{subset}: {sum(adds)/len(adds):.1f}%") 87 | 88 | headers = ["Model", "HumanEval", "MBPP"] 89 | rows = [] 90 | for model, unc, con, total in zip(models, unconstrained, constrained, totals): 91 | row = [MODEL_NAME_MAP[model]] 92 | for subset in SUBSETS: 93 | row.extend( 94 | ( 95 | # "{:.1f}".format(median(unc[suffix])) if total[suffix] else -1, 96 | "${:.1f}$&$_{{\\uparrow {:.1f}\\%}}$".format( 97 | median(con[subset]), 98 | median(con[subset]) * 100 / median(unc[subset]) - 100, 99 | ) 100 | if total[subset] != -1 101 | else -1, 102 | ) 103 | ) 104 | rows.append(row) 105 | if format == "csv": 106 | writer = csv.writer(sys.stdout) 107 | writer.writerow([""] + headers) 108 | writer.writerows(rows) 109 | else: 110 | print(tabulate(rows, headers=headers, tablefmt=format)) 111 | 112 | 113 | if __name__ == "__main__": 114 | fire.Fire(main) 115 | -------------------------------------------------------------------------------- /experiments/main/figures_revision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/experiments/main/figures_revision/__init__.py -------------------------------------------------------------------------------- /experiments/main/figures_revision/fig_compiler_perf_syn_tran_repair.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | from collections import defaultdict 4 | 5 | import fire 6 | from tabulate import tabulate 7 | 8 | from experiments.main.analyze_inf_res import main as inf_res 9 | 10 | SUFFIXES = ["_synth", "_translate", "_repair-all"] 11 | SUBSETS = ["humaneval", "mbpp"] 12 | SUBSET_SIZE_REPAIR = { 13 | "humaneval": 309, 14 | "mbpp": 317, 15 | } 16 | MODEL_NAME_MAP = { 17 | "google/gemma-2-2b-it": "Gemma 2 2B", 18 | "google/gemma-2-9b-it": "Gemma 2 9B", 19 | "google/gemma-2-27b-it": "Gemma 2 27B", 20 | "deepseek-ai/deepseek-coder-33b-instruct": "DS Coder 33B", 21 | "codellama/CodeLlama-34b-Instruct-hf": "CodeLlama 34B", 22 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5 32B", 23 | } 24 | 25 | 26 | def main( 27 | format="github", field="compiler_output", subset="humaneval", directory="results" 28 | ): 29 | models = ( 30 | "google/gemma-2-2b-it", 31 | "google/gemma-2-9b-it", 32 | "google/gemma-2-27b-it", 33 | "deepseek-ai/deepseek-coder-33b-instruct", 34 | "codellama/CodeLlama-34b-Instruct-hf", 35 | "Qwen/Qwen2.5-32B-Instruct", 36 | ) 37 | temp = 1 38 | condition = "compiler_output" 39 | unconstrained = [] 40 | constrained = [] 41 | ideal_syntax = [] 42 | for model in models: 43 | du = defaultdict(float) 44 | dc = defaultdict(float) 45 | id = defaultdict(float) 46 | for suffix in SUFFIXES: 47 | seeds = [0] if subset == "mbpp" or suffix == "_repair-all" else [0, 1, 2, 3] 48 | discard_unconstrained, discard_constrained = False, False 49 | for seed in seeds: 50 | res, n, _ = inf_res( 51 | [ 52 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl" 53 | ], 54 | field=field, 55 | condition=condition, 56 | show_list=False, 57 | intersection=field == "compiler_output", 58 | ) 59 | if res == "incomplete": 60 | discard_unconstrained = True 61 | discard_constrained = True 62 | break 63 | du[suffix] += res * n 64 | res, n, _ = inf_res( 65 | [ 66 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 67 | ], 68 | field="syntax_error", 69 | condition="none", 70 | show_list=False, 71 | intersection=True, 72 | ) 73 | if res == "incomplete": 74 | discard_unconstrained = True 75 | break 76 | id[suffix] += res * n 77 | res, n, _ = inf_res( 78 | [ 79 | f"{directory}/{subset}_{model.replace('/','_')}_s={seed}_t={temp}{suffix}_nc.jsonl", 80 | f"{directory}/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{suffix}_c.jsonl", 81 | ], 82 | field=field, 83 | condition=condition, 84 | show_list=False, 85 | intersection=field == "compiler_output", 86 | ) 87 | if res == "incomplete": 88 | discard_constrained = True 89 | break 90 | dc[suffix] += res * n if res != "incomplete" else 0 91 | if discard_unconstrained: 92 | du[suffix] = -1 93 | id[suffix] = -1 94 | if discard_constrained: 95 | dc[suffix] = -1 96 | unconstrained.append(du) 97 | constrained.append(dc) 98 | ideal_syntax.append(id) 99 | 100 | # print("Average improvement compiler errors (over Models)") 101 | # for suffix in SUFFIXES: 102 | # improvements_syntax = [] 103 | # improvements_constrained = [] 104 | # for model, unc, con, id in zip( 105 | # models, unconstrained, constrained, ideal_syntax 106 | # ): 107 | # improvements_syntax.append( 108 | # id[suffix] * 100 / unc[suffix] if unc[suffix] != 0 else 0 109 | # ) 110 | # improvements_constrained.append( 111 | # (unc[suffix] - con[suffix]) * 100 / unc[suffix] 112 | # if unc[suffix] != 0 113 | # else 0 114 | # ) 115 | # print( 116 | # f"{suffix}, Syntax: {sum(improvements_syntax) / len(improvements_syntax):.1f}%" 117 | # ) 118 | # print( 119 | # f"{suffix}, Types: {sum(improvements_constrained) / len(improvements_constrained):.1f}%" 120 | # ) 121 | 122 | # print("Repair improvements") 123 | # suffix = "_repair-all" 124 | # subset_size = SUBSET_SIZE_REPAIR[subset] 125 | # for model, unc, con, id in zip(models, unconstrained, constrained, ideal_syntax): 126 | # print( 127 | # f"{model}, Standard: {(subset_size - unc[suffix]) * 100 / subset_size:.1f}%, Syntax: {(subset_size - (unc[suffix] - id[suffix])) * 100 / subset_size:.1f}%, Types: {(subset_size - con[suffix]) * 100 / subset_size:.1f}%" 128 | # ) 129 | 130 | headers = ["", "Model"] + ["Standard", "Syntax", r"Types", ""] * len(SUFFIXES) 131 | rows = [] 132 | for model, unc, con, id in zip(models, unconstrained, constrained, ideal_syntax): 133 | row = ["", MODEL_NAME_MAP[model]] 134 | for suffix in SUFFIXES: 135 | row.extend( 136 | ( 137 | int(unc[suffix]) if unc[suffix] != 0 else -1, 138 | "${}$&$_{{\\downarrow {:.1f}\\%}}$".format( 139 | int((unc[suffix] - id[suffix])), 140 | id[suffix] * 100 / unc[suffix] if unc[suffix] != 0 else 0, 141 | ) 142 | if unc[suffix] != -1 and id[suffix] != -1 143 | else -1, 144 | "$\\textbf{{{}}}$&$_{{\\downarrow {:.1f}\\%}}$".format( 145 | int(con[suffix]), 146 | (unc[suffix] - con[suffix]) * 100 / unc[suffix] 147 | if unc[suffix] != 0 148 | else 0, 149 | ) 150 | if unc[suffix] != -1 and con[suffix] != -1 151 | else -1, 152 | "", 153 | ) 154 | ) 155 | row.pop(-1) 156 | rows.append(row) 157 | if format == "csv": 158 | writer = csv.writer(sys.stdout) 159 | writer.writerow([""] + headers) 160 | writer.writerows(rows) 161 | else: 162 | print(tabulate(rows, headers=headers, tablefmt=format)) 163 | 164 | 165 | if __name__ == "__main__": 166 | fire.Fire(main) 167 | -------------------------------------------------------------------------------- /experiments/main/figures_revision/fig_resample_hist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import json 4 | from math import ceil, log2 5 | from statistics import median 6 | from typing import Literal 7 | 8 | import numpy as np 9 | from transformers import AutoTokenizer 10 | 11 | from experiments.main.util import cutoff, invalid_mbpp_instances, extract_code 12 | 13 | 14 | @functools.lru_cache(maxsize=None) 15 | def load_output_file(outputs_file): 16 | res_dict = {} 17 | try: 18 | with open(outputs_file, "r") as f: 19 | outputs = [] 20 | for i, line in enumerate(f): 21 | # print(i) 22 | outputs.append(json.loads(line)) 23 | except Exception: 24 | outputs = [] 25 | for output in outputs: 26 | if output["instance_id"] in invalid_mbpp_instances: 27 | continue 28 | res_dict[output["instance_id"]] = output 29 | return res_dict 30 | 31 | 32 | def main( 33 | outputs_files, 34 | mode: Literal["resample", "correction"] = "resample", 35 | style: Literal["ascii", "latex", "plain", "matplotlib"] = "latex", 36 | ): 37 | if len(outputs_files) != 4: 38 | print("Not all results have completed generation, some seeds appear missing.") 39 | exit() 40 | outputs_by_instance = {} 41 | instances = set() 42 | for outputs_file in outputs_files: 43 | outputs = load_output_file(outputs_file) 44 | if len(outputs) != 159: 45 | print( 46 | f"File {outputs_file} has {len(outputs)} outputs, not 159, likely did not finish generation yet!." 47 | ) 48 | exit() 49 | outputs_by_instance[outputs_file] = outputs 50 | instances.update(outputs.keys()) 51 | tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it") 52 | counted_instances = 0 53 | instances_needing_resample = 0 54 | amounts_of_resamples = [] 55 | for file_name in outputs_files: 56 | if file_name.endswith("_nc.jsonl"): 57 | continue 58 | for instance_id in sorted(instances): 59 | outputs_by_instance_f = outputs_by_instance.get(file_name, {}) 60 | instance = outputs_by_instance_f.get(instance_id, {}) 61 | if not instance: 62 | continue 63 | resamples = instance.get("resamples", None) 64 | cutoffed_code = cutoff(extract_code(instance["code"], "TypeScript", 0)) 65 | pos_of_cutoff = instance["code"].find(cutoffed_code) 66 | len_of_cutoffed = pos_of_cutoff + len(cutoffed_code) 67 | num_tokens_in_cutoffed = len( 68 | tokenizer.encode(cutoffed_code, add_special_tokens=False) 69 | ) 70 | if resamples is None: 71 | amounts_of_resamples.extend([0] * num_tokens_in_cutoffed) 72 | continue 73 | counted_instances += 1 74 | # check if unconstrained version resolved correctly -> doesnt need resample 75 | unconstrained_file_version = file_name.replace("_c.jsonl", "_nc.jsonl") 76 | if ( 77 | not load_output_file(unconstrained_file_version) 78 | .get(instance_id, {}) 79 | .get("compiler_output", True) 80 | ): 81 | amounts_of_resamples.extend([0] * num_tokens_in_cutoffed) 82 | continue 83 | # otherwise determine if resample occurred in relevant part of code 84 | resamples = [x for x in resamples if x[0] <= len_of_cutoffed] 85 | amounts_of_resamples.extend( 86 | [x[1] for x in resamples] 87 | + [0] * (num_tokens_in_cutoffed - len(resamples)) 88 | ) 89 | if resamples: 90 | instances_needing_resample += len(resamples) > 0 91 | 92 | print( 93 | f"Instances needing {mode}: {instances_needing_resample/counted_instances*100:.2f}% ({instances_needing_resample}/{counted_instances})" 94 | ) 95 | print( 96 | f"Tokens needing {mode}: {len(tuple(x for x in amounts_of_resamples if x > 0))/len(amounts_of_resamples)*100:.2f}% ({len(tuple(x for x in amounts_of_resamples if x > 0))}/{len(amounts_of_resamples)})" 97 | ) 98 | print( 99 | f"Average amount of {mode}s: {sum(amounts_of_resamples)/len(amounts_of_resamples):.2f}" 100 | ) 101 | print(f"Median amount of {mode}s: {median(amounts_of_resamples):.2f}") 102 | print(f"Max amount of {mode}s: {max(amounts_of_resamples)}") 103 | print(f"Histogram of {mode} amounts:") 104 | 105 | if style == "ascii": 106 | asciihist( 107 | amounts_of_resamples, 108 | bins=10, 109 | minmax="auto", 110 | str_tag="Resample", 111 | scale_output=30, 112 | ) 113 | elif style == "latex": 114 | latex_hist(amounts_of_resamples) 115 | elif style == "plain": 116 | notsosimplehist(amounts_of_resamples) 117 | elif style == "matplotlib": 118 | matplotlib_hist(amounts_of_resamples) 119 | 120 | 121 | def matplotlib_hist(amounts_of_resamples): 122 | import matplotlib.pyplot as plt 123 | 124 | # show as histogram with logarithmic scale 125 | 126 | plt.hist(amounts_of_resamples, bins=30) 127 | plt.yscale("log") 128 | plt.show() 129 | 130 | 131 | def latex_hist(amounts_of_resamples): 132 | hist_template = r""" 133 | \begin{figure} 134 | \centering 135 | \resizebox{\textwidth}{!}{ 136 | \begin{tikzpicture} 137 | \begin{axis}[ 138 | ybar, 139 | width=7cm, 140 | height=6cm, 141 | ylabel={Count}, 142 | xlabel={Amount of Resamples}, 143 | ymin=0, 144 | ymax=1300, 145 | bar width=0.1cm, 146 | x=0.1cm, 147 | ] 148 | \addplot coordinates {%s}; 149 | \end{axis} 150 | \end{tikzpicture} 151 | } 152 | \caption{Histogram of Resample Amounts} 153 | \label{fig:resample_histogram} 154 | \end{figure} 155 | """ 156 | hist = {} 157 | bucket_size = 1 158 | for amount in amounts_of_resamples: 159 | hist[amount // bucket_size] = hist.get(amount // bucket_size, 0) + 1 160 | min_key = min(hist.keys()) 161 | max_key = max(hist.keys()) 162 | for i in range(min_key, max_key): 163 | hist[i] = hist.get(i, 0) 164 | # x_coords = ", ".join(str(k) for k in sorted(hist)) 165 | y_coords = " ".join( 166 | str((k + 1, v if v > 1 else 1.1)) for k, v in sorted(hist.items()) if v > 0 167 | ) 168 | print(hist_template % (y_coords)) 169 | 170 | 171 | def simplehist(amounts_of_resamples): 172 | hist = {} 173 | bucket_size = 1 174 | for amount in amounts_of_resamples: 175 | hist[amount // bucket_size] = hist.get(amount // bucket_size, 0) + 1 176 | for k, v in sorted(hist.items()): 177 | print(f"| {k} | {v} | `{'-' * v}` |") 178 | 179 | 180 | def notsosimplehist( 181 | amounts_of_resamples, 182 | scale=(lambda x: log2(x)), 183 | ticks=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 184 | max_key=30, 185 | ): 186 | hist = {} 187 | bucket_size = 1 188 | for amount in amounts_of_resamples: 189 | hist[amount // bucket_size] = hist.get(amount // bucket_size, 0) + 1 190 | max_height = int(ceil(scale(max(hist.values())))) 191 | min_key = min(hist.keys()) 192 | actual_max_key = max(hist.keys()) 193 | for i in range(min_key, max_key): 194 | hist[i] = hist.get(i, 0.0001) 195 | for i in range(max_key, actual_max_key + 1): 196 | if i in hist: 197 | del hist[i] 198 | max_x_width = max(max(len(str(k)) for k in hist), max(len(str(t)) for t in ticks)) 199 | lines = [] 200 | # draw lines 201 | for i in range(max_height, 0, -1): 202 | line = "" 203 | if i in ticks: 204 | line += "2^{:<{}d}-".format(i, max_x_width) 205 | else: 206 | line += " " * max_x_width + " " 207 | line += "| " 208 | plot = [] 209 | for k, v in sorted(hist.items()): 210 | if int(scale(v)) >= i: 211 | plot.append("❚") 212 | elif int(ceil(scale(v))) >= i or (v == 1 and i == 1): 213 | plot.append(".") 214 | else: 215 | plot.append(" ") 216 | lines.append(line + ((" ") * max_x_width).join(plot)) 217 | # draw x-axis 218 | lines.append("-" * (len(hist) * (max_x_width + 1)) + "------") 219 | # draw x-ticks 220 | lines.append( 221 | (" ") * max_x_width 222 | + " " 223 | + " ".join("{:{}d}".format(k, max_x_width) for k in sorted(hist)) 224 | ) 225 | 226 | for line in lines: 227 | print(line) 228 | 229 | 230 | def gnuplothist(amounts_of_resamples, mode): 231 | print(amounts_of_resamples) 232 | import gnuplotlib as gp 233 | import numpy as np 234 | 235 | gp.plot( 236 | ( 237 | np.array(amounts_of_resamples), 238 | { 239 | "histogram": True, 240 | "binwidth": 1, 241 | }, 242 | ), 243 | _with="boxes", 244 | unset=["grid"], 245 | terminal="dumb 180,20", 246 | set=["boxwidth 0.25", "style fill solid"], 247 | _xmin=0, 248 | _xmax=60, 249 | title=f"Histogram of {mode} amounts", 250 | xlabel="Amount", 251 | ylabel="Count", 252 | ) 253 | 254 | 255 | def asciihist( 256 | it, 257 | bins=10, 258 | minmax=None, 259 | str_tag="", 260 | scale_output=30, 261 | generate_only=False, 262 | print_function=print, 263 | ): 264 | """Create an ASCII histogram from an interable of numbers. 265 | Author: Boris Gorelik boris@gorelik.net. based on http://econpy.googlecode.com/svn/trunk/pytrix/pytrix.py 266 | License: MIT 267 | """ 268 | ret = [] 269 | itarray = np.asanyarray(it) 270 | if minmax == "auto": 271 | minmax = np.percentile(it, [5, 95]) 272 | if minmax[0] == minmax[1]: 273 | # for very ugly distributions 274 | minmax = None 275 | if minmax is not None: 276 | # discard values that are outside minmax range 277 | mn = minmax[0] 278 | mx = minmax[1] 279 | itarray = itarray[itarray >= mn] 280 | itarray = itarray[itarray <= mx] 281 | if itarray.size: 282 | total = len(itarray) 283 | counts, cutoffs = np.histogram(itarray, bins=bins) 284 | cutoffs = cutoffs[1:] 285 | if str_tag: 286 | str_tag = "%s " % str_tag 287 | else: 288 | str_tag = "" 289 | if scale_output is not None: 290 | scaled_counts = counts.astype(float) / counts.sum() * scale_output 291 | else: 292 | scaled_counts = counts 293 | 294 | if minmax is not None: 295 | ret.append("Trimmed to range (%s - %s)" % (str(minmax[0]), str(minmax[1]))) 296 | for cutoff, original_count, scaled_count in zip(cutoffs, counts, scaled_counts): 297 | ret.append( 298 | "{:s}{:>8.2f} |{:<7,d} | {:s}".format( 299 | str_tag, cutoff, original_count, "*" * int(scaled_count) 300 | ) 301 | ) 302 | ret.append("{:s}{:s} |{:s} | {:s}".format(str_tag, "-" * 8, "-" * 7, "-" * 7)) 303 | ret.append("{:s}{:>8s} |{:<7,d}".format(str_tag, "N=", total)) 304 | else: 305 | ret = [] 306 | if not generate_only: 307 | for line in ret: 308 | print_function(line) 309 | ret = "\n".join(ret) 310 | return ret 311 | 312 | 313 | if __name__ == "__main__": 314 | parser = argparse.ArgumentParser() 315 | parser.add_argument("files", nargs="*") 316 | parser.add_argument( 317 | "--mode", choices=["resample", "correction"], default="resample" 318 | ) 319 | parser.add_argument( 320 | "--style", 321 | choices=["ascii", "latex", "plain", "matplotlib"], 322 | default="matplotlib", 323 | ) 324 | args = parser.parse_args() 325 | main( 326 | args.files, 327 | args.mode, 328 | args.style, 329 | ) 330 | -------------------------------------------------------------------------------- /experiments/main/figures_revision/fig_resample_hist.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # DIR is either the parameter passed to this script or the result directory 3 | DIR="${1:-results}" 4 | python3 figures_revision/fig_resample_hist.py "${DIR}"/humaneval_google_gemma-2-2b-it_s=*_t=1_synth_c.jsonl --style latex 5 | -------------------------------------------------------------------------------- /experiments/main/filter_sensible_ts_outputs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import sys 4 | from tempfile import NamedTemporaryFile 5 | 6 | from tqdm import tqdm 7 | 8 | target_file = sys.argv[1] 9 | 10 | with open(target_file, "r") as f: 11 | lines = f.readlines() 12 | lines = [json.loads(line) for line in lines] 13 | 14 | stats = { 15 | "total": 0, 16 | "found_code": 0, 17 | "compiled": 0, 18 | } 19 | for line in tqdm(lines): 20 | stats["total"] += 1 21 | # extract code from code block 22 | code = line["translation"] 23 | code = code.split("```typescript")[1].strip() 24 | code = code.split("```")[0].strip() 25 | if not code: 26 | continue 27 | stats["found_code"] += 1 28 | # try to compile the code 29 | with NamedTemporaryFile("w", suffix=".ts") as f: 30 | f.write(code) 31 | try: 32 | subprocess.run( 33 | ["npx", "tsc", "--noEmit", f.name], 34 | text=True, 35 | check=True, 36 | capture_output=True, 37 | ) 38 | except subprocess.CalledProcessError: 39 | continue 40 | stats["compiled"] += 1 41 | print(json.dumps({"task_id": line["task_id"], "translation": code})) 42 | -------------------------------------------------------------------------------- /experiments/main/fix_nc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import sys 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | from datasets import load_dataset 7 | 8 | from experiments.main.util import ( 9 | extract_code, 10 | cutoff, 11 | go_compiles_passes, 12 | ) 13 | 14 | 15 | def process_file(file_path, tests_by_inst): 16 | try: 17 | if str(file_path).endswith(".bak"): 18 | return 19 | backup_path = file_path + ".bak" 20 | shutil.copy(file_path, backup_path) 21 | 22 | with open(file_path, "r") as f: 23 | lines = f.readlines() 24 | 25 | results = [] 26 | for line in lines: 27 | data = json.loads(line) 28 | if not data["compiled"]: 29 | extracted = cutoff(extract_code(data["code"], "Go", 0)) 30 | tests = tests_by_inst[data["instance_id"]] 31 | test_result = go_compiles_passes(extracted, tests, timeout=300) 32 | data["syntax_ok"] = test_result.syntax_ok 33 | data["compiled"] = test_result.compiled 34 | data["compiler_output"] = ( 35 | test_result.error_message if not test_result.compiled else None 36 | ) 37 | data["tests_passed"] = test_result.passed 38 | data["test_output"] = ( 39 | test_result.error_message if not test_result.passed else None 40 | ) 41 | results.append(json.dumps(data)) 42 | 43 | with open(file_path, "w") as f: 44 | f.write("\n".join(results)) 45 | except Exception as e: 46 | print(e) 47 | 48 | 49 | def main(): 50 | files = sys.argv[1:] 51 | dataset_name = "THUDM/humaneval-x" 52 | dataset = load_dataset(dataset_name, "go")["test"] 53 | tests_by_inst = {x["task_id"]: x["test"] for x in dataset} 54 | 55 | with ThreadPoolExecutor() as executor: 56 | executor.map(process_file, files, [tests_by_inst] * len(files)) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /experiments/main/inference_multiple.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from the model constrained or unconstrained 3 | 4 | Can also simulate a repair setting 5 | """ 6 | 7 | import json 8 | import multiprocessing 9 | import os 10 | import time 11 | import traceback 12 | 13 | import fire 14 | import torch 15 | from tqdm import tqdm 16 | from transformers import AutoModelForCausalLM, AutoTokenizer 17 | 18 | from experiments.main.util import ( 19 | extract_code, 20 | cutoff, 21 | tsx_compiles, 22 | passes_tests_js, 23 | ) 24 | from typesafe_llm.parser.parser_ts import parse_ts_program 25 | from typesafe_llm.sampling import sample_constrained 26 | from datasets import load_dataset 27 | import re 28 | 29 | LANGUAGE_SUBSET_MAP = { 30 | ("typescript", "humaneval"): "humaneval-ts", 31 | ("rust", "humaneval"): "humaneval-rs", 32 | ("typescript", "mbpp"): "mbpp-ts", 33 | ("rust", "mbpp"): "mbpp-rs", 34 | } 35 | TRANSLATION_SUBSET_MAP = { 36 | "python": "python", 37 | "typescript": "ts", 38 | "rust": "rs", 39 | "c++": "cpp", 40 | "cpp": "cpp", 41 | } 42 | 43 | LANGUAGE_PARSER_MAP = { 44 | "typescript": parse_ts_program, 45 | "rust": None, 46 | } 47 | LANGUAGE_COMPILER_MAP = { 48 | "typescript": tsx_compiles, 49 | } 50 | LANGUAGE_TEST_MAP = { 51 | "typescript": passes_tests_js, 52 | } 53 | LANGUAGE_PREFIX_MAP = { 54 | "typescript": "", 55 | "rust": "", 56 | } 57 | 58 | 59 | def TRANSLATION_SYSTEM_PROMPT( 60 | human_readable_source_lang: str, human_readable_target_lang: str 61 | ): 62 | return f""" 63 | You are a helpful and expert programmer in {human_readable_source_lang} and {human_readable_target_lang}. You will be given an input program in {human_readable_source_lang} and your task is to translate this program into {human_readable_target_lang}. You may assume that the input program is correct and that the translation should be semantically equivalent. Do not translate word by word and be careful about difference of language features between {human_readable_source_lang} and {human_readable_target_lang}. 64 | When answering, insert the solution code in a ```{human_readable_target_lang.lower()}...``` block. 65 | """ 66 | 67 | 68 | def TRANSLATION_PROMPT( 69 | human_readable_source_lang, src_prog, human_readable_target_lang 70 | ): 71 | return f"The following is the source program in {human_readable_source_lang}:\n```{human_readable_source_lang.lower()}\n{src_prog}\n```\n\nPlease translate the source program to {human_readable_target_lang}." 72 | 73 | 74 | def SYNTHESIS_SYSTEM_PROMPT(human_readable_target_lang: str): 75 | return f""" 76 | You are an expert in {human_readable_target_lang} programming. Solve the given problem by writing solution code in {human_readable_target_lang}. 77 | When answering, insert the solution code in a ```{human_readable_target_lang.lower()}...``` block. 78 | """ 79 | 80 | 81 | def format_prompt_to_question(prompt: str): 82 | user_input = [] 83 | split = prompt.splitlines() 84 | for i, line in enumerate(split): 85 | if line.startswith("//"): 86 | user_input.append(line[len("//") :].strip()) 87 | else: 88 | break 89 | first_code_line = "\n".join(split[i:]) 90 | return "\n".join(user_input), first_code_line 91 | 92 | 93 | def main( 94 | model_name="microsoft/Phi-3.5-mini-instruct", 95 | device="cuda", 96 | language="TypeScript", 97 | subset="humaneval", 98 | split="test", 99 | temp=0, 100 | seed=0, 101 | max_tokens=1000, 102 | timeout=300, 103 | output_file="multiple_outputs.jsonl", 104 | trace=False, 105 | constrained=False, 106 | limit=1000, 107 | input_file=None, 108 | repair=False, 109 | task_id=None, 110 | translate=False, 111 | translation_source_lang=None, 112 | reraise=False, 113 | ): 114 | try_top_k = 10000000000000 115 | if isinstance(task_id, int): 116 | task_id = str(task_id) 117 | if isinstance(task_id, str): 118 | # task ids always converted to a tuple of taskids 119 | task_id = (task_id,) 120 | dataset_name = "nuprl/MultiPL-E" 121 | human_readable_target_lang = language 122 | language = language.lower() 123 | dataset = load_dataset(dataset_name, LANGUAGE_SUBSET_MAP[language, subset])[split] 124 | assert not (repair and translate), "Must either choose translate or repair" 125 | 126 | # load code to repair in repair setting 127 | last_iteration = dict() 128 | if repair: 129 | assert os.path.exists(input_file), "Must provide an input file for repair" 130 | if os.path.exists(input_file): 131 | with open(input_file, "r") as f: 132 | for line in f: 133 | output = json.loads(line) 134 | last_iteration[output["instance_id"]] = output 135 | # load original code for translation in translate setting 136 | if translate: 137 | assert os.path.exists( 138 | input_file 139 | ), "Must provide an input file for translation (contains source language)" 140 | assert ( 141 | translation_source_lang is not None 142 | ), "Must provide a source language for translation" 143 | with open(input_file) as f: 144 | raw_translation_dataset = json.load(f) 145 | translation_src_dataset = raw_translation_dataset[ 146 | TRANSLATION_SUBSET_MAP[translation_source_lang.lower()] 147 | ] 148 | 149 | # load already inferred stuff 150 | already_done = set() 151 | if os.path.exists(output_file) and output_file not in ("/dev/stdout", "-"): 152 | with open(output_file, "r") as f: 153 | for i, line in enumerate(f): 154 | output = json.loads(line) 155 | already_done.add(output["instance_id"]) 156 | 157 | tokenizer = None 158 | model = None 159 | system_messages = [ 160 | { 161 | "role": "system", 162 | "content": ( 163 | SYNTHESIS_SYSTEM_PROMPT( 164 | human_readable_target_lang=human_readable_target_lang, 165 | ) 166 | if not translate 167 | else TRANSLATION_SYSTEM_PROMPT( 168 | human_readable_target_lang=human_readable_target_lang, 169 | human_readable_source_lang=translation_source_lang, 170 | ) 171 | ) 172 | + ( 173 | "\nDo not include test cases in the code." 174 | if "Qwen" in model_name 175 | else "" 176 | ), 177 | }, 178 | ] 179 | subset_prefix = "HumanEval" if subset == "humaneval" else "mbpp" 180 | with multiprocessing.Pool(1) as pool: 181 | # run through all instances 182 | for instance in tqdm(sorted(dataset, key=lambda x: x["name"])[:limit]): 183 | if instance["name"] in already_done and task_id is None: 184 | continue 185 | if task_id is not None and not any( 186 | f"{subset_prefix}_{tid}_" in instance["name"] for tid in task_id 187 | ): 188 | continue 189 | if tokenizer is None or model is None: 190 | tokenizer = AutoTokenizer.from_pretrained(model_name) 191 | kwargs = ( 192 | { 193 | "device_map": "auto", 194 | "torch_dtype": torch.bfloat16, 195 | "attn_implementation": "flash_attention_2", 196 | } 197 | if device == "cuda" 198 | else {"device_map": device} 199 | ) 200 | 201 | model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) 202 | model.eval() 203 | user, first_code_line = format_prompt_to_question(instance["prompt"]) 204 | if translate: 205 | instance_num = re.findall( 206 | rf"{subset_prefix}_(\d+)_*", instance["name"] 207 | )[0] 208 | user = TRANSLATION_PROMPT( 209 | human_readable_source_lang=translation_source_lang, 210 | src_prog=translation_src_dataset[instance_num]["prompt"], 211 | human_readable_target_lang=human_readable_target_lang, 212 | ) 213 | messages = system_messages + [ 214 | {"role": "user", "content": user}, 215 | ] 216 | if "octocoder" in model_name: 217 | chat_template = """\ 218 | {%- for message in messages %} 219 | {%- if message['role'] == 'user' %} 220 | {{- 'Question: ' + message['content'].strip() + '\\n\\n' }} 221 | {%- elif message['role'] == 'system' %} 222 | {{- 'System: ' + message['content'].strip() + '\\n\\n' }} 223 | {%- elif message['role'] == 'assistant' %} 224 | {{- 'Answer: ' + message['content'] + '\\n\\n' }} 225 | {%- endif %} 226 | {%- endfor %}""" 227 | tokenizer.chat_template = chat_template 228 | try: 229 | tokenizer.apply_chat_template( 230 | messages, tokenize=False, add_generation_prompt=True 231 | ) 232 | except Exception: 233 | messages[1]["content"] = ( 234 | messages[0]["content"] + "\n\n" + messages[1]["content"] 235 | ) 236 | messages.pop(0) 237 | if repair: 238 | last_iteration_res = last_iteration[instance["name"]] 239 | if not last_iteration_res["compiler_output"]: 240 | continue 241 | formatted_last_iteration_res = "\n".join( 242 | f"{i+1:03d}: {s}" 243 | for i, s in enumerate( 244 | cutoff(last_iteration_res["compilable"]).split("\n") 245 | ) 246 | ) 247 | messages.extend( 248 | [ 249 | { 250 | "role": "assistant", 251 | "content": f"```\n{formatted_last_iteration_res}\n```", 252 | }, 253 | { 254 | "role": "user", 255 | "content": f"This output produced an error:\n{last_iteration_res['compiler_output']}\n\nWrite the program again, and make sure to fix the error this time.", 256 | }, 257 | ] 258 | ) 259 | prompt = tokenizer.apply_chat_template( 260 | messages, tokenize=False, add_generation_prompt=True 261 | ) 262 | suffix = f"```{human_readable_target_lang.lower()}\n" 263 | prompt += suffix + first_code_line 264 | start = time.time() 265 | with torch.no_grad(): 266 | code, eos, crashed, resamples = sample_constrained( 267 | device=device, 268 | model_name=model_name, 269 | prompt=prompt, 270 | max_tokens=max_tokens, 271 | temperature=temp, 272 | do_sample=temp != 0, 273 | seed=seed, 274 | constrain_from=(suffix if constrained else None), 275 | constrain_until="```", 276 | trace=trace, 277 | model=model, 278 | tokenizer=tokenizer, 279 | timeout=timeout, 280 | try_top_k=try_top_k, 281 | reraise=reraise, 282 | ) 283 | end = time.time() 284 | time_taken = end - start 285 | extracted = extract_code(code, human_readable_target_lang, 0) 286 | extracted = cutoff(extracted) 287 | tests: str = instance["tests"] 288 | if tests.strip().startswith("}") and extracted.strip().endswith("}"): 289 | tests = tests[tests.find("}") + 1 :] 290 | compilable = extracted + "\n\n" + tests 291 | pool.apply_async( 292 | compile_test_and_dump, 293 | ( 294 | output_file, 295 | { 296 | "dataset": dataset_name, 297 | "language": language, 298 | "split": split, 299 | "instance_id": instance["name"], 300 | "prompt": instance["prompt"], 301 | "constrained": constrained, 302 | "eos": eos, 303 | "crashed": str(crashed), 304 | "model_name": model_name, 305 | "temp": temp, 306 | "max_tokens": max_tokens, 307 | "time_taken": time_taken, 308 | "code": code, 309 | "compilable": compilable, 310 | "trace": trace, 311 | "resamples": resamples, 312 | "timeout": timeout, 313 | }, 314 | ), 315 | ) 316 | pool.close() 317 | pool.join() 318 | 319 | 320 | def compile_test_and_dump(output_file: str, specs: dict): 321 | try: 322 | compiled, compiler_output = LANGUAGE_COMPILER_MAP[specs["language"]]( 323 | specs["compilable"], specs["timeout"] 324 | ) 325 | if compiled is not None: 326 | tests_passed, test_output = LANGUAGE_TEST_MAP[specs["language"]]( 327 | compiled, specs["timeout"] 328 | ) 329 | else: 330 | tests_passed, test_output = None, None 331 | specs["compiled"] = compiled 332 | specs["compiler_output"] = compiler_output 333 | specs["tests_passed"] = tests_passed 334 | specs["test_output"] = test_output 335 | with open(output_file, "a") as f: 336 | print( 337 | json.dumps( 338 | specs, 339 | ), 340 | flush=True, 341 | file=f, 342 | ) 343 | if specs["trace"]: 344 | print("compiler_output:", compiler_output) 345 | print("tests_passed:", tests_passed) 346 | except Exception: 347 | print("WARNING CATASROPHIC FAILURE") 348 | print("RESULTS ARE NOT WRITTEN TO FILE") 349 | traceback.print_exc() 350 | print("", flush=True) 351 | 352 | 353 | if __name__ == "__main__": 354 | fire.Fire(main) 355 | -------------------------------------------------------------------------------- /experiments/main/inference_multiple_repair.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from the model constrained or unconstrained 3 | 4 | Can also simulate a repair setting 5 | """ 6 | 7 | import json 8 | import multiprocessing 9 | import os 10 | import time 11 | import traceback 12 | 13 | import fire 14 | import torch 15 | from tqdm import tqdm 16 | from transformers import AutoModelForCausalLM, AutoTokenizer 17 | 18 | from experiments.main.util import ( 19 | extract_code, 20 | cutoff, 21 | tsx_compiles, 22 | passes_tests_js, 23 | ) 24 | from typesafe_llm.parser.parser_ts import parse_ts_program 25 | from typesafe_llm.sampling import sample_constrained 26 | from datasets import load_dataset 27 | 28 | LANGUAGE_SUBSET_MAP = { 29 | ("typescript", "humaneval"): "humaneval-ts", 30 | ("rust", "humaneval"): "humaneval-rs", 31 | ("typescript", "mbpp"): "mbpp-ts", 32 | ("rust", "mbpp"): "mbpp-rs", 33 | } 34 | TRANSLATION_SUBSET_MAP = { 35 | "python": "python", 36 | "typescript": "ts", 37 | "rust": "rs", 38 | "c++": "cpp", 39 | "cpp": "cpp", 40 | } 41 | 42 | LANGUAGE_PARSER_MAP = { 43 | "typescript": parse_ts_program, 44 | "rust": None, 45 | } 46 | LANGUAGE_COMPILER_MAP = { 47 | "typescript": tsx_compiles, 48 | } 49 | LANGUAGE_TEST_MAP = { 50 | "typescript": passes_tests_js, 51 | } 52 | LANGUAGE_PREFIX_MAP = { 53 | "typescript": "", 54 | "rust": "", 55 | } 56 | 57 | 58 | def TRANSLATION_SYSTEM_PROMPT( 59 | human_readable_source_lang: str, human_readable_target_lang: str 60 | ): 61 | return f""" 62 | You are a helpful and expert programmer in {human_readable_source_lang} and {human_readable_target_lang}. You will be given an input program in {human_readable_source_lang} and your task is to translate this program into {human_readable_target_lang}. You may assume that the input program is correct and that the translation should be semantically equivalent. 63 | When answering, insert the solution code in a ```{human_readable_target_lang.lower()}...``` block. 64 | """ 65 | 66 | 67 | def TRANSLATION_PROMPT( 68 | human_readable_source_lang, src_prog, human_readable_target_lang 69 | ): 70 | return f"The following is the source program in {human_readable_source_lang}:\n```{human_readable_source_lang.lower()}\n{src_prog}\n```\n\nPlease translate the source program to {human_readable_target_lang}." 71 | 72 | 73 | def SYNTHESIS_SYSTEM_PROMPT(human_readable_target_lang: str): 74 | return f""" 75 | You are an expert in {human_readable_target_lang} programming. Solve the given problem by writing solution code in {human_readable_target_lang}. 76 | When answering, insert the solution code in a ```{human_readable_target_lang.lower()}...``` block. 77 | """ 78 | 79 | 80 | def format_prompt_to_question(prompt: str): 81 | user_input = [] 82 | split = prompt.splitlines() 83 | for i, line in enumerate(split): 84 | if line.startswith("//"): 85 | user_input.append(line[len("//") :].strip()) 86 | else: 87 | break 88 | first_code_line = "\n".join(split[i:]) 89 | return "\n".join(user_input), first_code_line 90 | 91 | 92 | def main( 93 | model_name="microsoft/Phi-3.5-mini-instruct", 94 | device="cuda", 95 | language="TypeScript", 96 | subset="humaneval", 97 | split="test", 98 | temp=0, 99 | seed=0, 100 | max_tokens=1000, 101 | timeout=300, 102 | output_file="multiple_outputs.jsonl", 103 | trace=False, 104 | constrained=False, 105 | limit=1000, 106 | input_file=None, 107 | task_id=None, 108 | reraise=False, 109 | ): 110 | try_top_k = 10000000000000 111 | if isinstance(task_id, str): 112 | task_id = (task_id,) 113 | dataset_name = "nuprl/MultiPL-E" 114 | human_readable_target_lang = language 115 | language = language.lower() 116 | orig_dataset = load_dataset(dataset_name, LANGUAGE_SUBSET_MAP[language, subset])[ 117 | split 118 | ] 119 | dataset_by_instance_id = {instance["name"]: instance for instance in orig_dataset} 120 | 121 | # load code to repair in repair setting 122 | assert os.path.exists(input_file), "Must provide an input file for repair" 123 | dataset = [] 124 | if os.path.exists(input_file): 125 | with open(input_file, "r") as f: 126 | for line in f: 127 | output = json.loads(line) 128 | dataset.append(output) 129 | 130 | # load already inferred stuff 131 | already_done = set() 132 | if os.path.exists(output_file) and output_file not in ("/dev/stdout", "-"): 133 | with open(output_file, "r") as f: 134 | for i, line in enumerate(f): 135 | output = json.loads(line) 136 | already_done.add(output["instance_id"]) 137 | 138 | tokenizer = None 139 | model = None 140 | system_messages = [ 141 | { 142 | "role": "system", 143 | "content": ( 144 | SYNTHESIS_SYSTEM_PROMPT( 145 | human_readable_target_lang=human_readable_target_lang, 146 | ) 147 | ), 148 | }, 149 | ] 150 | subset_prefix = "HumanEval" if subset == "humaneval" else "mbpp" 151 | # run through all instances 152 | with multiprocessing.Pool(1) as pool: 153 | for repair_instance in tqdm( 154 | sorted(dataset, key=lambda x: x["instance_id"])[:limit] 155 | ): 156 | orig_instance = dataset_by_instance_id[repair_instance["instance_id"]] 157 | if repair_instance["repair_id"] in already_done and task_id is None: 158 | continue 159 | if task_id is not None and not any( 160 | f"{subset_prefix}_{tid}_" in repair_instance["repair_id"] 161 | for tid in task_id 162 | ): 163 | continue 164 | if tokenizer is None or model is None: 165 | tokenizer = AutoTokenizer.from_pretrained(model_name) 166 | kwargs = ( 167 | { 168 | "device_map": "auto", 169 | "torch_dtype": torch.bfloat16, 170 | "attn_implementation": "flash_attention_2", 171 | } 172 | if device == "cuda" 173 | else {"device_map": device} 174 | ) 175 | 176 | model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) 177 | user, first_code_line = format_prompt_to_question(orig_instance["prompt"]) 178 | messages = system_messages + [ 179 | {"role": "user", "content": user}, 180 | ] 181 | if "octocoder" in model_name: 182 | chat_template = """\ 183 | {%- for message in messages %} 184 | {%- if message['role'] == 'user' %} 185 | {{- 'Question: ' + message['content'].strip() + '\\n\\n' }} 186 | {%- elif message['role'] == 'system' %} 187 | {{- 'System: ' + message['content'].strip() + '\\n\\n' }} 188 | {%- elif message['role'] == 'assistant' %} 189 | {{- 'Answer: ' + message['content'] + '\\n\\n' }} 190 | {%- endif %} 191 | {%- endfor %}""" 192 | tokenizer.chat_template = chat_template 193 | try: 194 | tokenizer.apply_chat_template( 195 | messages, tokenize=False, add_generation_prompt=True 196 | ) 197 | except Exception: 198 | messages[1]["content"] = ( 199 | messages[0]["content"] + "\n\n" + messages[1]["content"] 200 | ) 201 | messages.pop(0) 202 | formatted_last_iteration_res = "\n".join( 203 | f"{i+1:03d}: {s}" 204 | for i, s in enumerate( 205 | cutoff( 206 | extract_code( 207 | repair_instance["code"], human_readable_target_lang, 0 208 | ) 209 | ).split("\n") 210 | ) 211 | ) 212 | messages.extend( 213 | [ 214 | { 215 | "role": "assistant", 216 | "content": f"```\n{formatted_last_iteration_res}\n```", 217 | }, 218 | { 219 | "role": "user", 220 | "content": f"Compiling this code produced an error:\n{repair_instance['compiler_output']}\n\nWrite the program again, and make sure to fix the error this time.", 221 | }, 222 | ] 223 | ) 224 | prompt = tokenizer.apply_chat_template( 225 | messages, tokenize=False, add_generation_prompt=True 226 | ) 227 | suffix = f"```{human_readable_target_lang.lower()}\n" 228 | prompt += suffix + first_code_line 229 | start = time.time() 230 | with torch.no_grad(): 231 | code, eos, crashed, resamples = sample_constrained( 232 | device=device, 233 | model_name=model_name, 234 | prompt=prompt, 235 | max_tokens=max_tokens, 236 | temperature=temp, 237 | do_sample=temp != 0, 238 | seed=seed, 239 | constrain_from=(suffix if constrained else None), 240 | constrain_until="```", 241 | trace=trace, 242 | model=model, 243 | tokenizer=tokenizer, 244 | timeout=timeout, 245 | try_top_k=try_top_k, 246 | reraise=reraise, 247 | ) 248 | end = time.time() 249 | time_taken = end - start 250 | extracted = extract_code(code, human_readable_target_lang, 0) 251 | extracted = cutoff(extracted) 252 | tests: str = orig_instance["tests"] 253 | if tests.strip().startswith("}") and extracted.strip().endswith("}"): 254 | tests = tests[tests.find("}") + 1 :] 255 | compilable = extracted + "\n\n" + tests 256 | pool.apply_async( 257 | compile_test_and_dump, 258 | ( 259 | output_file, 260 | { 261 | "dataset": dataset_name, 262 | "language": language, 263 | "split": split, 264 | "instance_id": repair_instance["repair_id"], 265 | "orig_instance_id": orig_instance["name"], 266 | "prompt": prompt, 267 | "constrained": constrained, 268 | "eos": eos, 269 | "crashed": str(crashed), 270 | "model_name": model_name, 271 | "temp": temp, 272 | "max_tokens": max_tokens, 273 | "time_taken": time_taken, 274 | "code": code, 275 | "compilable": compilable, 276 | "trace": trace, 277 | "resamples": resamples, 278 | "timeout": timeout, 279 | }, 280 | ), 281 | ) 282 | pool.close() 283 | pool.join() 284 | 285 | 286 | def compile_test_and_dump(output_file: str, specs: dict): 287 | try: 288 | compiled, compiler_output = LANGUAGE_COMPILER_MAP[specs["language"]]( 289 | specs["compilable"], specs["timeout"] 290 | ) 291 | if compiled is not None: 292 | tests_passed, test_output = LANGUAGE_TEST_MAP[specs["language"]]( 293 | compiled, specs["timeout"] 294 | ) 295 | else: 296 | tests_passed, test_output = None, None 297 | specs["compiled"] = compiled 298 | specs["compiler_output"] = compiler_output 299 | specs["tests_passed"] = tests_passed 300 | specs["test_output"] = test_output 301 | with open(output_file, "a") as f: 302 | print( 303 | json.dumps( 304 | specs, 305 | ), 306 | flush=True, 307 | file=f, 308 | ) 309 | if specs["trace"]: 310 | print("compiler_output:", compiler_output) 311 | print("tests_passed:", tests_passed) 312 | except Exception: 313 | print("WARNING CATASROPHIC FAILURE") 314 | print("RESULTS ARE NOT WRITTEN TO FILE") 315 | traceback.print_exc() 316 | print("", flush=True) 317 | 318 | 319 | if __name__ == "__main__": 320 | fire.Fire(main) 321 | -------------------------------------------------------------------------------- /experiments/main/invalid_mbpp: -------------------------------------------------------------------------------- 1 | mbpp_405_check_tuplex 2 | mbpp_563_extract_values 3 | mbpp_580_extract_even 4 | mbpp_612_merge 5 | mbpp_725_extract_quotation 6 | mbpp_791_remove_nested -------------------------------------------------------------------------------- /experiments/main/kill_inf.sh: -------------------------------------------------------------------------------- 1 | pkill -f run_temp_inf.sh 2 | pkill -f run_temp_inf.py 3 | pkill -f run_temp_inf_repair.py 4 | pkill -f run_inf.sh 5 | pkill -f inference_multiple.py 6 | pkill -f inference_multiple_repair.py 7 | -------------------------------------------------------------------------------- /experiments/main/print_c.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | 4 | nc_file, problem_name = sys.argv[1:] 5 | c_file = nc_file[:-8] + "c.jsonl" 6 | # results/humaneval_google_codegemma-7b-it_s\=0_t\=1_nc.jsonl 7 | with open(c_file) as f: 8 | for line in f: 9 | output = json.loads(line) 10 | if problem_name in output["instance_id"]: 11 | # print('\n'.join([f"{i+1:03d} {s}" for i, s in enumerate(output["code"].strip().split('\n'))])) 12 | print(output["code"].strip()) 13 | print(output["crashed"]) 14 | print(output["compiler_output"]) 15 | print(output["instance_id"]) 16 | print(output["tests_passed"]) 17 | -------------------------------------------------------------------------------- /experiments/main/rerun_temp_inf.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | from math import ceil 4 | # rerun the failed instances with ASI 5 | 6 | GPUS = list(range(8)) # Example list of GPUs, adjust based on your system 7 | N = 1 # Set the max number of allowed processes per GPU 8 | GPUSIZE = 80 9 | models = [ 10 | # ("microsoft/Phi-3.5-mini-instruct", 4), 11 | # ("codellama/CodeLlama-7b-Instruct-hf", 7), 12 | # ("meta-llama/Llama-3.1-8B-Instruct", 8), 13 | # ("google/gemma-2b-it", 2), 14 | ("google/gemma-2-2b-it", 2), 15 | ("google/gemma-2-9b-it", 9), 16 | ("google/gemma-2-27b-it", 27), 17 | ("deepseek-ai/deepseek-coder-33b-instruct", 33), 18 | # ("deepseek-ai/deepseek-coder-7b-instruct-v1.5", 7), 19 | # ("deepseek-ai/deepseek-coder-1.3b-instruct", 1.3), 20 | # ("google/codegemma-2b", 2), 21 | # ("meta-llama/Llama-3.1-70B-Instruct", 70), 22 | # ("codellama/CodeLlama-70b-Instruct-hf", 70), 23 | ("codellama/CodeLlama-34b-Instruct-hf", 34), 24 | # ("codellama/CodeLlama-13b-Instruct-hf", 13), 25 | # ("google/codegemma-7b-it", 7), 26 | # ("bigcode/octocoder", 14), 27 | ("Qwen/Qwen2.5-32B-Instruct", 32), 28 | ] 29 | 30 | 31 | def compute_needed_gpus(size_model, size_gpu): 32 | return (size_model * 2 * 1.15) / size_gpu 33 | 34 | 35 | # subsets = ["main", "mbpp"] 36 | # temps = ["1"] # , "0", "0.5"] 37 | # seeds = [0, 1, 2, 3] 38 | configs = [ 39 | ("", "_synth"), 40 | ( 41 | "--input_file '../translation/{}/dataset.json' --translate True --translation_source_lang Python", 42 | "_translate", 43 | ), 44 | ] 45 | # constraineds = [True] 46 | timeout = 300 47 | max_tokens = 1000 48 | 49 | 50 | def find_available_gpus(gpus, n): 51 | found_gpus = [] 52 | for gpu in gpus: 53 | process_count = int( 54 | subprocess.check_output( 55 | [ 56 | "/bin/bash", 57 | "-c", 58 | f"nvidia-smi -i {gpu} --query-compute-apps=pid --format=csv,noheader | wc -l", 59 | ], 60 | ).strip() 61 | ) 62 | if process_count < n: 63 | found_gpus.append(gpu) 64 | return found_gpus 65 | 66 | 67 | constrained = True 68 | with open("failed_without_ASI") as f: 69 | failed = f.readlines() 70 | total_configs = [] 71 | subset, model, seed, temp, suffix = None, None, None, None, None 72 | current_tasks = [] 73 | config_to_tasks = {} 74 | for entry in failed: 75 | if entry.startswith("params: "): 76 | config_to_tasks[(subset, model, seed, temp, suffix)] = current_tasks 77 | current_tasks = [] 78 | if "repair-all" in entry: 79 | subset, model, seed, temp, suffix = None, None, None, None, None 80 | continue 81 | subset, model, seed, temp, suffix, _, _ = entry[len("params: ") :].split(",") 82 | subset = subset.strip() 83 | model = model.strip() 84 | seed = seed.strip() 85 | temp = temp.strip() 86 | suffix = suffix.strip() 87 | for config, name in configs: 88 | if name == suffix: 89 | total_configs.append( 90 | ( 91 | seed, 92 | temp, 93 | config, 94 | name, 95 | constrained, 96 | model, 97 | next( 98 | size for model_name, size in models if model_name == model 99 | ), 100 | subset, 101 | ) 102 | ) 103 | else: 104 | current_task = entry.strip().split("_")[1] 105 | current_tasks.append(current_task) 106 | 107 | 108 | remaining_configs = total_configs.copy() 109 | running_configs = list() 110 | while remaining_configs or running_configs: 111 | # reinsert crashed programs 112 | for config, pipe in running_configs: 113 | if pipe.poll() is not None: 114 | running_configs.remove((config, pipe)) 115 | if pipe.returncode != 0: 116 | remaining_configs.append(config) 117 | cuda_devices, needed_gpus = [], 1 118 | cuda_devices = find_available_gpus(GPUS, N) 119 | total_config = None 120 | for total_config in remaining_configs: 121 | ( 122 | seed, 123 | temp, 124 | config, 125 | name, 126 | constrained, 127 | model, 128 | model_size, 129 | subset, 130 | ) = total_config 131 | needed_gpus = compute_needed_gpus(model_size, GPUSIZE) 132 | if needed_gpus > len(GPUS): 133 | print(f"model {model} is too large, skipping") 134 | remaining_configs.remove(total_config) 135 | continue 136 | if len(cuda_devices) >= needed_gpus: 137 | break 138 | if len(cuda_devices) < needed_gpus or total_config is None: 139 | print("No available GPU found or all configs running. Waiting...") 140 | time.sleep(60) 141 | continue 142 | remaining_configs.remove(total_config) 143 | if subset == "mbpp" and seed != 0: 144 | continue 145 | cuda_devices = cuda_devices[: int(ceil(needed_gpus))] 146 | if "translate True" in config: 147 | config = config.format("main-x" if subset == "main" else subset) 148 | 149 | if constrained: 150 | suffix = "c" 151 | else: 152 | suffix = "nc" 153 | 154 | task_ids = config_to_tasks[(subset, model, seed, temp, name)] 155 | command = ( 156 | f"CUDA_VISIBLE_DEVICES={','.join(str(i) for i in cuda_devices)} python3 inference_multiple.py " 157 | f"--max-tokens {max_tokens} --timeout {timeout} --model_name {model} --seed {seed} --temp {temp} --subset {subset} --task_id {','.join(task_ids)} " 158 | f"--constrained {constrained} --output_file 'results/rerun_{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{name}_{suffix}.jsonl' {config}" 159 | ) 160 | print("+ " + command) 161 | pipe = subprocess.Popen(["/bin/bash", "-c", command]) 162 | running_configs.append((total_config, pipe)) 163 | time.sleep(20) 164 | -------------------------------------------------------------------------------- /experiments/main/run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Description: Run the experiment for generating samples from the model constrained and unconstrained 3 | set -ex 4 | # cd into the directory of this file 5 | cd "$(dirname "${BASH_SOURCE[0]}")" 6 | python3 run_experiments_syn_tran.py 7 | python3 run_experiments_repair.py 8 | -------------------------------------------------------------------------------- /experiments/main/run_experiments_repair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import subprocess 4 | import time 5 | from math import ceil 6 | 7 | import fire 8 | import torch 9 | 10 | 11 | def get_gpu_memory(): 12 | command = "nvidia-smi --query-gpu=memory.total --format=csv" 13 | memory_free_info = ( 14 | subprocess.check_output(command.split()).decode("ascii").split("\n")[:-1][1:] 15 | ) 16 | memory_free_values = [ 17 | int(x.split()[0]) // 1000 for i, x in enumerate(memory_free_info) 18 | ] 19 | return memory_free_values 20 | 21 | 22 | parent = pathlib.Path(__file__).parent.absolute() 23 | GPUS = os.environ.get("CUDA_VISIBLE_DEVICES", None) 24 | if GPUS is None: 25 | GPUS = list( 26 | range(torch.cuda.device_count()) 27 | ) # list of GPUs per default, adjust based on your system 28 | N = 1 # Set the max number of allowed processes per GPU 29 | GPUSIZE = min(get_gpu_memory()) # for now assumes same memory for all GPUs 30 | MODEL_SIZE_MAP = { 31 | "google/gemma-2-2b-it": 2, 32 | "google/gemma-2-9b-it": 9, 33 | "google/gemma-2-27b-it": 27, 34 | "deepseek-ai/deepseek-coder-33b-instruct": 33, 35 | "codellama/CodeLlama-34b-Instruct-hf": 34, 36 | "Qwen/Qwen2.5-32B-Instruct": 32, 37 | } 38 | 39 | 40 | def compute_needed_gpus(size_model, size_gpu): 41 | return (size_model * 2 * 1.2) / size_gpu 42 | 43 | 44 | subsets = ["humaneval", "mbpp"] 45 | temps = ["1"] 46 | seeds = [0] 47 | constraineds = [False, True] 48 | timeout = 300 49 | max_tokens = 1000 50 | try_top_k = 10000000000000000 51 | 52 | 53 | def find_available_gpus(gpus, n): 54 | found_gpus = [] 55 | for gpu in gpus: 56 | process_count = int( 57 | subprocess.check_output( 58 | [ 59 | "/bin/bash", 60 | "-c", 61 | f"nvidia-smi -i {gpu} --query-compute-apps=pid --format=csv,noheader | wc -l", 62 | ], 63 | ).strip() 64 | ) 65 | if process_count < n: 66 | found_gpus.append(gpu) 67 | return found_gpus 68 | 69 | 70 | def main( 71 | subsets=subsets, 72 | seeds=seeds, 73 | temps=temps, 74 | constraineds=constraineds, 75 | models=list(MODEL_SIZE_MAP.keys()), 76 | timeout=timeout, 77 | max_tokens=max_tokens, 78 | gpu_size=GPUSIZE, 79 | gpus=GPUS, 80 | n_process_per_gpu=N, 81 | ): 82 | if isinstance(models, str): 83 | models = models.split(",") 84 | if isinstance(subsets, str): 85 | subsets = subsets.split(",") 86 | if isinstance(temps, str): 87 | temps = [float(x) for x in temps.split(",")] 88 | elif isinstance(temps, int): 89 | temps = [temps] 90 | if isinstance(seeds, str): 91 | seeds = [int(x) for x in seeds.split(",")] 92 | elif isinstance(seeds, int): 93 | seeds = [seeds] 94 | if isinstance(constraineds, str): 95 | constraineds = [constraineds == "True"] 96 | elif isinstance(constraineds, int): 97 | constraineds = [constraineds != 0] 98 | if isinstance(gpus, str): 99 | gpus = [int(x) for x in gpus.split(",")] 100 | elif isinstance(gpus, int): 101 | gpus = [gpus] 102 | 103 | assert all(subset in ["humaneval", "mbpp"] for subset in subsets) 104 | assert all(model in MODEL_SIZE_MAP for model in models) 105 | 106 | total_configs = [] 107 | for subset in subsets: 108 | for seed in seeds: 109 | for temp in temps: 110 | for constrained in constraineds: 111 | for model in models: 112 | total_configs.append( 113 | ( 114 | seed, 115 | temp, 116 | constrained, 117 | model, 118 | subset, 119 | ) 120 | ) 121 | 122 | remaining_configs = total_configs.copy() 123 | running_configs = list() 124 | while remaining_configs or running_configs: 125 | # reinsert crashed programs 126 | for config, pipe in running_configs: 127 | if pipe.poll() is not None: 128 | running_configs.remove((config, pipe)) 129 | if pipe.returncode != 0: 130 | remaining_configs.append(config) 131 | cuda_devices, needed_gpus = find_available_gpus(gpus, n_process_per_gpu), 1 132 | total_config = None 133 | for total_config in remaining_configs: 134 | ( 135 | seed, 136 | temp, 137 | constrained, 138 | model, 139 | subset, 140 | ) = total_config 141 | needed_gpus = compute_needed_gpus(MODEL_SIZE_MAP[model], gpu_size) 142 | if needed_gpus > len(gpus): 143 | print(f"Model {model} is too large to fit on available GPUs, skipping") 144 | remaining_configs.remove(total_config) 145 | continue 146 | if len(cuda_devices) >= needed_gpus: 147 | break 148 | if len(cuda_devices) < needed_gpus or total_config is None: 149 | if not remaining_configs: 150 | s = "Waiting for running jobs to finish..." 151 | else: 152 | s = f"All {len(gpus)} GPUs are busy, waiting to start new job for 60 seconds." 153 | print( 154 | f"Total jobs: {len(total_configs)}, Running jobs: {len(running_configs)}, Remaining jobs: {len(remaining_configs)}. {s}" 155 | ) 156 | time.sleep(60) 157 | continue 158 | remaining_configs.remove(total_config) 159 | cuda_devices = cuda_devices[: int(ceil(needed_gpus))] 160 | 161 | config = f" --input_file 'repair_datasets/{subset}_repair_dataset.jsonl'" 162 | 163 | if constrained: 164 | suffix = "c" 165 | else: 166 | suffix = "nc" 167 | command = ( 168 | f"CUDA_VISIBLE_DEVICES={','.join(str(i) for i in cuda_devices)} python3 inference_multiple_repair.py " 169 | f"--max-tokens {max_tokens} --timeout {timeout} --model_name {model} --seed {seed} --temp {temp} --subset {subset} --try_top_k {try_top_k} " 170 | f"--constrained {constrained} --output_file 'results/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}_repair-all_{suffix}.jsonl' {config}" 171 | ) 172 | print("+ " + command) 173 | pipe = subprocess.Popen(["/bin/bash", "-c", command], cwd=parent) 174 | running_configs.append((total_config, pipe)) 175 | time.sleep(20) 176 | 177 | 178 | if __name__ == "__main__": 179 | fire.Fire(main) 180 | -------------------------------------------------------------------------------- /experiments/main/run_experiments_syn_tran.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import subprocess 4 | import time 5 | from math import ceil 6 | 7 | import fire 8 | import torch 9 | 10 | 11 | def get_gpu_memory(): 12 | command = "nvidia-smi --query-gpu=memory.total --format=csv" 13 | memory_free_info = ( 14 | subprocess.check_output(command.split()).decode("ascii").split("\n")[:-1][1:] 15 | ) 16 | memory_free_values = [ 17 | int(x.split()[0]) // 1000 for i, x in enumerate(memory_free_info) 18 | ] 19 | return memory_free_values 20 | 21 | 22 | parent = pathlib.Path(__file__).parent.absolute() 23 | GPUS = os.environ.get("CUDA_VISIBLE_DEVICES", None) 24 | if GPUS is None: 25 | GPUS = list( 26 | range(torch.cuda.device_count()) 27 | ) # list of GPUs per default, adjust based on your system 28 | N = 1 # Set the max number of allowed processes per GPU 29 | GPUSIZE = min(get_gpu_memory()) # for now assumes same memory for all GPUs 30 | MODEL_SIZE_MAP = { 31 | "google/gemma-2-2b-it": 2, 32 | "google/gemma-2-9b-it": 9, 33 | "google/gemma-2-27b-it": 27, 34 | "deepseek-ai/deepseek-coder-33b-instruct": 33, 35 | "codellama/CodeLlama-34b-Instruct-hf": 34, 36 | "Qwen/Qwen2.5-32B-Instruct": 32, 37 | } 38 | 39 | 40 | def compute_needed_gpus(size_model, size_gpu): 41 | return (size_model * 2 * 1.15) / size_gpu 42 | 43 | 44 | SUBSETS = ["humaneval", "mbpp"] 45 | TEMPS = ["1"] 46 | SEEDS = [0, 1, 2, 3] 47 | CONFIGS = [ 48 | ("", "_synth"), 49 | ( 50 | "--input_file '../translation/{}/dataset.json' --translate True --translation_source_lang Python", 51 | "_translate", 52 | ), 53 | ] 54 | CONSTRAINEDS = [False, True] 55 | TIMEOUT = 300 56 | MAX_TOKENS = 1000 57 | TRY_TOP_K = 10000000000000000 58 | 59 | 60 | def find_available_gpus(gpus, n): 61 | found_gpus = [] 62 | for gpu in gpus: 63 | process_count = int( 64 | subprocess.check_output( 65 | [ 66 | "/bin/bash", 67 | "-c", 68 | f"nvidia-smi -i {gpu} --query-compute-apps=pid --format=csv,noheader | wc -l", 69 | ], 70 | ).strip() 71 | ) 72 | if process_count < n: 73 | found_gpus.append(gpu) 74 | return found_gpus 75 | 76 | 77 | def main( 78 | subsets=SUBSETS, 79 | seeds=SEEDS, 80 | temps=TEMPS, 81 | constraineds=CONSTRAINEDS, 82 | models=list(MODEL_SIZE_MAP.keys()), 83 | timeout=TIMEOUT, 84 | max_tokens=MAX_TOKENS, 85 | tasks=["synth", "translate"], 86 | gpu_size=GPUSIZE, 87 | gpus=GPUS, 88 | n_process_per_gpu=N, 89 | ): 90 | try_top_k = TRY_TOP_K 91 | if isinstance(models, str): 92 | models = models.split(",") 93 | if isinstance(subsets, str): 94 | subsets = subsets.split(",") 95 | if isinstance(temps, str): 96 | temps = [float(x) for x in temps.split(",")] 97 | elif isinstance(temps, int): 98 | temps = [temps] 99 | if isinstance(seeds, str): 100 | seeds = [int(x) for x in seeds.split(",")] 101 | elif isinstance(seeds, int): 102 | seeds = [seeds] 103 | if isinstance(constraineds, str): 104 | constraineds = [constraineds == "True"] 105 | elif isinstance(constraineds, int): 106 | constraineds = [constraineds != 0] 107 | if isinstance(tasks, str): 108 | tasks = tasks.split(",") 109 | if isinstance(gpus, str): 110 | gpus = [int(x) for x in gpus.split(",")] 111 | elif isinstance(gpus, int): 112 | gpus = [gpus] 113 | 114 | assert all(subset in ["humaneval", "mbpp"] for subset in subsets) 115 | assert all(model in MODEL_SIZE_MAP for model in models) 116 | 117 | configs = [(config, name) for config, name in CONFIGS if name[1:] in tasks] 118 | total_configs = [] 119 | for constrained in constraineds: 120 | for subset in subsets: 121 | for seed in seeds: 122 | for temp in temps: 123 | for config, name in configs: 124 | for model in models: 125 | if subset == "mbpp" and seed != 0: 126 | continue 127 | total_configs.append( 128 | ( 129 | seed, 130 | temp, 131 | config, 132 | name, 133 | constrained, 134 | model, 135 | subset, 136 | ) 137 | ) 138 | 139 | remaining_configs = total_configs.copy() 140 | running_configs = list() 141 | while remaining_configs or running_configs: 142 | # reinsert crashed programs 143 | for config, pipe in running_configs: 144 | if pipe.poll() is not None: 145 | running_configs.remove((config, pipe)) 146 | if pipe.returncode != 0: 147 | remaining_configs.append(config) 148 | cuda_devices, needed_gpus = [], 1 149 | cuda_devices = find_available_gpus(gpus, n_process_per_gpu) 150 | total_config = None 151 | for total_config in remaining_configs: 152 | ( 153 | seed, 154 | temp, 155 | config, 156 | name, 157 | constrained, 158 | model, 159 | subset, 160 | ) = total_config 161 | needed_gpus = compute_needed_gpus(MODEL_SIZE_MAP[model], gpu_size) 162 | if needed_gpus > len(gpus): 163 | print(f"Model {model} is too large to fit on available GPUs, skipping") 164 | remaining_configs.remove(total_config) 165 | continue 166 | if len(cuda_devices) >= needed_gpus: 167 | break 168 | if len(cuda_devices) < needed_gpus or total_config is None: 169 | if not remaining_configs: 170 | s = "Waiting for running jobs to finish..." 171 | else: 172 | s = f"All {len(gpus)} GPUs are busy, waiting to start new job for 60 seconds." 173 | print( 174 | f"Total jobs: {len(total_configs)}, Running jobs: {len(running_configs)}, Remaining jobs: {len(remaining_configs)}. {s}" 175 | ) 176 | time.sleep(60) 177 | continue 178 | remaining_configs.remove(total_config) 179 | cuda_devices = cuda_devices[: int(ceil(needed_gpus))] 180 | if "translate True" in config: 181 | config = config.format("humaneval-x" if subset == "humaneval" else subset) 182 | 183 | if constrained: 184 | suffix = "c" 185 | else: 186 | suffix = "nc" 187 | command = ( 188 | f"CUDA_VISIBLE_DEVICES={','.join(str(i) for i in cuda_devices)} python3 inference_multiple.py " 189 | f"--max-tokens {max_tokens} --timeout {timeout} --model_name {model} --seed {seed} --temp {temp} --subset {subset} --try_top_k {try_top_k} " 190 | f"--constrained {constrained} --output_file 'results/{subset}_{model.replace('/', '_')}_s={seed}_t={temp}{name}_{suffix}.jsonl' {config}" 191 | ) 192 | print("+ " + command) 193 | pipe = subprocess.Popen(["/bin/bash", "-c", command], cwd=parent) 194 | running_configs.append((total_config, pipe)) 195 | time.sleep(20) 196 | 197 | 198 | if __name__ == "__main__": 199 | fire.Fire(main) 200 | -------------------------------------------------------------------------------- /experiments/main/translate_canonical_humaneval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import sys 4 | 5 | import fire 6 | from datasets import load_dataset 7 | import argparse 8 | 9 | from langchain.prompts import PromptTemplate 10 | from langchain_core.messages import AIMessage 11 | from langchain_openai import ChatOpenAI 12 | 13 | import os 14 | 15 | from tqdm import tqdm 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser( 20 | description="Translate canonical solutions from the humaneval dataset" 21 | ) 22 | parser.add_argument( 23 | "--target_lang", 24 | type=str, 25 | default="ts", 26 | help="The target language to translate to", 27 | ) 28 | parser.add_argument( 29 | "--target_dir", 30 | type=str, 31 | default="data", 32 | help="The target directory to save the translations", 33 | ) 34 | parser.add_argument( 35 | "--model_name", 36 | type=str, 37 | default="gpt-4o-2024-05-13", 38 | help="The model name to use for translation", 39 | ) 40 | parser.add_argument( 41 | "--dataset_name", 42 | type=str, 43 | default="openai/openai_humaneval", 44 | help="The dataset of which to translate the canonical solution", 45 | ) 46 | args = parser.parse_args() 47 | 48 | dataset_name = args.dataset_name 49 | dataset = load_dataset(dataset_name) 50 | 51 | target_lang = args.target_lang 52 | target_dir = args.target_dir 53 | if not target_dir: 54 | print("Please provide the target directory") 55 | sys.exit(1) 56 | model_name = args.model_name 57 | openai_api_key = os.environ.get("OPENAI_API_KEY") 58 | llm = ChatOpenAI(openai_api_key=openai_api_key, model_name="gpt-4o-2024-05-13") 59 | 60 | TARGET_LANG_MAP = { 61 | "ts": "TypeScript", 62 | "rs": "Rust", 63 | } 64 | if target_lang not in TARGET_LANG_MAP: 65 | print(f"Unsupported target language: {target_lang}") 66 | sys.exit(1) 67 | 68 | human_readable_target_lang = TARGET_LANG_MAP[target_lang] 69 | prompt_template = f""" 70 | You are an expert in {human_readable_target_lang} and Python programming. 71 | Translate the following Python code to {human_readable_target_lang}. 72 | You may assume that the input code is correct and that the translation should be semantically equivalent. 73 | Just answer with the translated code in a ```{human_readable_target_lang.lower()}...``` block. 74 | 75 | This is the program to translate: 76 | ```python 77 | {{code}} 78 | ``` 79 | """ 80 | 81 | def parse_code(ai_message: AIMessage) -> str: 82 | return ai_message.content 83 | 84 | chain = PromptTemplate.from_template(prompt_template) | llm | parse_code 85 | 86 | translations_resolved = {} 87 | target_file = f"{target_dir}/{dataset_name.replace('/', '_')}_{target_lang}_{model_name}.jsonl" 88 | pathlib.Path(target_dir).mkdir(parents=True, exist_ok=True) 89 | if os.path.exists(target_file): 90 | with open(target_file, "r") as f: 91 | for line in f: 92 | json_line = json.loads(line) 93 | translations_resolved[json_line["task_id"]] = json_line["translation"] 94 | 95 | with open(target_file, "a") as f: 96 | for instance in tqdm(dataset["test"]): 97 | task_id = instance["task_id"] 98 | if task_id in translations_resolved: 99 | continue 100 | code = instance["prompt"] + "\n" + instance["canonical_solution"] 101 | translation = chain.invoke(code) 102 | f.write(json.dumps({"task_id": task_id, "translation": translation}) + "\n") 103 | f.flush() 104 | 105 | 106 | if __name__ == "__main__": 107 | fire.Fire(main) 108 | -------------------------------------------------------------------------------- /experiments/main/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import subprocess 4 | from pathlib import Path 5 | from subprocess import TimeoutExpired 6 | from tempfile import NamedTemporaryFile, TemporaryDirectory 7 | from typing import Tuple 8 | 9 | from attr import dataclass 10 | 11 | 12 | def extract_code(output: str, humanreadable_target_language: str, nth: int): 13 | prefix = f"```{humanreadable_target_language.lower()}\n" 14 | pos = 0 15 | for _ in range(nth + 1): 16 | pos = output.find(prefix, pos) + len(prefix) 17 | code = output[pos:] 18 | code = code[: code.find("```")] 19 | return code.strip().strip("`") + "\n" 20 | 21 | 22 | def cutoff(str_program: str): 23 | """ 24 | Cutoff after the last outermost function is closed 25 | """ 26 | curly_open = 0 27 | # default to just returning the entire string 28 | last_balanced_pos = len(str_program) 29 | for i, char in enumerate(str_program): 30 | if char == "{": 31 | curly_open += 1 32 | if char == "}": 33 | if curly_open <= 0: 34 | break 35 | curly_open -= 1 36 | if curly_open == 0 and str_program[i + 1] in ("\n", ";"): 37 | last_balanced_pos = i 38 | return str_program[: last_balanced_pos + 1] 39 | 40 | 41 | def test_code_extract_2(): 42 | code = """ 43 | ```typescript 44 | function words_string(s: string): string[] { 45 | return s.replace(/,/g, ').split(' ).trim().split(' ').filter(word => word!== ''); 46 | } 47 | ``""" 48 | extracted = extract_code(code, "TypeScript", 0) 49 | cutoffed = cutoff(extracted) 50 | print(cutoffed) 51 | 52 | 53 | def test_cutoff(): 54 | code = """ 55 | function words_string(s: string): string[] { 56 | return s.replace(/,/g, ').split(' ).trim().split(' ').filter(word => word!== ''); 57 | }; 58 | """ 59 | extracted = extract_code(code, "TypeScript", 0) 60 | cutoffed = cutoff(extracted) 61 | print(cutoffed) 62 | 63 | 64 | def test_cutoff2(): 65 | code = """ 66 | function words_string(s: string): string[] { 67 | return s.replace(/,/g, ').split(' ).trim().split(' ').filter(word => word!== ''); 68 | } 69 | function abc(s: string): string[] { 70 | for (x in bla){ 71 | y.map(x => { 72 | }) 73 | } 74 | } 75 | """ 76 | extracted = extract_code(code, "TypeScript", 0) 77 | cutoffed = cutoff(extracted) 78 | print(cutoffed) 79 | 80 | 81 | def test_double_cutoff(): 82 | code = """ 83 | ```typescript 84 | function prod_signs(arr: number[]): number | undefined { 85 | if (arr.length === 0) { 86 | return undefined; 87 | } 88 | let product = 1; 89 | let sum = 0; 90 | for (let i = 0; i < arr.length; i++) { 91 | if (arr[i] === 0) { 92 | continue; 93 | } 94 | product *= arr[i]; 95 | } 96 | for (let i = 0; i < arr.length; i++) { 97 | if (arr[i] === 0) { 98 | continue; 99 | } 100 | sum += Math.abs(arr[i]) * product; 101 | } 102 | return sum; 103 | } 104 | ``` 105 | 106 | 107 | **Explanation:** 108 | * **Function Definition:** 109 | ```typescript 110 | function prod_signs(arr: number[]): number | undefined { 111 | // Function code goes here 112 | } 113 | ``` 114 | """ 115 | res = extract_code(code, "TypeScript", 0) 116 | print(res) 117 | 118 | 119 | def test_cutoff3(): 120 | code = """ 121 | function words_string(s: string): string[] { 122 | return s.replace(/,/g, ').split(' ).trim().split(' ').filter(word => word!== ''); 123 | """ 124 | extracted = extract_code(code, "TypeScript", 0) 125 | cutoffed = cutoff(extracted) 126 | print(cutoffed) 127 | 128 | 129 | def test_cutoff4(): 130 | code = """ 131 | function words_string(s: string): string[] { 132 | return s.replace(/,/g, ').split(' ).trim().split(' ').filter(word => word!== ''); 133 | } 134 | // example use 135 | console.log(`${words_string("blabla")}`); 136 | """ 137 | extracted = extract_code(code, "TypeScript", 0) 138 | cutoffed = cutoff(extracted) 139 | print(cutoffed) 140 | 141 | 142 | def test_code_extract(): 143 | code = """ 144 | user 145 | You are an expert in TypeScript programming. Solve the given problem by writing solution code in TypeScript. 146 | When answering, insert the solution code in a ```typescript...``` block. 147 | 148 | 149 | Check if in given array of numbers, are any two numbers closer to each other than 150 | given threshold. 151 | >>> has_close_elements([1.0, 2.0, 3.0], 0.5) 152 | false 153 | >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) 154 | true 155 | model 156 | ```typescript 157 | function has_close_elements(numbers: number[], threshold: number): boolean { 158 | // Sort the array in ascending order 159 | numbers.sort(); 160 | 161 | // Iterate over the array 162 | for (let i = 0; i < numbers.length - 1; i++) { 163 | // Skip the last element, as it is already considered close to the previous element 164 | if (i > 0 && Math.abs(numbers[i] - numbers[i - 1]) <= threshold) { 165 | return true; 166 | } 167 | } 168 | 169 | // If no elements are close to the threshold, return false 170 | return false; 171 | } 172 | ``` 173 | """ 174 | extracted = extract_code(code, "TypeScript", 0) 175 | cutoffed = cutoff(extracted) 176 | print(cutoffed) 177 | 178 | 179 | def tsx_compiles(ts_program, timeout=300) -> Tuple[str | None, str]: 180 | with NamedTemporaryFile(suffix=".ts") as f: 181 | f.write(ts_program.encode()) 182 | f.flush() 183 | if has_syntax_error(f.name, timeout): 184 | return (None, "SyntaxError: Abort compilation") 185 | try: 186 | res = subprocess.run( 187 | [ 188 | "npx", 189 | "tsc", 190 | "--lib", 191 | "es2024", 192 | "--target", 193 | "es2024", 194 | "--strict", 195 | f.name, 196 | "--outFile", 197 | "/dev/stderr", 198 | ], 199 | capture_output=True, 200 | timeout=timeout, 201 | ) 202 | if res.returncode != 0: 203 | return res.stderr.decode(), res.stdout.decode() 204 | return res.stderr.decode(), res.stdout.decode() 205 | except TimeoutExpired: 206 | return None, "Timeout" 207 | 208 | 209 | def rust_compiles(rust_program, timeout=300, path="a.out") -> Tuple[str | None, str]: 210 | with NamedTemporaryFile(suffix=".rs") as f: 211 | f.write(rust_program.encode()) 212 | f.flush() 213 | if os.path.exists(path): 214 | os.unlink(path) 215 | try: 216 | res = subprocess.run( 217 | ["rustc", f.name, "-o", path], capture_output=True, timeout=timeout 218 | ) 219 | if res.returncode != 0: 220 | return None, res.stderr.decode() 221 | return path, res.stderr.decode() 222 | except TimeoutExpired: 223 | return None, "Timeout" 224 | 225 | 226 | def setup_go_env(go_program: str, go_tests: str, dir: pathlib.Path, timeout): 227 | """ 228 | Write program into file (adding relevant imports) and setup go package stuff 229 | """ 230 | main_prefix = """ 231 | package main 232 | """ 233 | main_suffix = ( 234 | """ 235 | func main(){ 236 | } 237 | """ 238 | if "func main(" not in go_program 239 | else "" 240 | ) 241 | main_file = dir / "main.go" 242 | with open(main_file, "w") as f: 243 | f.write(main_prefix + go_program + main_suffix) 244 | test_file = dir / "main_test.go" 245 | tests_prefix = """ 246 | package main 247 | import ( 248 | "testing" 249 | "github.com/stretchr/testify/assert" 250 | ) 251 | """ 252 | with open(test_file, "w") as f: 253 | f.write(tests_prefix + go_tests) 254 | subprocess.run( 255 | ["go", "mod", "init", "sample"], timeout=timeout, check=True, cwd=dir 256 | ) 257 | subprocess.run(["go", "mod", "tidy"], timeout=timeout, check=True, cwd=dir) 258 | return [main_file, test_file] 259 | 260 | 261 | @dataclass 262 | class TestResult: 263 | setup_ok: bool 264 | syntax_ok: bool 265 | compiled: bool 266 | passed: bool 267 | error_message: str 268 | 269 | 270 | def go_compiles_passes(program, tests, timeout=300) -> TestResult: 271 | with TemporaryDirectory() as tmpdir: 272 | try: 273 | files = setup_go_env(program, tests, pathlib.Path(tmpdir), timeout) 274 | except TimeoutExpired: 275 | return TestResult(False, False, False, False, "Setup Timeout") 276 | try: 277 | res = subprocess.run( 278 | [ 279 | "goimports", 280 | "-w", 281 | *files, 282 | ], 283 | capture_output=True, 284 | timeout=timeout, 285 | cwd=tmpdir, 286 | ) 287 | if res.returncode != 0: 288 | return TestResult( 289 | True, 290 | False, 291 | False, 292 | False, 293 | res.stderr.decode() + "\n\n" + res.stdout.decode(), 294 | ) 295 | except TimeoutExpired: 296 | return TestResult(True, False, False, False, "Format/Import Timeout") 297 | try: 298 | res = subprocess.run( 299 | [ 300 | "go", 301 | "build", 302 | ], 303 | capture_output=True, 304 | timeout=timeout, 305 | cwd=tmpdir, 306 | ) 307 | if res.returncode != 0: 308 | return TestResult(True, True, False, False, res.stderr.decode()) 309 | except TimeoutExpired: 310 | return TestResult(True, True, False, False, "Compilation Timeout") 311 | try: 312 | res = subprocess.run( 313 | ["go", "test", f"-timeout={timeout}s"], 314 | capture_output=True, 315 | timeout=timeout, 316 | cwd=tmpdir, 317 | ) 318 | if res.returncode != 0: 319 | return TestResult( 320 | True, 321 | True, 322 | True, 323 | False, 324 | res.stdout.decode() + "\n\n" + res.stderr.decode(), 325 | ) 326 | except TimeoutExpired: 327 | return TestResult(True, True, True, False, "Test Timeout") 328 | return TestResult(True, True, True, True, "") 329 | 330 | 331 | def go_passes_tests(ts_program, timeout=300) -> Tuple[str | None, str]: 332 | with TemporaryDirectory() as tmpdir: 333 | file = "main_test.go" 334 | tmpfile = pathlib.Path(tmpdir) / file 335 | with open(tmpfile, "w") as f: 336 | f.write(ts_program.encode()) 337 | if has_syntax_error(f.name, timeout): 338 | return (None, "SyntaxError: Abort compilation") 339 | try: 340 | res = subprocess.run( 341 | [ 342 | "go", 343 | "build", 344 | f"-timeout={timeout}s", 345 | "main.go", 346 | ], 347 | capture_output=True, 348 | timeout=timeout, 349 | cwd=tmpdir, 350 | ) 351 | res = subprocess.run( 352 | [ 353 | "go", 354 | "build", 355 | f"-timeout={timeout}s", 356 | file, 357 | ], 358 | capture_output=True, 359 | timeout=timeout, 360 | cwd=tmpdir, 361 | ) 362 | if res.returncode != 0: 363 | return res.stderr.decode(), res.stdout.decode() 364 | return res.stderr.decode(), res.stdout.decode() 365 | except TimeoutExpired: 366 | return None, "Timeout" 367 | 368 | 369 | def passes_tests_js(js_program, timeout=300) -> Tuple[bool, str]: 370 | try: 371 | res = subprocess.run( 372 | ["node", "-e", js_program], 373 | check=False, 374 | capture_output=True, 375 | timeout=timeout, 376 | ) 377 | return res.returncode == 0, ( 378 | {"stdout": res.stdout.decode(), "stderr": res.stderr.decode()} 379 | ) 380 | except TimeoutExpired: 381 | return False, "Timeout" 382 | 383 | 384 | path_to_ts_parser = "../../ts_parser/target/release/ts_parser" 385 | 386 | 387 | def has_syntax_error(ts_program_location, timeout=300) -> bool: 388 | res = subprocess.run([path_to_ts_parser, ts_program_location], capture_output=True) 389 | return res.returncode != 0 390 | 391 | 392 | with open(Path(__file__).parent / "invalid_mbpp") as f: 393 | invalid_mbpp_instances = {line.strip() for line in f.readlines()} 394 | -------------------------------------------------------------------------------- /experiments/translation/humaneval-x/.gitignore: -------------------------------------------------------------------------------- 1 | *.jsonl.gz -------------------------------------------------------------------------------- /experiments/translation/humaneval-x/execute.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import subprocess 5 | from tqdm import tqdm 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--output_name", type=str, required=True) 12 | 13 | parser.add_argument("--data_path", type=str, default="dataset.json") 14 | parser.add_argument("--output_dir", type=str, default="output_translation") 15 | 16 | args = parser.parse_args() 17 | args.output_dir = os.path.join(args.output_dir, args.output_name) 18 | 19 | return args 20 | 21 | 22 | if __name__ == "__main__": 23 | args = get_args() 24 | 25 | with open(args.data_path) as f: 26 | dataset = json.load(f) 27 | tgt_dataset = dataset["ts"] 28 | 29 | for task_id in tqdm(list(sorted(os.listdir(args.output_dir)))): 30 | tests = tgt_dataset[task_id]["tests"] 31 | tgt_progs_dir = os.path.join(args.output_dir, task_id) 32 | for tgt_prog_name in os.listdir(tgt_progs_dir): 33 | if not tgt_prog_name.endswith(".ts"): 34 | continue 35 | 36 | tgt_prog_name = tgt_prog_name[:-3] 37 | exec_res = { 38 | "tgt_prog": os.path.join(tgt_progs_dir, tgt_prog_name + ".ts"), 39 | "compile": False, 40 | "test": False, 41 | } 42 | 43 | returncode = subprocess.call( 44 | ["tsc", os.path.join(tgt_progs_dir, tgt_prog_name + ".ts")] 45 | ) 46 | if returncode != 0: 47 | with open( 48 | os.path.join(tgt_progs_dir, tgt_prog_name + ".json"), "w" 49 | ) as f: 50 | f.write(json.dumps(exec_res, indent=2)) 51 | continue 52 | 53 | exec_res["compile"] = True 54 | try: 55 | returncode = subprocess.call( 56 | ["node", os.path.join(tgt_progs_dir, tgt_prog_name + ".js")], 57 | timeout=5, 58 | ) 59 | if returncode != 0: 60 | test_res = "Incorrect" 61 | else: 62 | test_res = "Correct" 63 | except subprocess.TimeoutExpired: 64 | test_res = "Timeout" 65 | except Exception as e: 66 | test_res = e.__class__.__name__ 67 | exec_res["test"] = test_res 68 | 69 | with open(os.path.join(tgt_progs_dir, tgt_prog_name + ".json"), "w") as f: 70 | json.dump(exec_res, f, indent=2) 71 | -------------------------------------------------------------------------------- /experiments/translation/humaneval-x/metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from tabulate import tabulate 6 | 7 | K = [1, 5, 10] 8 | 9 | 10 | def compute(n, c, k): 11 | if n - c < k: 12 | return 1.0 13 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument("--output_name", type=str, required=True) 20 | parser.add_argument("--output_dir", type=str, default="output_translation") 21 | 22 | args = parser.parse_args() 23 | args.output_dir = os.path.join(args.output_dir, args.output_name) 24 | 25 | return args 26 | 27 | 28 | if __name__ == "__main__": 29 | args = get_args() 30 | 31 | pass_at_ks = [[] for _ in range(len(K))] 32 | compile_at_ks = [[] for _ in range(len(K))] 33 | for tgt_progs_dir in os.listdir(args.output_dir): 34 | tgt_progs_dir = os.path.join(args.output_dir, tgt_progs_dir) 35 | n, c_pass, c_compile = 0, 0, 0 36 | for res_name in os.listdir(tgt_progs_dir): 37 | if not res_name.endswith(".json"): 38 | continue 39 | 40 | n += 1 41 | with open(os.path.join(tgt_progs_dir, res_name)) as f: 42 | j = json.load(f) 43 | if j["compile"]: 44 | c_compile += 1 45 | if j["test"] == "Correct": 46 | c_pass += 1 47 | else: 48 | print(j["tgt_prog"]) 49 | for i, k in enumerate(K): 50 | pass_at_ks[i].append(compute(n, c_pass, k)) 51 | compile_at_ks[i].append(compute(n, c_compile, k)) 52 | for i, k in enumerate(K): 53 | pass_at_ks[i] = np.mean(pass_at_ks[i]) * 100 54 | compile_at_ks[i] = np.mean(compile_at_ks[i]) * 100 55 | 56 | header, row = [], [] 57 | for i, k in enumerate(K): 58 | header.append(f"pass@{k}") 59 | row.append("{:.1f}".format(pass_at_ks[i])) 60 | print(tabulate([row], headers=header, stralign="right", tablefmt="orgtbl")) 61 | print() 62 | 63 | header, row = [], [] 64 | for i, k in enumerate(K): 65 | header.append(f"compile@{k}") 66 | row.append("{:.1f}".format(compile_at_ks[i])) 67 | print(tabulate([row], headers=header, stralign="right", tablefmt="orgtbl")) 68 | -------------------------------------------------------------------------------- /experiments/translation/humaneval-x/process_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | import argparse 4 | import datasets 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | args = parser.parse_args() 10 | return args 11 | 12 | 13 | if __name__ == "__main__": 14 | args = get_args() 15 | 16 | res_dataset = dict() 17 | task_ids = set() 18 | 19 | hf_dataset = datasets.load_dataset("nuprl/MultiPL-E", "main-ts")["test"] 20 | ts_dataset = dict() 21 | for d in hf_dataset: 22 | j = dict() 23 | task_id = d["name"][10:] 24 | task_id = int(task_id[: task_id.find("_")]) 25 | prompt = d["prompt"].strip().split("\n")[-1] 26 | task_ids.add(task_id) 27 | j["task_id"] = task_id 28 | j["prompt"] = prompt 29 | if task_id == 10: 30 | j["prompt"] = "function is_palindrome(string: string): boolean {" 31 | j["tests"] = d["tests"].replace("declare var require: any;", "").strip() 32 | ts_dataset[task_id] = j 33 | res_dataset["ts"] = ts_dataset 34 | 35 | for lang in ["python", "cpp", "go", "java", "js"]: 36 | with gzip.open(f"humaneval_{lang}.jsonl.gz") as f: 37 | lines = f.readlines() 38 | lang_dataset = dict() 39 | for line in lines: 40 | j = dict() 41 | d = json.loads(line) 42 | task_id = d["task_id"] 43 | task_id = int(task_id[task_id.find("/") + 1 :]) 44 | if task_id not in task_ids: 45 | continue 46 | j["task_id"] = task_id 47 | j["prompt"] = d["declaration"] + d["canonical_solution"] 48 | lang_dataset[task_id] = j 49 | res_dataset[lang] = lang_dataset 50 | 51 | with open("dataset.json", "w") as f: 52 | json.dump(res_dataset, f, indent=2) 53 | -------------------------------------------------------------------------------- /experiments/translation/humaneval-x/translate.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | import torch 5 | import argparse 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | from typesafe_llm.sampling import sample_constrained 9 | 10 | 11 | def FULL_LANG(lang): 12 | if lang == "cpp": 13 | return "C++" 14 | elif lang == "go": 15 | return "Go" 16 | elif lang == "java": 17 | return "Java" 18 | elif lang == "python": 19 | return "Python" 20 | elif lang == "ts": 21 | return "TypeScript" 22 | else: 23 | assert False 24 | 25 | 26 | def SYSTEM_PROMPT(src_lang): 27 | full_lang = FULL_LANG(src_lang) 28 | return f"You are a helpful and expert programmer in {full_lang} and TypeScript. You will be given an input program in {full_lang} and your task is to translate this program into TypeScript. You may assume that the input program is correct and that the translation should be semantically equivalent. You will NOT return anything except for the program." 29 | 30 | 31 | def TRANSLATION_PROMPT(src_lang, src_prog): 32 | full_lang = FULL_LANG(src_lang) 33 | src_prog = src_prog.strip() 34 | return f"The following is the source program in {full_lang}:\n```{src_lang}\n{src_prog}\n```\n\nPlease translate the source program to TypeScript." 35 | 36 | 37 | REGEX = "^```typescript\n([\s\S]*?)```$" 38 | 39 | 40 | def extract_tgt_code(output): 41 | res = re.search(r"^```typescript\n([\s\S]*?)```$", output, re.M) 42 | if res is not None: 43 | span = res.span() 44 | output = output[span[0] : span[1]] 45 | output = output[14:] 46 | output = output[:-3] 47 | return output.strip() + "\n" 48 | else: 49 | output = output[output.find("```typescript\n") + 14 :] 50 | return output.strip() + "\n" 51 | 52 | 53 | class HFChatModel: 54 | def __init__(self, args, **kwargs): 55 | self.args = args 56 | self.tokenizer = AutoTokenizer.from_pretrained(args.model_name) 57 | self.model = AutoModelForCausalLM.from_pretrained(args.model_name, **kwargs) 58 | self.model.eval() 59 | 60 | def translate(self, src_prog, signature): 61 | src_lang = self.args.src_lang 62 | system_prompt = SYSTEM_PROMPT(src_lang) 63 | translation_prompt = TRANSLATION_PROMPT(src_lang, src_prog) 64 | try: 65 | messages = [ 66 | {"role": "system", "content": SYSTEM_PROMPT(src_lang)}, 67 | {"role": "user", "content": TRANSLATION_PROMPT(src_lang, src_prog)}, 68 | ] 69 | prompt = self.tokenizer.apply_chat_template( 70 | messages, tokenize=False, add_generation_prompt=True 71 | ) 72 | except Exception: 73 | messages = [ 74 | {"role": "user", "content": system_prompt + "\n\n" + translation_prompt} 75 | ] 76 | prompt = self.tokenizer.apply_chat_template( 77 | messages, tokenize=False, add_generation_prompt=True 78 | ) 79 | prompt += "```typescript\n" 80 | prompt += signature 81 | 82 | # input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) 83 | # outputs = self.model.generate(**input_ids, max_new_tokens=self.args.max_tokens, do_sample=False) 84 | # output = self.tokenizer.decode(outputs[0]) 85 | # print(output) 86 | # return [extract_tgt_code(output)] 87 | 88 | with torch.no_grad(): 89 | outputs = sample_constrained( 90 | prompt=prompt, 91 | tokenizer=self.tokenizer, 92 | model=self.model, 93 | constrain_from=None if self.args.no_constrain else "```typescript\n", 94 | constrain_until="\n```", 95 | max_tokens=self.args.max_tokens, 96 | try_top_k=self.args.try_top_k, 97 | device="cuda", 98 | ) 99 | return [extract_tgt_code(outputs[0])] 100 | 101 | 102 | def create_model(args): 103 | if args.model_name == "gemma-2b-it": 104 | args.model_name = "google/gemma-2b-it" 105 | kwargs = { 106 | "device_map": "cuda", 107 | "torch_dtype": torch.float16, 108 | "attn_implementation": "flash_attention_2", 109 | } 110 | elif args.model_name == "phi-3.5-it": 111 | args.model_name = "microsoft/Phi-3.5-mini-instruct" 112 | kwargs = { 113 | "device_map": "cuda", 114 | "torch_dtype": torch.float16, 115 | "attn_implementation": "flash_attention_2", 116 | } 117 | model = HFChatModel(args, **kwargs) 118 | return model 119 | 120 | 121 | def get_args(): 122 | parser = argparse.ArgumentParser() 123 | 124 | parser.add_argument("--output_name", type=str, default=None) 125 | parser.add_argument("--model_name", type=str, default="phi-3.5-it") 126 | parser.add_argument("--no_constrain", default=False, action="store_true") 127 | parser.add_argument("--num_gen", type=int, default=1) 128 | parser.add_argument("--max_tokens", type=int, default=250) 129 | parser.add_argument("--try_top_k", type=int, default=100) 130 | parser.add_argument("--task_id", type=str, default=None) 131 | 132 | parser.add_argument("--src_lang", type=str, default="python") 133 | parser.add_argument("--data_path", type=str, default="dataset.json") 134 | parser.add_argument("--output_dir", type=str, default="output_translation") 135 | 136 | args = parser.parse_args() 137 | 138 | if args.output_name is not None: 139 | args.output_dir = os.path.join(args.output_dir, args.output_name) 140 | 141 | return args 142 | 143 | 144 | if __name__ == "__main__": 145 | args = get_args() 146 | model = create_model(args) 147 | 148 | with open(args.data_path) as f: 149 | dataset = json.load(f) 150 | src_dataset = dataset[args.src_lang] 151 | tgt_dataset = dataset["ts"] 152 | 153 | task_ids = [args.task_id] if args.task_id is not None else list(sorted(src_dataset)) 154 | 155 | if args.output_name is not None: 156 | os.makedirs(os.path.join(args.output_dir), exist_ok=False) 157 | 158 | for task_id in tqdm(task_ids): 159 | src_prog = src_dataset[task_id]["prompt"] 160 | tgt_progs = model.translate(src_prog, tgt_dataset[task_id]["prompt"]) 161 | 162 | if args.output_name is not None: 163 | os.makedirs(os.path.join(args.output_dir, str(task_id))) 164 | for i in range(args.num_gen): 165 | with open( 166 | os.path.join(args.output_dir, str(task_id), f"{i}.ts"), "w" 167 | ) as f: 168 | f.write(tgt_progs[i] + "\n\n" + tgt_dataset[task_id]["tests"]) 169 | -------------------------------------------------------------------------------- /experiments/translation/mbpp/generate.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | 4 | ds = datasets.load_dataset("google-research-datasets/mbpp", "full") 5 | dataset = {} 6 | for split in ds: 7 | for instance in ds[split]: 8 | dataset[instance["task_id"]] = {"prompt": instance["code"]} 9 | json.dump({"python": dataset}, open("dataset.json", "w"), indent=2) 10 | -------------------------------------------------------------------------------- /incremental_tsc.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | from typesafe_llm.parser.parser_ts import parse_ts_program 4 | 5 | 6 | def main( 7 | input_file: str, 8 | ): 9 | with open(input_file, "r") as f: 10 | file_content = f.read() 11 | states = parse_ts_program(file_content, print_failure_point=True) 12 | print("-----------------------------") 13 | if states: 14 | print("Parsed successfully") 15 | else: 16 | print("Parsing failed at given point") 17 | 18 | 19 | if __name__ == "__main__": 20 | fire.Fire(main) 21 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "dependencies": { 3 | "typescript": "^5.7.2", 4 | "undici-types": "^7.0.0" 5 | }, 6 | "devDependencies": { 7 | "@types/node": "^20.0.0" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "typesafe-llm" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Anonymous"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.11" 10 | transformers = {git = "https://github.com/huggingface/transformers.git", rev="4286586"} 11 | fire = "^0.7.0" 12 | interegular = "^0.3.3" 13 | datasets = "^2.20.0" 14 | langchain = "^0.2.11" 15 | langchain-openai = "^0.1.19" 16 | pytest-timeout = "^2.3.1" 17 | stopit = "^1.1.2" 18 | accelerate = "^0.34.2" 19 | # flash-attn = "^2.6.3" 20 | frozendict = "^2.4.6" 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | hypothesis = "^6.75.6" 24 | parameterized = "^0.9.0" 25 | black = "^23.3.0" 26 | pre-commit = "^3.3.2" 27 | coverage = "<7.0" 28 | pytest = "^7.3.1" 29 | coveralls = "^3.3.1" 30 | poetry-bumpversion = "^0.3.0" 31 | tabulate = "^0.9.0" 32 | 33 | [build-system] 34 | requires = ["poetry-core"] 35 | build-backend = "poetry.core.masonry.api" 36 | -------------------------------------------------------------------------------- /setup_conda.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | # set up miniconda 3 | # Determine system architecture 4 | ARCH=$(uname -m) 5 | 6 | # Set URL based on architecture 7 | if [ "$ARCH" == "x86_64" ]; then 8 | URL="https://repo.anaconda.com/miniconda/Miniconda3-py39_24.9.2-0-Linux-x86_64.sh" 9 | else 10 | echo "Unsupported architecture: $ARCH" 11 | exit 1 12 | fi 13 | wget $URL -O miniconda.sh 14 | rm -r ~/.miniconda || echo "miniconda did not exist" 15 | bash miniconda.sh -b -p ~/.miniconda 16 | ~/.miniconda/bin/conda init bash 17 | 18 | # set up nvm 19 | curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash 20 | 21 | echo "Restart the shell to complete the installation" 22 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | conda update -n base -c defaults conda -y 3 | 4 | # set up torch 5 | conda install python=3.11 -y 6 | conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y 7 | 8 | # install flash-attention 9 | pip install flash-attn==2.7.3 --no-build-isolation 10 | 11 | # set up huggingface 12 | pip install -U "huggingface_hub[cli]" 13 | huggingface-cli login --add-to-git-credential 14 | 15 | # set up npm/tsc 16 | export NVM_DIR=$HOME/.nvm; 17 | source $NVM_DIR/nvm.sh; 18 | nvm install 20.16.0 19 | nvm use 20.16.0 20 | npm install typescript -g 21 | 22 | # set up typesafe_llm 23 | cd "$(dirname "${BASH_SOURCE[0]}")" 24 | pip install -e . 25 | cd ts_parser 26 | bash install_rust.sh 27 | . "$HOME/.cargo/env" 28 | bash build.sh 29 | cd .. 30 | echo "Restart the shell again to complete the installation" -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/test/__init__.py -------------------------------------------------------------------------------- /test/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/test/data/__init__.py -------------------------------------------------------------------------------- /test/data/manually_fixed/HumanEval_1.ts: -------------------------------------------------------------------------------- 1 | function separateParenGroups(parenString: string): string[] { 2 | /** 3 | * Input to this function is a string containing multiple groups of nested parentheses. Your goal is to 4 | * separate those group into separate strings and return the list of those. 5 | * Separate groups are balanced (each open brace is properly closed) and not nested within each other 6 | * Ignore any spaces in the input string. 7 | * 8 | * Example: 9 | * console.log(separateParenGroups('( ) (( )) (( )( ))')) 10 | * // Output: ['()', '(())', '(()())'] 11 | */ 12 | 13 | const result: string[] = []; 14 | let currentString: string[] = []; 15 | let currentDepth = 0; 16 | 17 | for (const c of parenString) { 18 | if (c === '(') { 19 | currentDepth += 1; 20 | currentString.push(c); 21 | } else { 22 | // CHANGE: fix nesting 23 | if (c === ')') { 24 | currentDepth -= 1; 25 | currentString.push(c); 26 | 27 | if (currentDepth === 0) { 28 | result.push(currentString.join('')); 29 | currentString = []; 30 | } 31 | } 32 | } 33 | } 34 | 35 | return result; 36 | } -------------------------------------------------------------------------------- /test/data/manually_fixed/HumanEval_10.ts: -------------------------------------------------------------------------------- 1 | function isPalindrome(str: string): boolean { 2 | /** Test if given string is a palindrome */ 3 | return str === str.split('').reverse().join(''); 4 | } 5 | 6 | function makePalindrome(str: string): string { 7 | /** Find the shortest palindrome that begins with a supplied string. 8 | Algorithm idea is simple: 9 | - Find the longest postfix of supplied string that is a palindrome. 10 | - Append to the end of the string reverse of a string prefix that comes before the palindromic suffix. 11 | */ 12 | 13 | if (!str) { 14 | return ''; 15 | } 16 | 17 | let beginningOfSuffix = 0; 18 | 19 | while (!isPalindrome(str.substring(beginningOfSuffix))) { 20 | beginningOfSuffix += 1; 21 | } 22 | 23 | // CHANGE: remove console.log 24 | return str + str.substring(0, beginningOfSuffix).split('').reverse().join(''); 25 | } -------------------------------------------------------------------------------- /test/data/manually_fixed/HumanEval_6.ts: -------------------------------------------------------------------------------- 1 | function parseNestedParens(parenString: string): number[] { 2 | /** 3 | * Input to this function is a string represented multiple groups for nested parentheses separated by spaces. 4 | * For each of the group, output the deepest level of nesting of parentheses. 5 | * E.g. (()()) has maximum two levels of nesting while ((())) has three. 6 | * 7 | * >>> parseNestedParens('(()()) ((())) () ((())()())') 8 | * [2, 3, 1, 3] 9 | */ 10 | 11 | function parseParenGroup(s: string): number { 12 | let depth = 0; 13 | let maxDepth = 0; 14 | for (const c of s) { 15 | if (c === '(') { 16 | depth += 1; 17 | maxDepth = Math.max(depth, maxDepth); 18 | } else { 19 | depth -= 1; 20 | } 21 | } 22 | return maxDepth; 23 | } 24 | // CHANGE: rewrite map expression 25 | let result: number[] = []; 26 | for (const group of parenString.split(' ', 100000)) { 27 | if (group) { 28 | result.push(parseParenGroup(group)); 29 | } 30 | } 31 | return result; 32 | } 33 | -------------------------------------------------------------------------------- /test/data/manually_fixed/HumanEval_7.ts: -------------------------------------------------------------------------------- 1 | function filterBySubstring(strings: string[], substring: string): string[] { 2 | /** 3 | * Filter an input list of strings only for ones that contain given substring 4 | * 5 | * Example usage: 6 | * filterBySubstring([], 'a'); // [] 7 | * filterBySubstring(['abc', 'bacd', 'cde', 'array'], 'a'); // ['abc', 'bacd', 'array'] 8 | */ 9 | // CHANGE: add parens to filter 10 | return strings.filter((x) => x.includes(substring)); 11 | } 12 | -------------------------------------------------------------------------------- /test/data/manually_fixed/HumanEval_8.ts: -------------------------------------------------------------------------------- 1 | // CHANGE: replace tuple type with list 2 | function sumProduct(numbers: number[]): number[] { 3 | /** 4 | * For a given list of integers, return a tuple consisting of a sum and a product of all the integers in a list. 5 | * Empty sum should be equal to 0 and empty product should be equal to 1. 6 | * sumProduct([]) => [0, 1] 7 | * sumProduct([1, 2, 3, 4]) => [10, 24] 8 | */ 9 | 10 | let sumValue = 0; 11 | let prodValue = 1; 12 | 13 | for (const n of numbers) { 14 | sumValue += n; 15 | prodValue *= n; 16 | } 17 | return [sumValue, prodValue]; 18 | } 19 | -------------------------------------------------------------------------------- /test/data/manually_fixed/HumanEval_9.ts: -------------------------------------------------------------------------------- 1 | function rollingMax(numbers: number[]): number[] { 2 | /** 3 | * From a given list of integers, generate a list of rolling maximum element found until given moment 4 | * in the sequence. 5 | * @example 6 | * rollingMax([1, 2, 3, 2, 3, 4, 2]) 7 | * // returns [1, 2, 3, 3, 3, 4, 4] 8 | */ 9 | 10 | // CHANGE: replace union number|null with number, introduce started variable 11 | let started = false; 12 | let runningMax: number = 0; 13 | let result: number[] = []; 14 | 15 | for (let n of numbers) { 16 | if (!started) { 17 | runningMax = n; 18 | started = true; 19 | } else { 20 | runningMax = Math.max(runningMax, n); 21 | } 22 | 23 | result.push(runningMax); 24 | } 25 | 26 | return result; 27 | } 28 | -------------------------------------------------------------------------------- /test/data/print_instance.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import sys 4 | 5 | multiple_e_instance_file = ( 6 | pathlib.Path(__file__).parent 7 | / "openai_openai_humaneval_ts_gpt-4o-2024-05-13_filtered.jsonl" 8 | ) 9 | instance_num = int(sys.argv[1]) 10 | 11 | with open(multiple_e_instance_file, "r") as f: 12 | for i, line in enumerate(f): 13 | if i == int(instance_num): 14 | print(json.loads(line)["translation"]) 15 | break 16 | -------------------------------------------------------------------------------- /test/test_parser_base.py: -------------------------------------------------------------------------------- 1 | from typesafe_llm.parser.parser_base import ( 2 | TerminalParserState, 3 | incremental_parse, 4 | IncrementalParserState, 5 | ) 6 | from test.utils import assert_weak_full, assert_reject, assert_strict_partial 7 | 8 | initial_ws_terminal = TerminalParserState(target_value="hello world\t !") 9 | initial_ws_terminal_state = IncrementalParserState([initial_ws_terminal], "") 10 | 11 | 12 | def test_accept_initial_char_terminal(): 13 | states = initial_ws_terminal.parse_char("h") 14 | assert_strict_partial(states) 15 | 16 | 17 | def test_dont_accept_ws_initial_char_terminal(): 18 | states = initial_ws_terminal.parse_char(" ") 19 | assert_reject(states) 20 | 21 | 22 | def test_accept_ws_optional_one(): 23 | states = incremental_parse(initial_ws_terminal_state, "hello ") 24 | assert_strict_partial(states) 25 | 26 | 27 | def test_accept_ws_optional_many(): 28 | states = incremental_parse(initial_ws_terminal_state, "hello \t\n w") 29 | assert_strict_partial(states) 30 | 31 | 32 | def test_accept_ws_optional_none(): 33 | states = incremental_parse(initial_ws_terminal_state, "hellow") 34 | assert_strict_partial(states) 35 | 36 | 37 | def test_accept_ws_required_one(): 38 | states = incremental_parse(initial_ws_terminal_state, "hello world !") 39 | assert_weak_full(states) 40 | 41 | 42 | def test_accept_ws_required_many(): 43 | states = incremental_parse(initial_ws_terminal_state, "hello world \t!") 44 | assert_weak_full(states) 45 | 46 | 47 | def test_accept_ws_required_none(): 48 | states = incremental_parse(initial_ws_terminal_state, "hello world!") 49 | assert_reject(states) 50 | -------------------------------------------------------------------------------- /test/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | from typesafe_llm.parser.parser_base import IncrementalParsingState 4 | 5 | 6 | def assert_partial(states): 7 | assert states, "States should not be empty" 8 | 9 | 10 | def assert_strict_partial(states): 11 | assert states, "States should not be empty" 12 | assert all(not state.accept for state in states), "State should not accept" 13 | 14 | 15 | def assert_weak_full(states): 16 | assert states, "States should not be empty" 17 | assert any(state.accept for state in states), "some State should accept" 18 | 19 | 20 | def assert_reject(states): 21 | assert not states, "States should be empty" 22 | 23 | 24 | def assert_strict_partial_or_reject(states): 25 | assert all(not state.accept for state in states), "State should not accept" 26 | 27 | 28 | def assert_just_before_reject_generic( 29 | parse_program: Callable[[str], List[IncrementalParsingState]], 30 | incremental_parse: Callable[ 31 | [List[IncrementalParsingState], str], List[IncrementalParsingState] 32 | ], 33 | program: str, 34 | ): 35 | program_before = program[:-1] 36 | states = parse_program(program_before) 37 | assert_partial(states) 38 | states = incremental_parse(states, program[-1]) 39 | assert_reject(states) 40 | -------------------------------------------------------------------------------- /ts_parser/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | .idea 3 | test.ts 4 | test.js 5 | -------------------------------------------------------------------------------- /ts_parser/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ts_parser" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | oxc_allocator = "0.33.0" 8 | oxc_parser = "0.33.0" 9 | oxc_span = "0.33.0" 10 | -------------------------------------------------------------------------------- /ts_parser/README.md: -------------------------------------------------------------------------------- 1 | # ts_parser 2 | 3 | A simple syntax checker for typescript. 4 | 5 | There is no way to simply tell whether some code has syntax errors based on tsc. This tool is a simple CLI tool based on the oxidation toolkit which will return 0 for syntactically valid code and 1 for syntactically invalid code. -------------------------------------------------------------------------------- /ts_parser/build.sh: -------------------------------------------------------------------------------- 1 | rustup update 2 | cargo build --release -------------------------------------------------------------------------------- /ts_parser/install_rust.sh: -------------------------------------------------------------------------------- 1 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -------------------------------------------------------------------------------- /ts_parser/src/main.rs: -------------------------------------------------------------------------------- 1 | use oxc_allocator::Allocator; 2 | use oxc_parser::{Parser, ParserReturn}; 3 | use oxc_span::SourceType; 4 | use std::env; 5 | use std::fs; 6 | 7 | fn main() { 8 | // Get the file path from the command-line arguments 9 | let args: Vec = env::args().collect(); 10 | if args.len() < 2 { 11 | eprintln!("Usage: {} ", args[0]); 12 | std::process::exit(1); 13 | } 14 | 15 | let file_path = &args[1]; 16 | 17 | // Read the file content 18 | let source_text = match fs::read_to_string(file_path) { 19 | Ok(content) => content, 20 | Err(err) => { 21 | eprintln!("Error reading file {}: {}", file_path, err); 22 | std::process::exit(1); 23 | } 24 | }; 25 | 26 | // Memory arena where AST nodes get stored 27 | let allocator = Allocator::default(); 28 | // Infer the source type based on the file extension 29 | let source_type = match SourceType::from_path(file_path) { 30 | Ok(st) => st, 31 | Err(x) => { 32 | eprintln!("Could not determine source type from file extension. {x}"); 33 | std::process::exit(1); 34 | } 35 | }; 36 | 37 | let ParserReturn { 38 | errors, // Syntax errors 39 | panicked, // Parser encountered an error it couldn't recover from 40 | .. 41 | } = Parser::new(&allocator, &source_text, source_type).parse(); 42 | 43 | if panicked || !errors.is_empty() { 44 | // Print errors or panic details 45 | if panicked { 46 | eprintln!("Parser panicked."); 47 | } 48 | 49 | if !errors.is_empty() { 50 | eprintln!("Parsing failed with the following errors:"); 51 | for error in errors { 52 | eprintln!("{:?}", error); 53 | } 54 | } 55 | std::process::exit(1); 56 | } 57 | 58 | println!("File parsed successfully."); 59 | std::process::exit(0); 60 | } 61 | -------------------------------------------------------------------------------- /typesafe_llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/typesafe_llm/__init__.py -------------------------------------------------------------------------------- /typesafe_llm/parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eth-sri/type-constrained-code-generation/0cba132d35d9277d0538292d1bdbabb88a5a447a/typesafe_llm/parser/__init__.py -------------------------------------------------------------------------------- /typesafe_llm/parser/parser_shared.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field, replace 2 | from functools import partial 3 | from typing import List, Type, Self 4 | 5 | 6 | from typesafe_llm.parser.parser_base import ( 7 | ConcatParserState, 8 | IncrementalParsingState, 9 | TerminalParserState, 10 | MAX_COMMENT_LENGTH, 11 | UnionParserState, 12 | ) 13 | from typesafe_llm.parser.util import fnr_dataclass 14 | 15 | 16 | @fnr_dataclass 17 | class AnyStringParser(IncrementalParsingState): 18 | terminator_chars: List[str] = field(default_factory=lambda: []) 19 | accept: bool = False 20 | length: int = 0 21 | 22 | def parse_char(self, char: str) -> List[Self]: 23 | if self.accept: 24 | return [] 25 | if char in self.terminator_chars: 26 | return [replace(self, accept=True)] 27 | if self.length >= MAX_COMMENT_LENGTH: 28 | return [] 29 | return [replace(self, length=self.length + 1)] 30 | 31 | def num_active_states(self): 32 | return 1 33 | 34 | 35 | @fnr_dataclass 36 | class AnyStringParserNoSeq(IncrementalParsingState): 37 | terminator_seq: List[str] = field(default_factory=lambda: []) 38 | accepted: str = "" 39 | accept: bool = False 40 | 41 | @property 42 | def max_bs_len(self): 43 | return max(len(bs) for bs in self.terminator_seq) 44 | 45 | def parse_char(self, char: str) -> List[Self]: 46 | new_str = self.accepted + char 47 | if self.accept: 48 | return [] 49 | if any(new_str[-len(bs) :] == bs for bs in self.terminator_seq): 50 | return [replace(self, accept=True)] 51 | return [replace(self, accepted=new_str[-self.max_bs_len :])] 52 | 53 | def num_active_states(self): 54 | return 1 55 | 56 | 57 | @fnr_dataclass() 58 | class CommentParserState(ConcatParserState): 59 | # TODO: comments are allowed basically anywhere where whitespace is allowed -> need to add it in all parsers? 60 | parse_classes: List[Type[IncrementalParsingState]] = field( 61 | default_factory=lambda: ( 62 | partial(TerminalParserState, target_value=" //"), 63 | partial(AnyStringParser, terminator_chars=["\n"]), 64 | ), 65 | repr=False, 66 | ) 67 | 68 | 69 | @fnr_dataclass() 70 | class MultilineCommentParserState(ConcatParserState): 71 | # TODO: comments are allowed basically anywhere where whitespace is allowed -> need to add it in all parsers? 72 | parse_classes: List[Type[IncrementalParsingState]] = field( 73 | default_factory=lambda: ( 74 | partial(TerminalParserState, target_value=" /*"), 75 | partial(AnyStringParserNoSeq, terminator_seq=["*/"]), 76 | ), 77 | repr=False, 78 | ) 79 | 80 | 81 | @fnr_dataclass() 82 | class EOLParserState(UnionParserState): 83 | parse_classes: List[Type[IncrementalParsingState]] = field( 84 | default_factory=lambda: [ 85 | partial(TerminalParserState, target_value=" ;"), 86 | ] 87 | ) 88 | 89 | 90 | @fnr_dataclass() 91 | class BreakStmtParserState(TerminalParserState): 92 | target_value: str = " break ;" 93 | 94 | def parse_char(self, char: str) -> List[Self]: 95 | if not self.in_loop: 96 | return [] 97 | return super().parse_char(char) 98 | 99 | 100 | @fnr_dataclass() 101 | class ContinueStmtParserState(TerminalParserState): 102 | target_value: str = " continue ;" 103 | 104 | def parse_char(self, char: str) -> List[Self]: 105 | if not self.in_loop: 106 | return [] 107 | return super().parse_char(char) 108 | 109 | 110 | EmptyStmtParserState = EOLParserState 111 | -------------------------------------------------------------------------------- /typesafe_llm/parser/types_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic tooling for types in a type-safe parser 3 | """ 4 | 5 | from typing import ( 6 | Self, 7 | Tuple, 8 | Set, 9 | Dict, 10 | Literal, 11 | ) 12 | 13 | from .util import fnr_dataclass 14 | 15 | OperatorPrecedence = Tuple[int, Literal["left", "right"]] 16 | 17 | 18 | @fnr_dataclass() 19 | class PType: 20 | """ 21 | Base class for types as handled by the parser i.e. ParserTypes 22 | """ 23 | 24 | @property 25 | def attributes(self) -> Dict[str, Tuple[Self, bool]]: 26 | """ 27 | Attributes of the type 28 | """ 29 | raise NotImplementedError(f"abstract function in {self.__class__.__name__}") 30 | 31 | @property 32 | def nesting_depth(self) -> Tuple[int, int]: 33 | """ 34 | Nesting depth of the type (array nesting, function nesting) 35 | """ 36 | raise NotImplementedError(f"abstract function in {self.__class__.__name__}") 37 | 38 | @property 39 | def root_values(self) -> Set[Self]: 40 | raise NotImplementedError(f"abstract function in {self.__class__.__name__}") 41 | 42 | def __ge__(self, other): 43 | """ 44 | Assignment compatibility. 45 | self >= other iff value of type other can be assigned to a variable of type self 46 | :param other: 47 | :return: 48 | """ 49 | raise NotImplementedError() 50 | 51 | def __repr__(self): 52 | return f"{self.__class__.__name__}()" 53 | 54 | def __eq__(self, other): 55 | raise NotImplementedError() 56 | 57 | def __hash__(self): 58 | raise NotImplementedError() 59 | 60 | def instantiate_type_params(self, type_params: Dict[str, Self]) -> Self: 61 | """ 62 | Instantiate type parameters in the type 63 | """ 64 | raise NotImplementedError() 65 | 66 | def type_params(self) -> Set[str]: 67 | """ 68 | Check if the type has type parameters 69 | """ 70 | return set() 71 | 72 | 73 | @fnr_dataclass 74 | class AnyPType(PType): 75 | """ 76 | Any type 77 | """ 78 | 79 | @property 80 | def attributes(self) -> Dict[str, Tuple[Self, bool]]: 81 | return dict() 82 | 83 | def __ge__(self, other): 84 | return True 85 | 86 | @property 87 | def nesting_depth(self) -> Tuple[int, int]: 88 | return (0, 0) 89 | 90 | def reachable_types(self): 91 | return [self] 92 | 93 | @property 94 | def root_values(self) -> Set[Self]: 95 | return {self} 96 | 97 | def __eq__(self, other): 98 | return isinstance(other, AnyPType) 99 | 100 | def __hash__(self): 101 | return hash(self.__class__) 102 | 103 | def __repr__(self): 104 | return "AnyPType()" 105 | 106 | def __str__(self): 107 | return "unknown" 108 | 109 | def instantiate_type_params(self, type_params: Dict[str, Self]) -> Self: 110 | return self 111 | -------------------------------------------------------------------------------- /typesafe_llm/parser/util.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | 4 | sum_list = partial(sum, start=[]) 5 | fnr_dataclass = partial(dataclass, frozen=True, repr=False) 6 | 7 | ALPHABET = set(chr(i) for i in range(ord("a"), ord("z") + 1)) 8 | ALPHABET.update(chr(i) for i in range(ord("A"), ord("Z") + 1)) 9 | DIGITS = set(str(i) for i in range(10)) 10 | WHITESPACE = {" ", "\t", "\n", "\r"} 11 | NON_BREAKING_WHITESPACE = {" ", "\t"} 12 | 13 | 14 | def union_dict(d1, d2, *args): 15 | """Return the combination of two or more dictionaries""" 16 | if args: 17 | return union_dict(d1, union_dict(d2, *args)) 18 | return {**d1, **d2} 19 | 20 | 21 | def intersection_dict(d1, d2, *args): 22 | """Return the intersection of two or more dictionaries""" 23 | if args: 24 | return intersection_dict(d1, intersection_dict(d2, *args)) 25 | return {k: v for k, v in d1.items() if k in d2 and d1[k] == d2[k]} 26 | 27 | 28 | def update_keys(d1, d2): 29 | """Update the entries of d1 with the value of d2 iff d2 contains that value""" 30 | return {k: d2.get(k, v) for k, v in d1.items()} 31 | -------------------------------------------------------------------------------- /typesafe_llm/trie.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2022 Kanishk Gandhi 5 | Copyright (c) 2024 Anonymized (adapted) 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | #!/usr/bin/env python3 27 | 28 | import collections 29 | from typing import Iterable 30 | 31 | from typesafe_llm.util import pflush 32 | 33 | 34 | # Trie representation of a vocabulary. 35 | class Trie: 36 | def __init__(self, value=None, enforce_token_maximality=True): 37 | self._children = collections.defaultdict( 38 | lambda: Trie(enforce_token_maximality=enforce_token_maximality) 39 | ) 40 | self._value = [value] if value is not None else [] 41 | self._enforce_token_maximality = enforce_token_maximality 42 | 43 | def insert(self, key, value, depth=0): 44 | if len(key) == depth: 45 | self._value.append(value) 46 | else: 47 | self._children[key[depth]].insert(key, value, depth + 1) 48 | 49 | @staticmethod 50 | def from_vocabulary(vocab: Iterable[str], enforce_token_maximality: bool = True): 51 | t = Trie(enforce_token_maximality=enforce_token_maximality) 52 | 53 | for i, token in enumerate(vocab): 54 | if token: 55 | t.insert(token, i) 56 | 57 | return t 58 | 59 | def antimonotonic_filter( 60 | self, parse_fn, states, key="", _pflush=pflush 61 | ) -> list[tuple[str, list]]: 62 | # NOTE only works when all keys are unique! 63 | # key_d = json.dumps(key)[1:-1] 64 | # _pflush(key_d) 65 | this_node_valid = parse_fn(states, key) if key else states 66 | 67 | if not this_node_valid: 68 | # Prune using anti-monotonicity: no children will be valid. 69 | # delete(key_d, _pflush) 70 | return [] 71 | 72 | children_values = [] 73 | 74 | for k, c in self._children.items(): 75 | children_values += c.antimonotonic_filter( 76 | parse_fn, this_node_valid, k, _pflush=_pflush 77 | ) 78 | 79 | this_value = [(v, this_node_valid) for v in self._value] 80 | # delete(key_d, _pflush) 81 | 82 | if self._enforce_token_maximality: 83 | # Only return maximal strings. 84 | if len(children_values) or not self._value: 85 | return children_values 86 | return this_value 87 | 88 | return this_value + children_values 89 | -------------------------------------------------------------------------------- /typesafe_llm/util.py: -------------------------------------------------------------------------------- 1 | def pflush(string: str): 2 | print(string, flush=True, end="") 3 | 4 | 5 | def delete(string: str, _pflush=pflush): 6 | _pflush(len(string) * "\b" + len(string) * " " + len(string) * "\b") 7 | --------------------------------------------------------------------------------