├── .gitattributes ├── .github └── workflows │ ├── codeql.yml │ └── dependency-review.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── dataset └── ranker_datasets.zip ├── requirements.txt └── src ├── compute_metrics.py ├── eval.sh ├── finetune.sh ├── reindent.py ├── run_seq_classification.py └── run_seq_classification_and_line_prediction.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.gzip filter=lfs diff=lfs merge=lfs -text 2 | *.zip filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '24 12 * * 3' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /.github/workflows/dependency-review.yml: -------------------------------------------------------------------------------- 1 | # Dependency Review Action 2 | # 3 | # This Action will scan dependency manifest files that change as part of a Pull Request, surfacing known-vulnerable versions of the packages declared or updated in the PR. Once installed, if the workflow run is marked as required, PRs introducing known-vulnerable packages will be blocked from merging. 4 | # 5 | # Source repository: https://github.com/actions/dependency-review-action 6 | # Public documentation: https://docs.github.com/en/code-security/supply-chain-security/understanding-your-software-supply-chain/about-dependency-review#dependency-review-enforcement 7 | name: 'Dependency Review' 8 | on: [pull_request] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | dependency-review: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: 'Checkout Repository' 18 | uses: actions/checkout@v3 19 | - name: 'Dependency Review' 20 | uses: actions/dependency-review-action@v2 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fault-aware Neural Code Rankers 2 | 3 | > This repo has the code accompanying the CodeRanker NeurIPS'22 [paper](https://arxiv.org/pdf/2206.03865.pdf). 4 | 5 | ## Installation 6 | ``` 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## Dataset 11 | The ranker datasets are available through GIT LFS in dataset/ directory 12 | 13 | 14 | ## Usage 15 | First, cd into src directory 16 | 17 | 1. Training rankers 18 | ``` 19 | bash finetune.sh DATADIR MODELDIR CACHEDIR TASK 20 | ``` 21 | where DATADIR is the location of the directory containing the desired ranker dataset 22 | MODELDIR is the location of the output dirctory where the trained model will be stored 23 | CACHEDIR is the location of the cache directory both for caching the model and for caching the dataset 24 | TASK is one of binary, ternary, intent_error, execution_error, execution_error_with_line 25 | 26 | 2. Inference with rankers 27 | ``` 28 | bash eval.sh DATADIR MODELDIR CACHEDIR TASK {val|test}.json PREDICT_FILENAME 29 | ``` 30 | where PREDICT_FILENAME is the name of the file inside MODELDIR where the inferenced logits will be stored. 31 | 32 | 3. Computing the ranked metrics 33 | ``` 34 | python3 compute_metrics.py --data_file=DATADIR/{val|test}.json --logits_prediction_file=MODELDIR/PREDICT_FILENAME --labels_file=DATADIR/labels_TASK.txt --task=TASK 35 | ``` 36 | 37 | 38 | ## Contributing 39 | 40 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 41 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 42 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 43 | 44 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 45 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 46 | provided by the bot. You will only need to do this once across all repos using our CLA. 47 | 48 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 49 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 50 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 51 | 52 | ## Trademarks 53 | 54 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 55 | trademarks or logos is subject to and must follow 56 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 57 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 58 | Any use of third-party trademarks or logos are subject to those third-party's policies. 59 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /dataset/ranker_datasets.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a50ed0835b29d036b971187de7d29921ceca8e401565e35e083875876433d9da 3 | size 2024872843 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch>=1.7 3 | transformers>=4 4 | datasets 5 | scikit-learn 6 | scipy 7 | -------------------------------------------------------------------------------- /src/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | from xmlrpc.client import TRANSPORT_ERROR 4 | from tqdm import tqdm 5 | import scipy 6 | from scipy import special 7 | import re 8 | import numpy as np 9 | import pdb 10 | from sklearn.metrics import precision_recall_fscore_support 11 | import os 12 | 13 | 14 | def get_data(logits_prediction_file, data_file, labels_file, prompt_key='prompt'): 15 | # outfile only contains logits 16 | # datafile contains all data and they should be in the same order 17 | all_logits = [] 18 | with open(logits_prediction_file, 'r') as f: 19 | for line in tqdm(f): 20 | # skip header 21 | if line.startswith('index'): 22 | continue 23 | line = line.strip() 24 | logits = line.split('\t')[1] 25 | logits = eval(logits) 26 | all_logits.append(logits) 27 | print(len(all_logits)) 28 | 29 | with open(data_file, 'r') as f: 30 | data_dict = {} 31 | for i, line in tqdm(enumerate(f)): 32 | data = json.loads(line) 33 | data["logits"] = all_logits[i] 34 | 35 | key = data[prompt_key] 36 | if key not in data_dict: 37 | data_dict[key] = [] 38 | data_dict[key].append(data) 39 | assert(i == len(all_logits) -1 ) 40 | print(len(data_dict)) 41 | 42 | with open(labels_file, 'r') as f: 43 | labels = [] 44 | for line in tqdm(f): 45 | line = line.strip() 46 | labels.append(line) 47 | 48 | return data_dict, labels 49 | 50 | def process(data_dict, labels, label_key, pass_label="Correct"): 51 | # convert data_dict to list of data 52 | data = list(data_dict.values()) 53 | pass_idx = labels.index(pass_label) 54 | 55 | probs = [[scipy.special.softmax(d['logits'], axis=0)[pass_idx] for d in data[i]] for i in range(len(data))] 56 | 57 | grouped_labels = [[d['ternary_label'] for d in data[i]] for i in range(len(data))] 58 | 59 | # make all rows the same length 60 | max_len = max([len(d) for d in data]) 61 | probs = np.array([d + [-1] * (max_len - len(d)) for d in probs] ) # num_prompts x max_num_suggestions 62 | grouped_labels = np.array([d + ["error"] * (max_len - len(d)) for d in grouped_labels]) # num_prompts x max_num_suggestions 63 | 64 | results = {} 65 | # compute vanillar metrics 66 | res = compute_vanilla_metrics(grouped_labels) 67 | 68 | # compute ranked accuracies 69 | print("Renked metrics") 70 | res = compute_metrics(probs, grouped_labels) 71 | 72 | def pass_at_k(n, c, k): 73 | """ 74 | :param n: total number of samples 75 | :param c: number of correct samples 76 | :param k: k in pass@$k$ 77 | """ 78 | if n - c < k: return 1.0 79 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 80 | 81 | def compute_vanilla_metrics(grouped_labels): 82 | pass_1 = [] 83 | exec_1 = [] 84 | pass_5 = [] 85 | exec_5 = [] 86 | pass_100 = [] 87 | exec_100 = [] 88 | for i in range(len(grouped_labels)): 89 | labels = grouped_labels[i] 90 | num_correct = 0 91 | num_error_free = 0 92 | for j in range(len(labels)): 93 | if labels[j] == "Correct": 94 | num_correct += 1 95 | if labels[j] != "Execution error": 96 | num_error_free += 1 97 | pass_1.append(pass_at_k(len(labels), num_correct, 1)) 98 | exec_1.append(pass_at_k(len(labels), num_error_free, 1)) 99 | pass_5.append(pass_at_k(len(labels), num_correct, 5)) 100 | exec_5.append(pass_at_k(len(labels), num_error_free, 5)) 101 | 102 | pass_100.append(pass_at_k(len(labels), num_correct, len(labels))) 103 | exec_100.append(pass_at_k(len(labels), num_error_free, len(labels))) 104 | 105 | print("Vanilla model") 106 | res = { 107 | "pass_1": np.mean(pass_1), 108 | "pass_5": np.mean(pass_5), 109 | "exec_1": np.mean(exec_1), 110 | "exec_5": np.mean(exec_5), 111 | "pass_100": np.mean(pass_100), 112 | "exec_100": np.mean(exec_100), 113 | } 114 | print(res) 115 | 116 | return res 117 | 118 | 119 | def compute_metrics(probs, grouped_labels): 120 | # top-1 accuracy 121 | best = np.argmax(probs, axis=1) 122 | predictions = grouped_labels[np.arange(len(grouped_labels)), best] 123 | correct = predictions == "Correct" 124 | error_free = predictions != "Execution error" 125 | 126 | ranker_accuracy = np.mean(correct) 127 | ranker_accuracy_error_free = np.mean(error_free) 128 | 129 | # top-5 accuracy 130 | best = np.argsort(probs, axis=1)[:, -5:] 131 | predictions_top_5 = [] 132 | for i in range(len(best)): 133 | predictions_top_5.append([grouped_labels[i, j] for j in best[i]]) 134 | predictions_top_5 = np.array(predictions_top_5) 135 | correct_top_5 = predictions_top_5 == "Correct" 136 | error_free_top_5 = predictions_top_5 != "Execution error" 137 | 138 | ranker_accuracy_top_5 = np.mean(correct_top_5) 139 | ranker_accuracy_error_free_top_5 = np.mean(error_free_top_5) 140 | 141 | # top-5 best accuracy 142 | best_accuracy_top_5 = np.max(correct_top_5, axis=1).mean() 143 | best_accuracy_error_free_top_5 = np.max(error_free_top_5, axis=1).mean() 144 | 145 | res = { 146 | "pass_1": round(ranker_accuracy, 3), 147 | "pass_5": round(best_accuracy_top_5, 3), 148 | "exec_1": round(ranker_accuracy_error_free, 3), 149 | "exec_5": round(best_accuracy_error_free_top_5, 3), 150 | } 151 | 152 | print(res) 153 | return res 154 | 155 | 156 | if __name__ == "__main__": 157 | import argparse 158 | 159 | parser = argparse.ArgumentParser(description="Compute metrics with ranker") 160 | parser.add_argument("--data_file", type=str, default="../ranker_datasets/gpt_neo_125m/val.json") 161 | parser.add_argument("--logits_prediction_file", type=str, default="~/ranker_model_for_gpt_neo_125m/predict_results_test.txt") 162 | parser.add_argument("--labels_file", type = str, default = "../ranker_datasets/gpt_neo_125m/labels_binary.txt") 163 | parser.add_argument("--task", type = str, default = "binary") 164 | 165 | args = parser.parse_args() 166 | 167 | label_key = args.task + "_label" 168 | 169 | data_dict, labels = get_data(args.logits_prediction_file, args.data_file, args.labels_file) 170 | process(data_dict, labels, label_key) 171 | -------------------------------------------------------------------------------- /src/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | MODEL=microsoft/codebert-base 3 | 4 | if [ "$#" -eq 2 ]; then 5 | DATASET=$1 6 | TASK=$2 7 | # default args to run locally 8 | DATA_DIR=../ranker_datasets/$DATASET/ 9 | MODEL_CACHE_DIR=~/ranker_model_for_${DATASET}_cache 10 | MODEL_DIR=~/ranker_model_for_$DATASET 11 | TEST_FILE_SUFFIX=val.json 12 | LABELS_SUFFIX=labels_$TASK.txt 13 | LABEL_KEY=${TASK}_label 14 | PREDICT_FILE_SUFFIX=test 15 | else 16 | # arguments 17 | DATA_DIR=$1 18 | MODEL_DIR=$2 19 | MODEL_CACHE_DIR=$3 20 | TASK=$4 21 | TEST_FILE=$5 22 | PREDICT_FILE_SUFFIX=$6 23 | LABELS_SUFFIX=labels_$TASK.txt 24 | LABEL_KEY=${TASK}_label 25 | fi 26 | 27 | if [ $TASK != "execution_error_with_line" ]; then 28 | python run_seq_classification.py \ 29 | --output_dir $MODEL_DIR \ 30 | --cache_dir $MODEL_CACHE_DIR \ 31 | --model_name_or_path $MODEL \ 32 | --test_file $DATA_DIR/$TEST_FILE \ 33 | --sentence1_key prompt \ 34 | --sentence2_key completion \ 35 | --label_key $LABEL_KEY \ 36 | --labels_file $DATA_DIR/$LABELS_SUFFIX \ 37 | --max_seq_length 512 \ 38 | --do_predict \ 39 | --per_device_eval_batch_size 32 \ 40 | --predict_suffix $PREDICT_FILE_SUFFIX \ 41 | --overwrite_cache \ 42 | 43 | else 44 | LABELS_SUFFIX=labels_execution_error.txt 45 | LABEL_KEY=execution_error_label 46 | 47 | python3 run_seq_classification_and_line_prediction.py \ 48 | --output_dir $MODEL_DIR \ 49 | --cache_dir $MODEL_CACHE_DIR \ 50 | --model_name_or_path $MODEL \ 51 | --test_file $DATA_DIR/$TEST_FILE \ 52 | --sentence1_key prompt \ 53 | --sentence2_key completion \ 54 | --label_key $LABEL_KEY \ 55 | --labels_file $DATA_DIR/$LABELS_SUFFIX \ 56 | --max_seq_length 512 \ 57 | --do_predict \ 58 | --per_device_eval_batch_size 32 \ 59 | --predict_suffix $PREDICT_FILE_SUFFIX \ 60 | --overwrite_cache \ 61 | 62 | fi -------------------------------------------------------------------------------- /src/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | MODEL=microsoft/codebert-base 3 | 4 | if [ "$#" -eq 2 ]; then 5 | DATASET=$1 6 | TASK=$2 7 | # default args to run locally 8 | DATA_DIR=ranker_datasets/$DATASET/ 9 | MODEL_CACHE_DIR=~/ranker_model_for_${DATASET}_cache 10 | MODEL_DIR=~/ranker_model_for_$DATASET 11 | TRAIN_FILE_SUFFIX=train.json 12 | VAL_FILE_SUFFIX=val.json 13 | LABELS_SUFFIX=labels_$TASK.txt 14 | WEIGHTS_SUFFIX=weights_$TASK.txt 15 | LABEL_KEY=${TASK}_label 16 | else 17 | # arguments 18 | DATA_DIR=$1 19 | MODEL_DIR=$2 20 | MODEL_CACHE_DIR=$3 21 | TASK=$4 22 | TRAIN_FILE_SUFFIX=train.json 23 | VAL_FILE_SUFFIX=val.json 24 | LABELS_SUFFIX=labels_$TASK.txt 25 | WEIGHTS_SUFFIX=weights_$TASK.txt 26 | LABEL_KEY=${TASK}_label 27 | 28 | fi 29 | 30 | if [ $TASK != "execution_error_with_line" ]; then 31 | python3 run_seq_classification.py \ 32 | --output_dir $MODEL_DIR \ 33 | --cache_dir $MODEL_CACHE_DIR \ 34 | --model_name_or_path $MODEL \ 35 | --train_file $DATA_DIR/$TRAIN_FILE_SUFFIX \ 36 | --validation_file $DATA_DIR/$VAL_FILE_SUFFIX \ 37 | --sentence1_key prompt \ 38 | --sentence2_key completion \ 39 | --label_key $LABEL_KEY \ 40 | --labels_file $DATA_DIR/$LABELS_SUFFIX \ 41 | --weights_file $DATA_DIR/$WEIGHTS_SUFFIX \ 42 | --grouped_indices_file $DATA_DIR/val_grouped_indices.npy \ 43 | --grouped_labels_file $DATA_DIR/val_grouped_labels.npy \ 44 | --max_seq_length 512 \ 45 | --do_train \ 46 | --do_eval \ 47 | --per_device_train_batch_size 16 \ 48 | --per_device_eval_batch_size 32 \ 49 | --learning_rate 1e-4 \ 50 | --warmup_steps 1000 \ 51 | --weight_decay 0.01 \ 52 | --gradient_accumulation_steps 32 \ 53 | --num_train_epochs 30 \ 54 | --evaluation_strategy steps \ 55 | --save_strategy steps \ 56 | --logging_steps 10 \ 57 | --load_best_model_at_end \ 58 | --metric_for_best_model top1_accuracy \ 59 | --logging_first_step \ 60 | --eval_steps 10 \ 61 | --save_steps 10 62 | 63 | else 64 | LABELS_SUFFIX=labels_execution_error.txt 65 | WEIGHTS_SUFFIX=weights_execution_error.txt 66 | LABEL_KEY=execution_error_label 67 | 68 | python3 run_seq_classification_and_line_prediction.py \ 69 | --output_dir $MODEL_DIR \ 70 | --cache_dir $MODEL_CACHE_DIR \ 71 | --model_name_or_path $MODEL \ 72 | --train_file $DATA_DIR/$TRAIN_FILE_SUFFIX \ 73 | --validation_file $DATA_DIR/$VAL_FILE_SUFFIX \ 74 | --sentence1_key prompt \ 75 | --sentence2_key completion \ 76 | --label_key $LABEL_KEY \ 77 | --labels_file $DATA_DIR/$LABELS_SUFFIX \ 78 | --weights_file $DATA_DIR/$WEIGHTS_SUFFIX \ 79 | --grouped_indices_file $DATA_DIR/val_grouped_indices.npy \ 80 | --grouped_labels_file $DATA_DIR/val_grouped_labels.npy \ 81 | --max_seq_length 512 \ 82 | --do_train \ 83 | --do_eval \ 84 | --per_device_train_batch_size 16 \ 85 | --per_device_eval_batch_size 32 \ 86 | --learning_rate 1e-4 \ 87 | --warmup_steps 1000 \ 88 | --weight_decay 0.01 \ 89 | --gradient_accumulation_steps 32 \ 90 | --num_train_epochs 30 \ 91 | --evaluation_strategy steps \ 92 | --save_strategy steps \ 93 | --logging_steps 10 \ 94 | --load_best_model_at_end \ 95 | --metric_for_best_model top1_accuracy \ 96 | --logging_first_step \ 97 | --eval_steps 10 \ 98 | --save_steps 10 \ 99 | --overwrite_output_dir \ 100 | 101 | fi 102 | -------------------------------------------------------------------------------- /src/reindent.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Reindent files. 4 | """ 5 | 6 | from __future__ import print_function 7 | import sys 8 | import getopt 9 | import codecs 10 | import tempfile 11 | import shutil 12 | import os 13 | 14 | 15 | def _find_indentation(line, config): 16 | if len(line) and line[0] in (" ", "\t") and not line.isspace(): 17 | if line[0] == "\t": 18 | config['is-tabs'] = True 19 | # Find indentation 20 | i = 0 21 | for char in list(line): 22 | if char not in (" ", "\t"): 23 | break 24 | i += 1 25 | config["from"] = i 26 | 27 | 28 | def find_indentation(line, config): 29 | # Find indentation level used in file 30 | if config['from'] < 0: 31 | _find_indentation(line, config) 32 | 33 | if config['from'] >= 0: 34 | # Set old indent 35 | indent = " " if not config['is-tabs'] else "\t" 36 | indent = indent * config['from'] 37 | 38 | # Set new indent 39 | newindent = " " if not config['tabs'] else "\t" 40 | if not config['tabs']: 41 | newindent = newindent * config['to'] 42 | 43 | return indent, newindent 44 | 45 | # Continue to the next line, indentation not found 46 | return False 47 | 48 | 49 | def replace_inline_tabs(content, config): 50 | newcontent = "" 51 | imagined_i = 0 52 | for i in range(0, len(content)): 53 | char = content[i] 54 | if char == '\t': 55 | spaces = config['tabsize']-(imagined_i % config['tabsize']) 56 | newcontent += " " * spaces 57 | imagined_i += spaces 58 | else: 59 | newcontent += char 60 | imagined_i += 1 61 | return newcontent 62 | 63 | 64 | def run(fd_in, fd_out, config): 65 | while True: 66 | line = fd_in.readline() 67 | if not line: 68 | break 69 | line = line.rstrip('\r\n') 70 | 71 | # Find indentation style used in file if not set 72 | if config['from'] < 0: 73 | indent = find_indentation(line, config) 74 | if not indent: 75 | print(line, file=fd_out) 76 | continue 77 | indent, newindent = indent 78 | 79 | # Find current indentation level 80 | level = 0 81 | while True: 82 | whitespace = line[:len(indent) * (level + 1)] 83 | if whitespace == indent * (level + 1): 84 | level += 1 85 | else: 86 | break 87 | 88 | content = line[len(indent) * level:] 89 | if config['all-tabs']: 90 | content = replace_inline_tabs(content, config) 91 | 92 | line = (newindent * level) + content 93 | print(line, file=fd_out) 94 | 95 | -------------------------------------------------------------------------------- /src/run_seq_classification.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from datasets import load_dataset 3 | 4 | import transformers 5 | from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, RobertaForSequenceClassification 6 | 7 | import numpy as np 8 | 9 | from transformers import TrainingArguments, Trainer, HfArgumentParser, default_data_collator, set_seed 10 | from transformers.trainer_utils import get_last_checkpoint 11 | 12 | from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score 13 | from transformers import EvalPrediction, WEIGHTS_NAME 14 | import torch 15 | import torch.nn as nn 16 | 17 | import argparse 18 | import os 19 | import sys 20 | import logging 21 | 22 | from dataclasses import dataclass, field 23 | from typing import Optional 24 | import random 25 | import glob 26 | import json 27 | from reindent import run as run_reindent 28 | import io 29 | import pdb 30 | import scipy 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | @dataclass 35 | class DataTrainingArguments: 36 | """ 37 | Arguments pertaining to what data we are going to input our model for training and eval. 38 | Using `HfArgumentParser` we can turn this class 39 | into argparse arguments to be able to specify them on 40 | the command line. 41 | """ 42 | 43 | max_seq_length: int = field( 44 | default=128, 45 | metadata={ 46 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 47 | "than this will be truncated, sequences shorter will be padded." 48 | }, 49 | ) 50 | overwrite_cache: bool = field( 51 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 52 | ) 53 | pad_to_max_length: bool = field( 54 | default=True, 55 | metadata={ 56 | "help": "Whether to pad all samples to `max_seq_length`. " 57 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 58 | }, 59 | ) 60 | truncate_texta_from_first: bool = field( 61 | default=False, 62 | metadata={ 63 | "help": "Truncate texta from the first when truncating. " 64 | "If False, will truncate texta from the last" 65 | }, 66 | ) 67 | max_train_samples: Optional[int] = field( 68 | default=None, 69 | metadata={ 70 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 71 | "value if set." 72 | }, 73 | ) 74 | max_eval_samples: Optional[int] = field( 75 | default=None, 76 | metadata={ 77 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 78 | "value if set." 79 | }, 80 | ) 81 | max_predict_samples: Optional[int] = field( 82 | default=None, 83 | metadata={ 84 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 85 | "value if set." 86 | }, 87 | ) 88 | train_file: Optional[str] = field( 89 | default=None, metadata={"help": "A csv or a json file containing the training data."} 90 | ) 91 | validation_file: Optional[str] = field( 92 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 93 | ) 94 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 95 | labels_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the labels."}) 96 | weights_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the weights."}) 97 | grouped_indices_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the grouped indices."}) 98 | grouped_labels_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the grouped labels."}) 99 | predict_suffix: Optional[str] = field(default="", metadata={"help": "Suffix for predict file."}) 100 | sentence1_key: Optional[str] = field( 101 | default="prompt", 102 | metadata={"help": "Name of the key for sentence1 in the dataset" }, 103 | ) 104 | sentence2_key: Optional[str] = field( 105 | default="completion", 106 | metadata={"help": "Name of the key for sentence2 in the dataset" }, 107 | ) 108 | label_key: Optional[str] = field( 109 | default="binary_label", 110 | metadata={"help": "Name of the key for label in the dataset" }, 111 | ) 112 | 113 | def __post_init__(self): 114 | if (self.train_file is None and self.validation_file is None) and self.test_file is None: 115 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 116 | 117 | 118 | 119 | @dataclass 120 | class ModelArguments: 121 | """ 122 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 123 | """ 124 | 125 | model_name_or_path: str = field( 126 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 127 | ) 128 | config_name: Optional[str] = field( 129 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 130 | ) 131 | tokenizer_name: Optional[str] = field( 132 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 133 | ) 134 | cache_dir: Optional[str] = field( 135 | default=None, 136 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 137 | ) 138 | use_fast_tokenizer: bool = field( 139 | default=True, 140 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 141 | ) 142 | model_revision: str = field( 143 | default="main", 144 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 145 | ) 146 | use_auth_token: bool = field( 147 | default=False, 148 | metadata={ 149 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 150 | "with private models)." 151 | }, 152 | ) 153 | 154 | 155 | def compute_metrics(p: EvalPrediction, compute_ranker_accuracy=False, grouped_indices=None, grouped_labels=None, pass_idx=1): 156 | # grouped_indices is a two-dimensional array where each row represents the indices of various datapoints in p that have 157 | # the same prompt 158 | # grouped_labels is the actual ternary labels of datapoints at the indices provided by grouped_indices 159 | 160 | pred_raw, labels = p 161 | if compute_ranker_accuracy: 162 | # we will also compute the actual accuracy of the ranker 163 | pred_softmax = scipy.special.softmax(pred_raw, axis=1) 164 | pred_prob = pred_softmax[:,pass_idx] # prob of is Correct as predicted by the ranker 165 | prob_fun = lambda x: pred_prob[x] if x >=0 else 0 166 | prob_fun = np.vectorize(prob_fun) 167 | grouped_prob = prob_fun(grouped_indices) # group prob predicted by task 168 | 169 | # top-1 accuracy 170 | best = np.argmax(grouped_prob, axis=1) # index of the best completion for each task as predicted by the ranker 171 | top1_label = grouped_labels[np.arange(len(grouped_labels)), best] # get the actual label of this best completion 172 | top1_accuracy = np.mean(top1_label == "Correct") 173 | top1_execution = 1.0 - np.mean(top1_label == "Execution error") 174 | 175 | 176 | pred = np.argmax(pred_raw, axis=1) 177 | pred_binary = pred == pass_idx 178 | labels_binary = labels == pass_idx 179 | accuracy = accuracy_score(y_true=labels_binary, y_pred=pred_binary) 180 | recall = recall_score(y_true=labels_binary, y_pred=pred_binary, average='micro') 181 | precision = precision_score(y_true=labels_binary, y_pred=pred_binary, average='micro') 182 | f1 = f1_score(y_true=labels_binary, y_pred=pred_binary, average='micro') 183 | 184 | # multi class predictions 185 | accuracy_mc = accuracy_score(y_true=labels, y_pred=pred) 186 | recall_mc = recall_score(y_true=labels, y_pred=pred, average='micro') 187 | precision_mc = precision_score(y_true=labels, y_pred=pred, average='micro') 188 | f1_mc = f1_score(y_true=labels, y_pred=pred, average='micro') 189 | 190 | metrics = {'f1': f1, 191 | 'precision': precision, 192 | 'recall': recall, 193 | 'accuracy': accuracy, 194 | 'f1_mc': f1_mc, 195 | 'precision_mc': precision_mc, 196 | 'recall_mc': recall_mc, 197 | 'accuracy_mc': accuracy_mc,} 198 | 199 | if compute_ranker_accuracy: 200 | metrics['top1_accuracy'] = top1_accuracy 201 | metrics['top1_execution'] = top1_execution 202 | return metrics 203 | 204 | def reindent_code(codestr, replace_set=[]): 205 | """ 206 | Given code string, reindent it in the same way that the 207 | Github dataset was indented 208 | """ 209 | codestr = io.StringIO(codestr) 210 | ret = io.StringIO() 211 | 212 | run_reindent( 213 | codestr, 214 | ret, 215 | config = { 216 | "dry-run": False, 217 | "help": False, 218 | "to": 10, 219 | "from": -1, 220 | "tabs": True, 221 | "encoding": "utf-8", 222 | "is-tabs": False, 223 | "tabsize": 10, 224 | "all-tabs": False 225 | } 226 | ) 227 | 228 | out = ret.getvalue() 229 | for s in replace_set: 230 | out = out.replace(s, "") 231 | 232 | return out 233 | 234 | # A trainer for balancing dataset with class weights 235 | class CustomTrainer(Trainer): 236 | def set_weights(self, class_weights): 237 | self.class_weights = class_weights 238 | def compute_loss(self, model, inputs, return_outputs=False): 239 | labels = inputs.get("labels") 240 | # forward pass 241 | outputs = model(**inputs) 242 | logits = outputs.get("logits") 243 | # compute custom loss (labels with different weights) 244 | loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(self.class_weights).to(logits.device)) 245 | loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) 246 | return (loss, outputs) if return_outputs else loss 247 | 248 | def main(): 249 | print("Helllllllooooooooooooo") 250 | # See all possible arguments in src/transformers/training_args.py 251 | # or by passing the --help flag to this script. 252 | # We now keep distinct sets of args, for a cleaner separation of concerns. 253 | 254 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 255 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 256 | 257 | # Setup logging 258 | logging.basicConfig( 259 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 260 | datefmt="%m/%d/%Y %H:%M:%S", 261 | handlers=[logging.StreamHandler(sys.stdout)], 262 | ) 263 | 264 | log_level = training_args.get_process_log_level() 265 | logger.setLevel(log_level) 266 | #datasets.utils.logging.set_verbosity(log_level) 267 | #transformers.utils.logging.set_verbosity(log_level) 268 | transformers.utils.logging.enable_default_handler() 269 | transformers.utils.logging.enable_explicit_format() 270 | 271 | # Log on each process the small summary: 272 | logger.warning( 273 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 274 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 275 | ) 276 | logger.info(f"Training/evaluation parameters {training_args}") 277 | 278 | # Set seed before initializing model. 279 | set_seed(training_args.seed) 280 | 281 | # Get the datasets: you can provide your own CSV/JSON training and evaluation files (see below) 282 | # 283 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 284 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 285 | # label if at least two columns are provided. 286 | # 287 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 288 | # single column. You can easily tweak this behavior (see below) 289 | # 290 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 291 | # download the dataset. 292 | 293 | # Loading a dataset from your local files. 294 | # CSV/JSON training and evaluation files are needed. 295 | data_files = {} 296 | if training_args.do_train and data_args.train_file != None: 297 | data_files["train"] = data_args.train_file 298 | if training_args.do_eval and data_args.validation_file != None: 299 | data_files["validation"] = data_args.validation_file 300 | if training_args.do_predict and data_args.test_file != None: 301 | data_files["test"] = data_args.test_file 302 | 303 | 304 | for key in data_files.keys(): 305 | logger.info(f"load a local file for {key}: {data_files[key]}") 306 | 307 | raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 308 | 309 | print(raw_datasets) 310 | # Labels 311 | if data_args.labels_file != None: 312 | # read labels from file 313 | with open(data_args.labels_file, 'r') as f: 314 | label_list = [line.strip() for line in f] 315 | num_labels = len(label_list) 316 | else: 317 | num_labels = 2 318 | label_list = [False, True] 319 | 320 | # grouped indices/labels for measuring actual accuracy of ranking models 321 | if data_args.grouped_indices_file != None: 322 | grouped_indices = np.load(data_args.grouped_indices_file) 323 | grouped_labels = np.load(data_args.grouped_labels_file) 324 | else: 325 | grouped_indices = None 326 | grouped_labels = None 327 | # get the index of "Correct" in the labels list 328 | pass_idx = label_list.index("Correct") 329 | 330 | print("label_list:", label_list) 331 | print("pass_idx:", pass_idx) 332 | 333 | 334 | if data_args.weights_file != None: 335 | # read weights from file 336 | with open(data_args.weights_file, 'r') as f: 337 | class_weights = [float(line.strip()) for line in f] 338 | else: 339 | class_weights = None 340 | 341 | # Load pretrained model and tokenizer 342 | tokenizer = RobertaTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,) 343 | model = RobertaForSequenceClassification.from_pretrained(model_args.model_name_or_path, num_labels=num_labels, cache_dir=model_args.cache_dir) 344 | 345 | # Preprocessing the raw datasets 346 | sentence1_key = data_args.sentence1_key #"model_prompt" 347 | sentence2_key = data_args.sentence2_key #"model_completion" 348 | label_key = data_args.label_key #"binary_label" 349 | 350 | # Padding strategy 351 | if data_args.pad_to_max_length: 352 | padding = "max_length" 353 | else: 354 | padding = False 355 | 356 | label_to_id = {v: i for i, v in enumerate(label_list)} 357 | model.config.label2id = label_to_id 358 | 359 | if data_args.max_seq_length > tokenizer.model_max_length: 360 | logger.warning( 361 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 362 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 363 | ) 364 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 365 | mask_padding_with_zero = True 366 | pad_token = 0 367 | pad_token_segment_id = 0 368 | pad_on_left = False 369 | def preprocess_function(examples): 370 | 371 | # Tokenize the texts 372 | text1 = [reindent_code(s, replace_set=[]) for s in examples[sentence1_key]] 373 | text2 = [reindent_code(s, replace_set=[]) for s in examples[sentence2_key]] 374 | 375 | def trunc(tokens_a, tokens_b, max_length, truncate_texta_from_first=False): 376 | """Truncates a sequence pair in place to the maximum length.""" 377 | # This is a simple heuristic which will always truncate the longer sequence 378 | # one token at a time. This makes more sense than truncating an equal percent 379 | # of tokens from each, since if one sequence is very short then each token 380 | # that's truncated likely contains more information than a longer sequence. 381 | while True: 382 | total_length = len(tokens_a) + len(tokens_b) 383 | if total_length <= max_length: 384 | break 385 | if len(tokens_a) > len(tokens_b): 386 | if truncate_texta_from_first: 387 | tokens_a.pop(0) 388 | else: 389 | tokens_a.pop() 390 | else: 391 | tokens_b.pop() 392 | def custom_tokenize(text1, text2): 393 | all_input_ids = [] 394 | all_attention_mask = [] 395 | all_token_type_ids = [] 396 | for i in range(len(text1)): 397 | tok_seq1 = tokenizer.tokenize(text1[i]) 398 | tok_seq2 = tokenizer.tokenize(text2[i]) 399 | 400 | trunc(tok_seq1, tok_seq2, max_seq_length - 3, truncate_texta_from_first=data_args.truncate_texta_from_first) # 3 is number of special tokens for bert sequence pair 401 | 402 | input_ids = [tokenizer.cls_token_id] 403 | input_ids += tokenizer.convert_tokens_to_ids(tok_seq1) 404 | input_ids += [tokenizer.sep_token_id] 405 | 406 | token_type_ids = [0]*len(input_ids) 407 | 408 | input_ids += tokenizer.convert_tokens_to_ids(tok_seq2) 409 | input_ids += [tokenizer.sep_token_id] 410 | token_type_ids += [1]*(len(tok_seq2)+1) 411 | 412 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 413 | # tokens are attended to. 414 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 415 | 416 | # Zero-pad up to the sequence length. 417 | padding_length = max_seq_length - len(input_ids) 418 | if pad_on_left: 419 | input_ids = ([tokenizer.pad_token] * padding_length) + input_ids 420 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 421 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 422 | else: 423 | input_ids = input_ids + ([pad_token] * padding_length) 424 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 425 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 426 | 427 | all_input_ids.append(input_ids) 428 | all_attention_mask.append(attention_mask) 429 | all_token_type_ids.append(token_type_ids) 430 | 431 | result = {"input_ids": all_input_ids, "attention_mask": all_attention_mask} 432 | return result 433 | 434 | result = custom_tokenize(text1, text2) 435 | 436 | # Map labels to IDs (not necessary for GLUE wc -s) 437 | if label_to_id is not None and label_key in examples: 438 | result["label"] = [(label_to_id[l]) for l in examples[label_key]] 439 | return result 440 | 441 | 442 | if training_args.do_train: 443 | if "train" not in raw_datasets: 444 | raise ValueError("--do_train requires a train dataset") 445 | train_dataset = raw_datasets["train"] 446 | if data_args.max_train_samples is not None: 447 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 448 | train_dataset = train_dataset.map( 449 | preprocess_function, 450 | batched=True, 451 | load_from_cache_file=not data_args.overwrite_cache, 452 | desc="Running tokenizer on train dataset", 453 | num_proc = 20, 454 | ) 455 | if training_args.do_eval: 456 | if "validation" not in raw_datasets : 457 | raise ValueError("--do_eval requires a validation dataset") 458 | eval_dataset = raw_datasets["validation"] 459 | if data_args.max_eval_samples is not None: 460 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 461 | eval_dataset = eval_dataset.map( 462 | preprocess_function, 463 | batched=True, 464 | load_from_cache_file=not data_args.overwrite_cache, 465 | desc="Running tokenizer on eval dataset", 466 | num_proc = 20, 467 | ) 468 | 469 | 470 | if training_args.do_predict or data_args.test_file is not None: 471 | if "test" not in raw_datasets: 472 | raise ValueError("--do_predict requires a test dataset") 473 | predict_dataset = raw_datasets["test"] 474 | if data_args.max_predict_samples is not None: 475 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 476 | predict_dataset = predict_dataset.map( 477 | preprocess_function, 478 | batched=True, 479 | load_from_cache_file=not data_args.overwrite_cache, 480 | desc="Running tokenizer on dataset", 481 | num_proc = 20, 482 | ) 483 | 484 | # Log a few random samples from the training set: 485 | if training_args.do_train: 486 | for index in random.sample(range(len(train_dataset)), 3): 487 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 488 | input_ids = train_dataset[index]["input_ids"] 489 | text = tokenizer.decode(input_ids) 490 | print(text) 491 | 492 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 493 | if data_args.pad_to_max_length: 494 | data_collator = default_data_collator 495 | elif training_args.fp16: 496 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 497 | else: 498 | data_collator = None 499 | 500 | # Initialize our Trainer 501 | if class_weights == None: 502 | trainer = Trainer( 503 | model=model, 504 | args=training_args, 505 | train_dataset=train_dataset if training_args.do_train else None, 506 | eval_dataset=eval_dataset if training_args.do_eval else None, 507 | compute_metrics=lambda x: compute_metrics(x, True, grouped_indices, grouped_labels, pass_idx), 508 | tokenizer=tokenizer, 509 | data_collator=data_collator, 510 | ) 511 | else: 512 | trainer = CustomTrainer( 513 | model=model, 514 | args=training_args, 515 | train_dataset=train_dataset if training_args.do_train else None, 516 | eval_dataset=eval_dataset if training_args.do_eval else None, 517 | compute_metrics=lambda x: compute_metrics(x, True, grouped_indices, grouped_labels, pass_idx), 518 | tokenizer=tokenizer, 519 | data_collator=data_collator, 520 | ) 521 | trainer.set_weights(class_weights) 522 | 523 | # Training 524 | if training_args.do_train: 525 | # Detecting last checkpoint. 526 | last_checkpoint = None 527 | if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: 528 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 529 | if training_args.do_train and last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 530 | raise ValueError( 531 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 532 | "Use --overwrite_output_dir to overcome." 533 | ) 534 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 535 | logger.info( 536 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 537 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 538 | ) 539 | 540 | train_result = trainer.train(resume_from_checkpoint = last_checkpoint) 541 | metrics = train_result.metrics 542 | max_train_samples = ( 543 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 544 | ) 545 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 546 | 547 | trainer.save_model() # Saves the tokenizer too for easy upload 548 | 549 | trainer.log_metrics("train", metrics) 550 | trainer.save_metrics("train", metrics) 551 | trainer.save_state() 552 | 553 | # load model for eval and test 554 | # if output dir has a model, then load it 555 | if os.path.exists(os.path.join(training_args.output_dir, "pytorch_model.bin")): 556 | logger.info(f"Loading model from {os.path.join(training_args.output_dir, 'pytorch_model.bin')}") 557 | model = RobertaForSequenceClassification.from_pretrained(training_args.output_dir, num_labels=num_labels) 558 | else: 559 | # if last checkpoint exists and the output dir does not have a model, 560 | # then we can load the best model using the trainer state in last checkpoint 561 | # Detecting last checkpoint. 562 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 563 | if last_checkpoint is not None: 564 | with open(os.path.join(last_checkpoint, "trainer_state.json"), "r") as f: 565 | trainer_state = json.load(f) 566 | if "best_model_checkpoint" in trainer_state: 567 | best_checkpoint = trainer_state['best_model_checkpoint'] 568 | # match prefix before /checkpoint of the checkpoint name with the output_dir 569 | prefix = best_checkpoint.split("/checkpoint")[0] 570 | substitute_prefix = training_args.output_dir.split("/checkpoint")[0] 571 | best_checkpoint = best_checkpoint.replace(prefix, substitute_prefix) 572 | logger.info(f"Loading model from {best_checkpoint}") 573 | model = RobertaForSequenceClassification.from_pretrained(best_checkpoint, num_labels=num_labels) 574 | else: 575 | logger.info("No model found. Using CodeBERT model") 576 | 577 | trainer = Trainer( 578 | model=model, 579 | args=training_args, 580 | train_dataset=train_dataset if training_args.do_train else None, 581 | eval_dataset=eval_dataset if training_args.do_eval else None, 582 | compute_metrics= lambda x: compute_metrics(x, False, grouped_indices, grouped_labels, pass_idx), 583 | tokenizer=tokenizer, 584 | data_collator=data_collator, 585 | ) 586 | # Evaluation 587 | if training_args.do_eval: 588 | logger.info("*** Evaluate ***") 589 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 590 | 591 | max_eval_samples = ( 592 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 593 | ) 594 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 595 | 596 | trainer.log_metrics("eval", metrics) 597 | trainer.save_metrics("eval", metrics) 598 | 599 | if training_args.do_predict: 600 | logger.info("*** Predict ***") 601 | pred_output = trainer.predict(predict_dataset, metric_key_prefix="predict") 602 | predictions = pred_output.predictions 603 | metrics = pred_output.metrics 604 | trainer.log_metrics("predict", metrics) 605 | trainer.save_metrics(data_args.predict_suffix, metrics) 606 | 607 | output_predict_file = os.path.join(training_args.output_dir, data_args.predict_suffix) 608 | if trainer.is_world_process_zero(): 609 | with open(output_predict_file, "w") as writer: 610 | logger.info(f"***** Predict results *****") 611 | writer.write("index\tprediction\n") 612 | for index, item in enumerate(predictions): 613 | item_str = "[" + ",".join([str(r) for r in item]) + "]" 614 | writer.write(f"{index}\t{item_str}\n") 615 | 616 | if __name__ == "__main__": 617 | main() -------------------------------------------------------------------------------- /src/run_seq_classification_and_line_prediction.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from datasets import load_dataset 3 | 4 | import transformers 5 | from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, RobertaForSequenceClassification 6 | 7 | import numpy as np 8 | 9 | from transformers import TrainingArguments, Trainer, HfArgumentParser, default_data_collator, set_seed 10 | from transformers.trainer_utils import get_last_checkpoint 11 | from transformers.file_utils import ModelOutput 12 | 13 | from transformers.models.roberta.modeling_roberta import ( 14 | RobertaClassificationHead, 15 | RobertaPreTrainedModel, 16 | ) 17 | 18 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 19 | 20 | from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score 21 | from transformers import EvalPrediction, WEIGHTS_NAME 22 | import torch 23 | import torch.nn as nn 24 | 25 | from typing import Optional, Tuple 26 | 27 | import argparse 28 | import os 29 | import sys 30 | import logging 31 | 32 | from dataclasses import dataclass, field 33 | from typing import Optional 34 | import random 35 | import glob 36 | import json 37 | from reindent import run as run_reindent 38 | import io 39 | import pdb 40 | import scipy 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | @dataclass 45 | class DataTrainingArguments: 46 | """ 47 | Arguments pertaining to what data we are going to input our model for training and eval. 48 | Using `HfArgumentParser` we can turn this class 49 | into argparse arguments to be able to specify them on 50 | the command line. 51 | """ 52 | 53 | max_seq_length: int = field( 54 | default=128, 55 | metadata={ 56 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 57 | "than this will be truncated, sequences shorter will be padded." 58 | }, 59 | ) 60 | overwrite_cache: bool = field( 61 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 62 | ) 63 | pad_to_max_length: bool = field( 64 | default=True, 65 | metadata={ 66 | "help": "Whether to pad all samples to `max_seq_length`. " 67 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 68 | }, 69 | ) 70 | truncate_texta_from_first: bool = field( 71 | default=False, 72 | metadata={ 73 | "help": "Truncate texta from the first when truncating. " 74 | "If False, will truncate texta from the last" 75 | }, 76 | ) 77 | max_train_samples: Optional[int] = field( 78 | default=None, 79 | metadata={ 80 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 81 | "value if set." 82 | }, 83 | ) 84 | max_eval_samples: Optional[int] = field( 85 | default=None, 86 | metadata={ 87 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 88 | "value if set." 89 | }, 90 | ) 91 | max_predict_samples: Optional[int] = field( 92 | default=None, 93 | metadata={ 94 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 95 | "value if set." 96 | }, 97 | ) 98 | train_file: Optional[str] = field( 99 | default=None, metadata={"help": "A csv or a json file containing the training data."} 100 | ) 101 | validation_file: Optional[str] = field( 102 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 103 | ) 104 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 105 | labels_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the labels."}) 106 | weights_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the weights."}) 107 | grouped_indices_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the grouped indices."}) 108 | grouped_labels_file: Optional[str] = field(default=None, metadata={"help": "A txt file containing the grouped labels."}) 109 | predict_suffix: Optional[str] = field(default="", metadata={"help": "Suffix for predict file."}) 110 | sentence1_key: Optional[str] = field( 111 | default="prompt", 112 | metadata={"help": "Name of the key for sentence1 in the dataset" }, 113 | ) 114 | sentence2_key: Optional[str] = field( 115 | default="completion", 116 | metadata={"help": "Name of the key for sentence2 in the dataset" }, 117 | ) 118 | label_key: Optional[str] = field( 119 | default="passed", 120 | metadata={"help": "Name of the key for label in the dataset" }, 121 | ) 122 | 123 | def __post_init__(self): 124 | if (self.train_file is None and self.validation_file is None) and self.test_file is None: 125 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 126 | 127 | 128 | 129 | @dataclass 130 | class ModelArguments: 131 | """ 132 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 133 | """ 134 | 135 | model_name_or_path: str = field( 136 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 137 | ) 138 | config_name: Optional[str] = field( 139 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 140 | ) 141 | tokenizer_name: Optional[str] = field( 142 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 143 | ) 144 | cache_dir: Optional[str] = field( 145 | default=None, 146 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 147 | ) 148 | use_fast_tokenizer: bool = field( 149 | default=True, 150 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 151 | ) 152 | model_revision: str = field( 153 | default="main", 154 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 155 | ) 156 | use_auth_token: bool = field( 157 | default=False, 158 | metadata={ 159 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 160 | "with private models)." 161 | }, 162 | ) 163 | 164 | eval_all_checkpoints: bool = field( 165 | default = False, 166 | metadata = { 167 | 'help' : "Evaluate all checkpoints" 168 | }, 169 | ) 170 | 171 | def compute_metrics(p: EvalPrediction, num_labels, compute_ranker_accuracy=False, grouped_indices=None, grouped_labels=None, pass_idx=1): 172 | # grouped_indices is a two-dimensional array where each row represents the indices of various datapoints in p that have 173 | # the same prompt 174 | # grouped_labels is the actual ternary labels of datapoints at the indices provided by grouped_indices 175 | 176 | pred_raw, labels = p 177 | assert(len(pred_raw) == 2) 178 | if pred_raw[0].shape[1] == num_labels: 179 | classifier_pred_raw = pred_raw[0] 180 | line_pred_raw = pred_raw[1] 181 | else: 182 | classifier_pred_raw = pred_raw[1] 183 | line_pred_raw = pred_raw[0] 184 | 185 | assert(len(labels) == 2) 186 | if labels[0].max() >= num_labels: 187 | line_labels = labels[0] 188 | classifier_labels = labels[1] 189 | else: 190 | line_labels = labels[1] 191 | classifier_labels = labels[0] 192 | 193 | # compute classifier accuracy 194 | if compute_ranker_accuracy: 195 | # we will also compute the actual accuracy of the ranker 196 | pred_softmax = scipy.special.softmax(classifier_pred_raw, axis=1) 197 | pred_prob = pred_softmax[:,pass_idx] 198 | prob_fun = lambda x: pred_prob[x] if x >=0 else 0 199 | prob_fun = np.vectorize(prob_fun) 200 | grouped_prob = prob_fun(grouped_indices) 201 | 202 | # top-1 accuracy 203 | best = np.argmax(grouped_prob, axis=1) 204 | top1_label = grouped_labels[np.arange(len(grouped_labels)), best] 205 | top1_accuracy = np.mean(top1_label == "Correct") 206 | top1_execution = 1.0 - np.mean(top1_label == "Execution error") 207 | 208 | pred = np.argmax(classifier_pred_raw, axis=1) 209 | pred_binary = pred == pass_idx 210 | labels_binary = classifier_labels == pass_idx 211 | accuracy = accuracy_score(y_true=labels_binary, y_pred=pred_binary) 212 | recall = recall_score(y_true=labels_binary, y_pred=pred_binary, average='micro') 213 | precision = precision_score(y_true=labels_binary, y_pred=pred_binary, average='micro') 214 | f1 = f1_score(y_true=labels_binary, y_pred=pred_binary, average='micro') 215 | 216 | # multi class predictions 217 | accuracy_mc = accuracy_score(y_true=classifier_labels, y_pred=pred) 218 | recall_mc = recall_score(y_true=classifier_labels, y_pred=pred, average='micro') 219 | precision_mc = precision_score(y_true=classifier_labels, y_pred=pred, average='micro') 220 | f1_mc = f1_score(y_true=classifier_labels, y_pred=pred, average='micro') 221 | 222 | metrics = {'f1': f1, 223 | 'precision': precision, 224 | 'recall': recall, 225 | 'accuracy': accuracy, 226 | 'f1_mc': f1_mc, 227 | 'precision_mc': precision_mc, 228 | 'recall_mc': recall_mc, 229 | 'accuracy_mc': accuracy_mc,} 230 | if compute_ranker_accuracy: 231 | metrics['top1_accuracy'] = top1_accuracy 232 | metrics['top1_execution'] = top1_execution 233 | 234 | # compute the metrics for the line prediction 235 | pred = np.argmax(line_pred_raw, axis=1) 236 | lines_accuracy = accuracy_score(y_true=line_labels, y_pred=pred) 237 | lines_recall = recall_score(y_true=line_labels, y_pred=pred, average='micro') 238 | lines_precision = precision_score(y_true=line_labels, y_pred=pred, average='micro') 239 | lines_f1 = f1_score(y_true=line_labels, y_pred=pred, average='micro') 240 | metrics['lines_f1'] = lines_f1 241 | metrics['lines_precision'] = lines_precision 242 | metrics['lines_recall'] = lines_recall 243 | metrics['lines_accuracy'] = lines_accuracy 244 | 245 | return metrics 246 | 247 | def reindent_code(codestr): 248 | """ 249 | Given code string, reindent it in the same way that the 250 | Github dataset was indented 251 | """ 252 | codestr = io.StringIO(codestr) 253 | ret = io.StringIO() 254 | 255 | run_reindent( 256 | codestr, 257 | ret, 258 | config = { 259 | "dry-run": False, 260 | "help": False, 261 | "to": 10, 262 | "from": -1, 263 | "tabs": True, 264 | "encoding": "utf-8", 265 | "is-tabs": False, 266 | "tabsize": 10, 267 | "all-tabs": False 268 | } 269 | ) 270 | 271 | return ret.getvalue() 272 | 273 | @dataclass 274 | class FaultAwareModelOutput(ModelOutput): 275 | loss: Optional[torch.FloatTensor] = None 276 | line_loss: Optional[torch.FloatTensor] = None 277 | classifier_loss: Optional[torch.FloatTensor] = None 278 | line_logits: torch.FloatTensor = None 279 | classifier_logits: torch.FloatTensor = None 280 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 281 | attentions: Optional[Tuple[torch.FloatTensor]] = None 282 | 283 | _CHECKPOINT_FOR_DOC = "roberta-base" 284 | _CONFIG_FOR_DOC = "RobertaConfig" 285 | _TOKENIZER_FOR_DOC = "RobertaTokenizer" 286 | 287 | 288 | class RobertaForFaultAwareRanker(RobertaPreTrainedModel): 289 | _keys_to_ignore_on_load_missing = [r"position_ids"] 290 | 291 | def __init__(self, config): 292 | super().__init__(config) 293 | self.num_labels = config.num_labels 294 | self.config = config 295 | setattr(self.config, 'keys_to_ignore_at_inference', ['line_loss', 'classifier_loss']) 296 | 297 | self.roberta = RobertaModel(config, add_pooling_layer=False) 298 | self.line_outputs = nn.Linear(config.hidden_size, 1) 299 | self.classifier = RobertaClassificationHead(config) 300 | self.class_weights = None 301 | 302 | # Initialize weights and apply final processing 303 | #self.post_init() 304 | 305 | def set_class_weights(self, class_weights): 306 | self.class_weights = class_weights 307 | 308 | def forward( 309 | self, 310 | input_ids=None, 311 | attention_mask=None, 312 | token_type_ids=None, 313 | position_ids=None, 314 | head_mask=None, 315 | inputs_embeds=None, 316 | labels=None, 317 | line_masks = None, 318 | line_positions=None, 319 | output_attentions=None, 320 | output_hidden_states=None, 321 | return_dict=None, 322 | ): 323 | r""" 324 | line_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 325 | Labels for error lines (index) for computing the token classification loss. 326 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 327 | are not taken into account for computing the loss. 328 | line_masks (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 329 | line masks for which token to consider for computing the token classification loss. 330 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 331 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 332 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 333 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 334 | """ 335 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 336 | outputs = self.roberta( 337 | input_ids, 338 | attention_mask=attention_mask, 339 | token_type_ids=token_type_ids, 340 | position_ids=position_ids, 341 | head_mask=head_mask, 342 | inputs_embeds=inputs_embeds, 343 | output_attentions=output_attentions, 344 | output_hidden_states=output_hidden_states, 345 | return_dict=return_dict, 346 | ) 347 | 348 | sequence_output = outputs[0] 349 | 350 | classifier_logits = self.classifier(sequence_output) 351 | 352 | line_logits = self.line_outputs(sequence_output) 353 | line_logits = line_logits.squeeze(-1).contiguous() 354 | 355 | loss = 0 356 | 357 | line_loss = None 358 | if line_positions is not None: 359 | # If we are on multi-GPU, split add a dimension 360 | if len(line_positions.size()) > 1: 361 | line_positions = line_positions.squeeze(-1) 362 | # sometimes the line positions are outside our model inputs, we ignore these terms 363 | ignored_index = line_logits.size(1) 364 | line_positions = line_positions.clamp(0, ignored_index) 365 | 366 | # mask out logits at line_masks == 0 positions 367 | line_logits = line_logits.masked_fill(line_masks == 0, -1e12) 368 | 369 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 370 | line_loss = loss_fct(line_logits, line_positions) 371 | loss += line_loss 372 | 373 | if labels is not None: 374 | if self.config.problem_type is None: 375 | if self.num_labels == 1: 376 | self.config.problem_type = "regression" 377 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 378 | self.config.problem_type = "single_label_classification" 379 | else: 380 | self.config.problem_type = "multi_label_classification" 381 | 382 | if self.config.problem_type == "regression": 383 | loss_fct = MSELoss() 384 | if self.num_labels == 1: 385 | classifier_loss = loss_fct(classifier_logits.squeeze(), labels.squeeze()) 386 | else: 387 | classifier_loss = loss_fct(classifier_logits, labels) 388 | elif self.config.problem_type == "single_label_classification": 389 | if self.class_weights != None: 390 | loss_fct = CrossEntropyLoss(weight=torch.tensor(self.class_weights).to(classifier_logits.device)) 391 | else: 392 | loss_fct = CrossEntropyLoss() 393 | classifier_loss = loss_fct(classifier_logits.view(-1, self.num_labels), labels.view(-1)) 394 | elif self.config.problem_type == "multi_label_classification": 395 | loss_fct = BCEWithLogitsLoss() 396 | classifier_loss = loss_fct(classifier_logits, labels) 397 | loss += classifier_loss 398 | 399 | if not return_dict: 400 | output = (classifier_logits,) + outputs[2:] 401 | return ((loss,) + output) if loss is not None else output 402 | #output = (classifier_logits, line_logits) + outputs[2:] 403 | #return ((loss, classifier_loss, line_loss) + output) if loss is not None else output 404 | 405 | return FaultAwareModelOutput( 406 | loss=loss, 407 | line_loss=line_loss, 408 | #classifier_loss=classifier_loss, 409 | #classifier_logits=classifier_logits, 410 | line_logits=line_logits, 411 | classifier_logits = classifier_logits, 412 | hidden_states=outputs.hidden_states, 413 | attentions=outputs.attentions, 414 | ) 415 | 416 | 417 | 418 | def main(): 419 | # See all possible arguments in src/transformers/training_args.py 420 | # or by passing the --help flag to this script. 421 | # We now keep distinct sets of args, for a cleaner separation of concerns. 422 | 423 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 424 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 425 | 426 | # Setup logging 427 | logging.basicConfig( 428 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 429 | datefmt="%m/%d/%Y %H:%M:%S", 430 | handlers=[logging.StreamHandler(sys.stdout)], 431 | ) 432 | 433 | log_level = training_args.get_process_log_level() 434 | logger.setLevel(log_level) 435 | #datasets.utils.logging.set_verbosity(log_level) 436 | #transformers.utils.logging.set_verbosity(log_level) 437 | transformers.utils.logging.enable_default_handler() 438 | transformers.utils.logging.enable_explicit_format() 439 | 440 | # Log on each process the small summary: 441 | logger.warning( 442 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 443 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 444 | ) 445 | logger.info(f"Training/evaluation parameters {training_args}") 446 | 447 | # Set seed before initializing model. 448 | set_seed(training_args.seed) 449 | 450 | # Loading a dataset from your local files. 451 | # CSV/JSON training and evaluation files are needed. 452 | data_files = {} 453 | if training_args.do_train and data_args.train_file != None: 454 | data_files["train"] = data_args.train_file 455 | if training_args.do_eval and data_args.validation_file != None: 456 | data_files["validation"] = data_args.validation_file 457 | if training_args.do_predict and data_args.test_file != None: 458 | data_files["test"] = data_args.test_file 459 | 460 | 461 | for key in data_files.keys(): 462 | logger.info(f"load a local file for {key}: {data_files[key]}") 463 | 464 | raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 465 | 466 | print(raw_datasets) 467 | # Labels 468 | if data_args.labels_file != None: 469 | # read labels from file 470 | with open(data_args.labels_file, 'r') as f: 471 | label_list = [line.strip() for line in f] 472 | num_labels = len(label_list) 473 | else: 474 | num_labels = 2 475 | label_list = [False, True] 476 | 477 | # grouped indices/labels for measuring actual accuracy of ranking models 478 | if data_args.grouped_indices_file != None: 479 | grouped_indices = np.load(data_args.grouped_indices_file) 480 | grouped_labels = np.load(data_args.grouped_labels_file) 481 | else: 482 | grouped_indices = None 483 | grouped_labels = None 484 | # get the index of "Correct" in the labels list 485 | pass_idx = label_list.index("Correct") 486 | 487 | print("label_list:", label_list) 488 | print("pass_idx:", pass_idx) 489 | 490 | 491 | if data_args.weights_file != None: 492 | # read weights from file 493 | with open(data_args.weights_file, 'r') as f: 494 | class_weights = [float(line.strip()) for line in f] 495 | else: 496 | class_weights = None 497 | 498 | # Load pretrained model and tokenizer 499 | tokenizer = RobertaTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,) 500 | model = RobertaForFaultAwareRanker.from_pretrained(model_args.model_name_or_path, num_labels=num_labels, cache_dir=model_args.cache_dir) 501 | model.set_class_weights(class_weights) 502 | # Preprocessing the raw datasets 503 | sentence1_key = data_args.sentence1_key #"model_prompt" 504 | sentence2_key = data_args.sentence2_key #"model_completion" 505 | label_key = data_args.label_key #"passed" 506 | 507 | # Padding strategy 508 | if data_args.pad_to_max_length: 509 | padding = "max_length" 510 | else: 511 | padding = False 512 | 513 | label_to_id = {v: i for i, v in enumerate(label_list)} 514 | model.config.label2id = label_to_id 515 | 516 | if data_args.max_seq_length > tokenizer.model_max_length: 517 | logger.warning( 518 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 519 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 520 | ) 521 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 522 | mask_padding_with_zero = True 523 | pad_token = 0 524 | pad_token_segment_id = 0 525 | pad_on_left = False 526 | def preprocess_function(examples): 527 | # Tokenize the texts 528 | text1 = [reindent_code(s) for s in examples[sentence1_key]] 529 | text2 = [reindent_code(s) for s in examples[sentence2_key]] 530 | 531 | def trunc(tokens_a, tokens_b, max_length, truncate_texta_from_first=False): 532 | """Truncates a sequence pair in place to the maximum length.""" 533 | # This is a simple heuristic which will always truncate the longer sequence 534 | # one token at a time. This makes more sense than truncating an equal percent 535 | # of tokens from each, since if one sequence is very short then each token 536 | # that's truncated likely contains more information than a longer sequence. 537 | while True: 538 | total_length = len(tokens_a) + len(tokens_b) 539 | if total_length <= max_length: 540 | break 541 | if len(tokens_a) > len(tokens_b): 542 | if truncate_texta_from_first: 543 | tokens_a.pop(0) 544 | else: 545 | tokens_a.pop() 546 | else: 547 | tokens_b.pop() 548 | def custom_tokenize(text1, text2, error_line): 549 | all_input_ids = [] 550 | all_attention_mask = [] 551 | all_token_type_ids = [] 552 | all_line_masks = [] 553 | all_line_positions = [] 554 | for i in range(len(text1)): 555 | tok_seq1 = tokenizer.tokenize(text1[i]) 556 | tok_seq2 = tokenizer.tokenize(text2[i]) 557 | 558 | trunc(tok_seq1, tok_seq2, max_seq_length - 3, truncate_texta_from_first=data_args.truncate_texta_from_first) # 3 is number of special tokens for bert sequence pair 559 | new_line_token = tokenizer.tokenize('\n')[0] 560 | # find all indices of new_line_token in tok_seq2 561 | new_line_indices = [i for i, x in enumerate(tok_seq2) if x == new_line_token] 562 | input_ids = [tokenizer.cls_token_id] 563 | input_ids += tokenizer.convert_tokens_to_ids(tok_seq1) 564 | input_ids += [tokenizer.sep_token_id] 565 | line_masks = [0]*len(input_ids) 566 | # unmask the first token -> to handle cases where error line does not exist 567 | line_masks[0] = 1 568 | if error_line[i] == -1: 569 | line_position = 0 570 | token_type_ids = [0]*len(input_ids) 571 | 572 | input_ids += tokenizer.convert_tokens_to_ids(tok_seq2) 573 | line_masks_from_text2 = [0]*len(tok_seq2) 574 | for j in range(len(new_line_indices)): 575 | line_masks_from_text2[new_line_indices[j]] = 1 576 | if j == error_line[i]: 577 | line_position = len(line_masks) + new_line_indices[j] 578 | 579 | line_masks += line_masks_from_text2 580 | 581 | # unmask the next token -> to handle cases where error line is truncated 582 | line_masks += [1] 583 | if error_line[i] >= len(new_line_indices): 584 | line_position = len(line_masks) - 1 585 | 586 | input_ids += [tokenizer.sep_token_id] 587 | token_type_ids += [1]*(len(tok_seq2)+1) 588 | 589 | assert(len(line_masks) == len(input_ids) == len(token_type_ids)) 590 | assert(line_position < len(line_masks)) 591 | 592 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 593 | # tokens are attended to. 594 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 595 | # Zero-pad up to the sequence length. 596 | padding_length = max_seq_length - len(input_ids) 597 | if pad_on_left: 598 | input_ids = ([tokenizer.pad_token] * padding_length) + input_ids 599 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 600 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 601 | line_masks = ([0] * padding_length) + line_masks 602 | line_position += padding_length 603 | else: 604 | input_ids = input_ids + ([pad_token] * padding_length) 605 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 606 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 607 | line_masks = line_masks + ([0] * padding_length) 608 | 609 | assert(len(line_masks) == len(input_ids) == len(token_type_ids)) 610 | assert(line_position < len(line_masks)) 611 | assert(line_masks[line_position] == 1) 612 | 613 | all_input_ids.append(input_ids) 614 | all_attention_mask.append(attention_mask) 615 | all_token_type_ids.append(token_type_ids) 616 | all_line_masks.append(line_masks) 617 | all_line_positions.append(line_position) 618 | 619 | result = {"input_ids": all_input_ids, "attention_mask": all_attention_mask, 'line_masks': all_line_masks, 'line_positions': all_line_positions} 620 | return result 621 | 622 | result = custom_tokenize(text1, text2, examples['error_line_number']) 623 | input_ids = result["input_ids"] 624 | #text = tokenizer.decode(input_ids[0]) 625 | 626 | # Map labels to IDs (not necessary for GLUE wc -s) 627 | if label_to_id is not None and label_key in examples: 628 | result["label"] = [(label_to_id[l]) for l in examples[label_key]] 629 | return result 630 | 631 | 632 | if training_args.do_train: 633 | if "train" not in raw_datasets: 634 | raise ValueError("--do_train requires a train dataset") 635 | train_dataset = raw_datasets["train"] 636 | if data_args.max_train_samples is not None: 637 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 638 | train_dataset = train_dataset.map( 639 | preprocess_function, 640 | batched=True, 641 | load_from_cache_file=not data_args.overwrite_cache, 642 | desc="Running tokenizer on train dataset", 643 | num_proc = 20, 644 | ) 645 | if training_args.do_eval: 646 | if "validation" not in raw_datasets : 647 | raise ValueError("--do_eval requires a validation dataset") 648 | eval_dataset = raw_datasets["validation"] 649 | if data_args.max_eval_samples is not None: 650 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 651 | eval_dataset = eval_dataset.map( 652 | preprocess_function, 653 | batched=True, 654 | load_from_cache_file=not data_args.overwrite_cache, 655 | desc="Running tokenizer on eval dataset", 656 | num_proc = 20, 657 | ) 658 | 659 | 660 | if training_args.do_predict or data_args.test_file is not None: 661 | if "test" not in raw_datasets: 662 | raise ValueError("--do_predict requires a test dataset") 663 | predict_dataset = raw_datasets["test"] 664 | if data_args.max_predict_samples is not None: 665 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 666 | predict_dataset = predict_dataset.map( 667 | preprocess_function, 668 | batched=True, 669 | load_from_cache_file=not data_args.overwrite_cache, 670 | desc="Running tokenizer on dataset", 671 | num_proc = 20, 672 | ) 673 | 674 | # Log a few random samples from the training set: 675 | if training_args.do_train: 676 | for index in random.sample(range(len(train_dataset)), 3): 677 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 678 | input_ids = train_dataset[index]["input_ids"] 679 | text = tokenizer.decode(input_ids) 680 | print(text) 681 | 682 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 683 | if data_args.pad_to_max_length: 684 | data_collator = default_data_collator 685 | elif training_args.fp16: 686 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 687 | else: 688 | data_collator = None 689 | 690 | # Initialize our Trainer 691 | trainer = Trainer( 692 | model=model, 693 | args=training_args, 694 | train_dataset=train_dataset if training_args.do_train else None, 695 | eval_dataset=eval_dataset if training_args.do_eval else None, 696 | compute_metrics=lambda x: compute_metrics(x, num_labels, True, grouped_indices, grouped_labels, pass_idx), 697 | tokenizer=tokenizer, 698 | data_collator=data_collator, 699 | ) 700 | trainer.label_names = ['labels', 'line_positions'] 701 | 702 | # Training 703 | if training_args.do_train: 704 | # Detecting last checkpoint. 705 | last_checkpoint = None 706 | if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: 707 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 708 | if training_args.do_train and last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 709 | raise ValueError( 710 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 711 | "Use --overwrite_output_dir to overcome." 712 | ) 713 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 714 | logger.info( 715 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 716 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 717 | ) 718 | 719 | train_result = trainer.train(resume_from_checkpoint = last_checkpoint) 720 | metrics = train_result.metrics 721 | max_train_samples = ( 722 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 723 | ) 724 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 725 | 726 | trainer.save_model() # Saves the tokenizer too for easy upload 727 | 728 | trainer.log_metrics("train", metrics) 729 | trainer.save_metrics("train", metrics) 730 | trainer.save_state() 731 | 732 | # load model for eval and test 733 | # if output dir has a model, then load it 734 | if os.path.exists(os.path.join(training_args.output_dir, "pytorch_model.bin")): 735 | logger.info(f"Loading model from {os.path.join(training_args.output_dir, 'pytorch_model.bin')}") 736 | model = RobertaForFaultAwareRanker.from_pretrained(training_args.output_dir, num_labels=num_labels) 737 | else: 738 | # if last checkpoint exists and the output dir does not have a model, 739 | # then we can load the best model using the trainer state in last checkpoint 740 | # Detecting last checkpoint. 741 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 742 | if last_checkpoint is not None: 743 | with open(os.path.join(last_checkpoint, "trainer_state.json"), "r") as f: 744 | trainer_state = json.load(f) 745 | if "best_model_checkpoint" in trainer_state: 746 | best_checkpoint = trainer_state['best_model_checkpoint'] 747 | # match prefix before /checkpoint of the checkpoint name with the output_dir (because the model location might have changed after traning) 748 | prefix = best_checkpoint.split("/checkpoint")[0] 749 | substitute_prefix = training_args.output_dir.split("/checkpoint")[0] 750 | best_checkpoint = best_checkpoint.replace(prefix, substitute_prefix) 751 | logger.info(f"Loading model from {best_checkpoint}") 752 | model = RobertaForFaultAwareRanker.from_pretrained(best_checkpoint, num_labels=num_labels) 753 | else: 754 | logger.info("No model found. Using CodeBERT model") 755 | 756 | trainer = Trainer( 757 | model=model, 758 | args=training_args, 759 | train_dataset=train_dataset if training_args.do_train else None, 760 | eval_dataset=eval_dataset if training_args.do_eval else None, 761 | tokenizer=tokenizer, 762 | data_collator=data_collator, 763 | ) 764 | 765 | 766 | # Evaluation 767 | if training_args.do_eval: 768 | logger.info("*** Evaluate ***") 769 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 770 | max_eval_samples = ( 771 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 772 | ) 773 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 774 | 775 | trainer.log_metrics("eval", metrics) 776 | trainer.save_metrics("eval", metrics) 777 | 778 | if training_args.do_predict: 779 | logger.info("*** Predict ***") 780 | 781 | pred_output = trainer.predict(predict_dataset, metric_key_prefix="predict") 782 | predictions = pred_output.predictions 783 | metrics = pred_output.metrics 784 | trainer.log_metrics("predict", metrics) 785 | trainer.save_metrics(data_args.predict_suffix, metrics) 786 | 787 | # It appears that predictions[1] is the one that contains the classification logits 788 | output_predict_file = os.path.join(training_args.output_dir, data_args.predict_suffix) 789 | if trainer.is_world_process_zero(): 790 | with open(output_predict_file, "w") as writer: 791 | logger.info(f"***** Predict results *****") 792 | writer.write("index\tprediction\n") 793 | for index, item in enumerate(predictions[1]): 794 | item_str = "[" + ",".join([str(r) for r in item]) + "]" 795 | writer.write(f"{index}\t{item_str}\n") 796 | 797 | 798 | if __name__ == "__main__": 799 | main() --------------------------------------------------------------------------------