├── .env ├── .gitattribute ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.txt ├── Pipfile ├── Pipfile.lock ├── README.md ├── data ├── 9.11.accs_by_project.csv ├── 9.12.accs_by_project.csv ├── 9.14.accs_by_project.csv ├── TypeT5-Workflow.png ├── code │ ├── bad_code_1.py │ ├── bad_code_2.py │ ├── code_with_slash.py │ ├── dummy │ │ ├── __init__.py │ │ ├── dummy_1.py │ │ └── dummy_2.py │ ├── env_code_1.py │ ├── env_code_2.py │ └── good_code_1.py ├── ex_repo │ ├── ex_code_1.py │ └── ex_code_2.py ├── mypy-dependents-by-stars.json ├── repos_split.pkl └── useful_repos.pkl ├── requirements.txt ├── scripts ├── analyze_dataset.ipynb ├── analyze_decoding_results.ipynb ├── archive │ ├── analyze_dagger.ipynb │ ├── code-t5-workflow.ipynb │ ├── debug_mypy.ipynb │ ├── fine_tune_t5.ipynb │ ├── inference_spot.ipynb │ ├── kill_dmypy.sh │ ├── test_inf_env.ipynb │ ├── train_dagger.py │ └── train_spot.ipynb ├── collect_dataset.ipynb ├── experiments │ ├── eval_file_model.ipynb │ ├── eval_func_model.py │ ├── run_hityper.ipynb │ ├── run_type4py.ipynb │ ├── run_typilus.ipynb │ └── type_check_decoding.ipynb ├── prepare_dataset.ipynb ├── run_func_decoding.ipynb ├── run_typet5.ipynb ├── scratch.ipynb └── train_model.py ├── setup.py ├── src └── typet5 │ ├── __init__.py │ ├── data.py │ ├── decode.py │ ├── experiments │ ├── __init__.py │ ├── hityper.py │ ├── type4py.py │ ├── typet5.py │ ├── typilus.py │ └── utils.py │ ├── function_dataset.py │ ├── function_decoding.py │ ├── model.py │ ├── static_analysis.py │ ├── tokenized_src.py │ ├── train.py │ ├── type_check.py │ ├── type_env.py │ ├── utils.py │ └── visualization.py └── tests ├── __init__.py ├── test_func_decoding.py ├── test_model_creation.py ├── test_static_analysis.py └── test_type_env.py /.env: -------------------------------------------------------------------------------- 1 | LIBCST_PARSER_TYPE=native 2 | -------------------------------------------------------------------------------- /.gitattribute: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | config 2 | *.egg-info 3 | code_output 4 | temp 5 | wandb 6 | checkpoints 7 | caches 8 | lightning_logs/ 9 | output/ 10 | 11 | build/ 12 | __pycache__ 13 | *.py[cod] 14 | *~ 15 | /build 16 | /env*/ 17 | docs/build/ 18 | docs/source/_build 19 | mypyc/doc/_build 20 | *.iml 21 | /out/ 22 | .venv 23 | venv/ 24 | mypy_temp 25 | .mypy_cache/ 26 | .incremental_checker_cache.json 27 | .cache 28 | dmypy.json 29 | .dmypy.json 30 | .coeditor_logs 31 | 32 | # Packages 33 | *.egg 34 | *.egg-info 35 | *.eggs 36 | 37 | # IDEs 38 | .idea 39 | .vscode 40 | 41 | # vim temporary files 42 | .*.sw? 43 | *.sw? 44 | 45 | # Operating Systems 46 | .DS_Store 47 | 48 | # Coverage Files 49 | htmlcov 50 | .coverage* 51 | 52 | # pytest cache 53 | .pytest_cache/ 54 | 55 | # virtualenv 56 | .Python 57 | bin/ 58 | lib/ 59 | include/ 60 | .python-version 61 | pyvenv.cfg 62 | 63 | .tox 64 | pip-wheel-metadata 65 | 66 | 67 | test_capi 68 | *.o 69 | *.a 70 | test_capi 71 | /.mypyc-flake8-cache.json 72 | /mypyc/lib-rt/build/ 73 | /mypyc/lib-rt/*.so 74 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'tests/testcases' 2 | 3 | default_language_version: 4 | python: python3.11 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v3.2.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: end-of-file-fixer 12 | - id: check-yaml 13 | - id: check-added-large-files 14 | 15 | - repo: https://github.com/pycqa/isort 16 | rev: 5.11.4 17 | hooks: 18 | - id: isort 19 | args: ["--profile", "black", "--filter-files"] 20 | 21 | - repo: https://github.com/psf/black 22 | rev: 22.12.0 23 | hooks: 24 | - id: black 25 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Jiayi Wei 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | mypy = "~=0.971" 8 | typet5 = {editable = true, path = "."} 9 | tqdm = "*" 10 | types-all = "*" 11 | dateparser = "~=1.1.1" 12 | pyrsistent = "*" 13 | plotly = "~=5.10" 14 | pandas = "~=1.4" 15 | ipywidgets = "~=8.0" 16 | datasets = "~=2.4.0" 17 | transformers = "~=4.21.3" 18 | wandb = "~=0.13.2" 19 | libcst = "<=0.4.2" 20 | pytorch-lightning = "~=1.7.5" 21 | colored = "~=1.4" 22 | ipykernel = "*" 23 | termcolor = "~=1.0" 24 | prettytable = "~=3.4.1" 25 | huggingface-hub = "*" 26 | 27 | [dev-packages] 28 | pytest = "*" 29 | black = "*" 30 | exceptiongroup = "*" 31 | 32 | [requires] 33 | python_version = "3.10" 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TypeT5: Seq2seq Type Inference using Static Analysis 2 | 3 | TypeT5 Workflow 4 | 5 | This repo contains the source code for the paper [TypeT5: Seq2seq Type Inference using Static Analysis](https://openreview.net/forum?id=4TyNEhI2GdN¬eId=EX_-kP9xah). 6 | 7 | ``` 8 | @inproceedings{Wei2023TypeT5, 9 | title={TypeT5: Seq2seq Type Inference using Static Analysis}, 10 | author={Jiayi Wei and Greg Durrett and Isil Dillig}, 11 | booktitle={International Conference on Learning Representations}, 12 | year={2023}, 13 | url={https://openreview.net/forum?id=4TyNEhI2GdN} 14 | } 15 | ``` 16 | 17 | ## Installation 18 | 19 | This project uses [pipenv](https://pipenv.pypa.io/en/latest/) to manage the package dependencies. Pipenv tracks the exact package versions and manages the (project-specific) virtual environment for you. To install all dependencies, make sure you have pipenv and Python 3.10 installed, then, at the project root, run the following two commands: 20 | ```bash 21 | pipenv --python # create a new environment for this project 22 | pipenv sync --dev # install all specificed dependencies 23 | ``` 24 | 25 | More about pipenv: 26 | - To add new dependences into the virtual environment, you can either add them via `pipenv install ..` (using `pipenv`) or `pipenv run pip install ..` (using `pip` from within the virtual environment). 27 | - If your pytorch installation is not working properly, you might need to reinstall it via the `pipenv run pip install` approach rather than `pipenv install`. 28 | - All `.py` scripts below can be run via `pipenv run python `. For `.ipynb` notebooks, make sure you select the pipenv environment as the kernel. You can run all unit tests by running `pipenv run pytest` at the project root. 29 | 30 | If you are not using pipenv: 31 | - Make sure to add the environment variables in the [.env](.env) file to your shell environment when you run the scripts (needed by the parsing library). 32 | - We also provided a [requirements.txt](requirements.txt) file for you to install the dependencies via `pip install -r requirements.txt`. 33 | 34 | 35 | ## Using the trained model 36 | The notebook [scripts/run_typet5.ipynb](scripts/run_typet5.ipynb) shows you how to download the TypeT5 model from Huggingface and then use it to make type predictions for a specified codebase. 37 | 38 | ## Training a New Model 39 | 40 | - First, run the notebook [scripts/collect_dataset.ipynb](scripts/collect_dataset.ipynb) to download and split the BetterTypes4Py dataset used in our paper. 41 | - The exact list of repos we used for the experiments in paper can be loaded from `data/repos_split.pkl` using `pickle.load`. They can also be downloaded via this [Google Drive link](https://drive.google.com/drive/folders/1lXKtwi7AOI-w4ESgMi7J5YAHRGP-JhG5?usp=sharing). 42 | - Then, run [scripts/train_model.py](scripts/train_model.py) to train a new TypeT5 model. Training takes about 11 hours on a single Quadro RTX 8000 GPU with 48GB memory. 43 | 44 | 45 | ## Development 46 | - Formatter: We use `black` for formatting with the default options. 47 | - Type Checker: We use Pylance to type check this codebase. It's the built-in type checker shipped with the VSCode Python extension and can be enabled by setting `Python > Anlaysis > Type Checking Mode` to `basic`. 48 | -------------------------------------------------------------------------------- /data/9.11.accs_by_project.csv: -------------------------------------------------------------------------------- 1 | ,project,no-neighbors,non-incr,double-traversal,label_rate,labels 2 | 0,srittau__FakeSMTPd,0.8362068965517241,0.8793103448275862,0.896551724137931,0.3483483483483483,116 3 | 1,scalableminds__webknossos-connect,0.6757188498402555,0.7507987220447284,0.7763578274760383,0.830238726790451,626 4 | 2,flopp__unicode-explorer,0.8205128205128205,0.8547008547008547,0.9230769230769231,0.5763546798029556,117 5 | 3,eirannejad__calcatime,0.46875,0.53125,0.5,0.8,32 6 | 4,jelford__webwatcher,0.7659574468085106,0.7659574468085106,0.7659574468085106,0.2186046511627907,47 7 | 5,road-master__video-archiver,0.8536585365853658,0.9105691056910569,0.9186991869918699,0.41694915254237286,123 8 | 6,flopp__GpxTrackPoster,0.5498154981549815,0.6568265682656826,0.7158671586715867,0.6545893719806763,271 9 | 7,boompig__book-classics,0.9058823529411765,0.9411764705882353,0.9411764705882353,0.5862068965517241,85 10 | 8,dropbox__sqlalchemy-stubs,0.5952380952380952,0.6666666666666666,0.6428571428571429,0.6086956521739131,42 11 | 9,reddit__baseplate.py-upgrader,0.7548387096774194,0.7612903225806451,0.8129032258064516,0.4305555555555556,155 12 | 10,typeddjango__pytest-mypy-plugins,0.7565217391304347,0.782608695652174,0.782608695652174,0.732484076433121,115 13 | 11,linw1995__data_extractor,0.6,0.676923076923077,0.6538461538461539,0.23090586145648312,130 14 | 12,albertyw__albertyw.com,0.5285714285714286,0.6285714285714286,0.6428571428571429,0.37433155080213903,70 15 | 13,marcosschroh__dataclasses-avroschema,0.5056947608200456,0.5034168564920274,0.4760820045558087,0.42787524366471735,439 16 | 14,antonagestam__collectfast,0.6896551724137931,0.7155172413793104,0.7586206896551724,0.5155555555555555,116 17 | 15,ammarnajjar__pautomate,0.8518518518518519,0.8703703703703703,0.9259259259259259,0.5046728971962616,54 18 | 16,brettkromkamp__topic-db,0.8476190476190476,0.8761904761904762,0.9238095238095239,0.3633217993079585,105 19 | 17,basilisp-lang__basilisp,0.5882352941176471,0.6233979625369701,0.6194544857048965,0.5269264069264069,3043 20 | 18,paulcwatts__drf-json-schema,0.6590909090909091,0.6439393939393939,0.6931818181818182,0.528,264 21 | 19,ShadowTemplate__beautiful-python-3,0.8991596638655462,0.957983193277311,0.9411764705882353,0.3287292817679558,119 22 | 20,AxelVoitier__lookups,0.3424657534246575,0.4155251141552511,0.3789954337899543,0.33435114503816793,219 23 | 21,Gerschtli__teamspeak-update-notifier,0.6,0.6666666666666666,0.6857142857142857,0.5497382198952879,105 24 | 22,joshtemple__lkml,0.7317073170731707,0.7658536585365854,0.7707317073170732,0.39575289575289574,205 25 | 23,kitsuyui__bamboo-crawler,0.6331360946745562,0.6449704142011834,0.6331360946745562,0.6450381679389313,169 26 | 24,amplify-education__python-hcl2,0.6301369863013698,0.6438356164383562,0.821917808219178,0.73,73 27 | 25,nubark__instark,0.8165374677002584,0.8578811369509044,0.8578811369509044,0.3467741935483871,387 28 | 26,albertyw__git-browse,0.5944444444444444,0.7166666666666667,0.7277777777777777,0.46272493573264784,180 29 | 27,kornicameister__axion,0.4909596662030598,0.5090403337969402,0.45479833101529904,0.6301489921121823,719 30 | 28,payscale__fables,0.7338709677419355,0.7903225806451613,0.75,0.3668639053254438,124 31 | 29,nabla-c0d3__sslyze,0.7120085015940489,0.7523910733262487,0.7630180658873539,0.5689238210399032,941 32 | 30,rakitaj__daily-programmer,0.737037037037037,0.7425925925925926,0.7666666666666667,0.5009276437847866,540 33 | 31,silasary__discord_connotations,0.6153846153846154,0.6923076923076923,0.6923076923076923,0.2708333333333333,13 34 | 32,JakobGM__quelf,0.5434782608695652,0.5652173913043478,0.5543478260869565,0.5287356321839081,92 35 | 33,ocf__slackbridge,0.6935483870967742,0.7096774193548387,0.7741935483870968,0.5767441860465117,124 36 | 34,sonic182__aiosonic,0.8393782383419689,0.8393782383419689,0.7927461139896373,0.291981845688351,193 37 | 35,lucaswerkmeister__tool-quickcategories,0.7580645161290323,0.7870967741935484,0.7645161290322581,0.6581740976645435,310 38 | 36,ohjames__babies,0.8507462686567164,0.8805970149253731,0.8805970149253731,0.32524271844660196,67 39 | 37,ClearcodeHQ__mirakuru,0.6907216494845361,0.7422680412371134,0.7731958762886598,0.383399209486166,97 40 | 38,futursolo__magichttp,0.509009009009009,0.6261261261261262,0.6981981981981982,0.35406698564593303,222 41 | 39,ActivityWatch__aw-research,0.8205128205128205,0.8376068376068376,0.8205128205128205,0.3667711598746082,117 42 | 40,lebrice__blurred-GAN,0.7012987012987013,0.6233766233766234,0.6883116883116883,0.18421052631578946,77 43 | 41,webrecorder__browsertrix,0.61328125,0.6875,0.6875,0.30732292917166865,256 44 | 42,seattleflu__id3c,0.7381889763779528,0.8326771653543307,0.8523622047244095,0.4425087108013937,508 45 | 43,TomerFi__aioswitcher,0.7263157894736842,0.8105263157894737,0.8947368421052632,0.336283185840708,190 46 | 44,jfly__jfly.github.io,0.696969696969697,0.8181818181818182,0.9696969696969697,0.23741007194244604,33 47 | 45,jreese__aql,0.5487179487179488,0.5743589743589743,0.5641025641025641,0.437219730941704,195 48 | 46,yeraydiazdiaz__wait_for_it.py,1.0,1.0,1.0,0.375,6 49 | 47,cliffxuan__mew,0.7638888888888888,0.8333333333333334,0.7777777777777778,0.36923076923076925,72 50 | 48,everyclass__everyclass-server,0.7237237237237237,0.7897897897897898,0.7477477477477478,0.3729003359462486,333 51 | 49,knowark__estimark,0.8630705394190872,0.8423236514522822,0.8464730290456431,0.33565459610027853,241 52 | -------------------------------------------------------------------------------- /data/9.12.accs_by_project.csv: -------------------------------------------------------------------------------- 1 | ,project,no-neighbors,non-incr,double-traversal,label_size,label_rate,labels 2 | 17,basilisp-lang__basilisp,0.5882352941176471,0.6233979625369701,0.6194544857048965,1.2431810713112061,0.5269264069264069,3043 3 | 29,nabla-c0d3__sslyze,0.7120085015940489,0.7523910733262487,0.7630180658873539,1.2699256110520722,0.5689238210399032,941 4 | 27,kornicameister__axion,0.4909596662030598,0.5090403337969402,0.45479833101529904,1.6244784422809457,0.6301489921121823,719 5 | 1,scalableminds__webknossos-connect,0.6757188498402555,0.7507987220447284,0.7763578274760383,1.20926517571885,0.830238726790451,626 6 | 30,rakitaj__daily-programmer,0.737037037037037,0.7425925925925926,0.7666666666666667,1.5277777777777777,0.5009276437847866,540 7 | 42,seattleflu__id3c,0.7381889763779528,0.8326771653543307,0.8523622047244095,1.2066929133858268,0.4425087108013937,508 8 | 13,marcosschroh__dataclasses-avroschema,0.5056947608200456,0.5034168564920274,0.4760820045558087,1.4738041002277904,0.42787524366471735,439 9 | 25,nubark__instark,0.8165374677002584,0.8578811369509044,0.8578811369509044,1.1808785529715762,0.3467741935483871,387 10 | 48,everyclass__everyclass-server,0.7237237237237237,0.7897897897897898,0.7477477477477478,1.2012012012012012,0.3729003359462486,333 11 | 35,lucaswerkmeister__tool-quickcategories,0.7580645161290323,0.7870967741935484,0.7645161290322581,1.4,0.6581740976645435,310 12 | 6,flopp__GpxTrackPoster,0.5498154981549815,0.6568265682656826,0.7158671586715867,1.2472324723247232,0.6545893719806763,271 13 | 18,paulcwatts__drf-json-schema,0.6590909090909091,0.6439393939393939,0.6931818181818182,1.4090909090909092,0.528,264 14 | 41,webrecorder__browsertrix,0.61328125,0.6875,0.6875,1.17578125,0.30732292917166865,256 15 | 49,knowark__estimark,0.8630705394190872,0.8423236514522822,0.8464730290456431,1.3278008298755186,0.33565459610027853,241 16 | 38,futursolo__magichttp,0.509009009009009,0.6261261261261262,0.6981981981981982,1.1486486486486487,0.35406698564593303,222 17 | 20,AxelVoitier__lookups,0.3424657534246575,0.4155251141552511,0.3789954337899543,1.5936073059360731,0.33435114503816793,219 18 | 22,joshtemple__lkml,0.7317073170731707,0.7658536585365854,0.7707317073170732,1.4341463414634146,0.39575289575289574,205 19 | 45,jreese__aql,0.5487179487179488,0.5743589743589743,0.5641025641025641,1.4564102564102563,0.437219730941704,195 20 | 34,sonic182__aiosonic,0.8393782383419689,0.8393782383419689,0.7927461139896373,1.1347150259067358,0.291981845688351,193 21 | 43,TomerFi__aioswitcher,0.7263157894736842,0.8105263157894737,0.8947368421052632,1.1421052631578947,0.336283185840708,190 22 | 26,albertyw__git-browse,0.5944444444444444,0.7166666666666667,0.7277777777777777,1.0777777777777777,0.46272493573264784,180 23 | 23,kitsuyui__bamboo-crawler,0.6331360946745562,0.6449704142011834,0.6331360946745562,1.7514792899408285,0.6450381679389313,169 24 | 9,reddit__baseplate.py-upgrader,0.7548387096774194,0.7612903225806451,0.8129032258064516,1.264516129032258,0.4305555555555556,155 25 | 11,linw1995__data_extractor,0.6,0.676923076923077,0.6538461538461539,1.4230769230769231,0.23090586145648312,130 26 | 33,ocf__slackbridge,0.6935483870967742,0.7096774193548387,0.7741935483870968,1.3709677419354838,0.5767441860465117,124 27 | 28,payscale__fables,0.7338709677419355,0.7903225806451613,0.75,1.6774193548387097,0.3668639053254438,124 28 | 5,road-master__video-archiver,0.8536585365853658,0.9105691056910569,0.9186991869918699,1.1626016260162602,0.41694915254237286,123 29 | 19,ShadowTemplate__beautiful-python-3,0.8991596638655462,0.957983193277311,0.9411764705882353,1.1428571428571428,0.3287292817679558,119 30 | 39,ActivityWatch__aw-research,0.8205128205128205,0.8376068376068376,0.8205128205128205,1.6752136752136753,0.3667711598746082,117 31 | 2,flopp__unicode-explorer,0.8205128205128205,0.8547008547008547,0.9230769230769231,1.1452991452991452,0.5763546798029556,117 32 | 14,antonagestam__collectfast,0.6896551724137931,0.7155172413793104,0.7586206896551724,1.2586206896551724,0.5155555555555555,116 33 | 0,srittau__FakeSMTPd,0.8362068965517241,0.8793103448275862,0.896551724137931,1.2327586206896552,0.3483483483483483,116 34 | 10,typeddjango__pytest-mypy-plugins,0.7565217391304347,0.782608695652174,0.782608695652174,1.6956521739130435,0.732484076433121,115 35 | 21,Gerschtli__teamspeak-update-notifier,0.6,0.6666666666666666,0.6857142857142857,1.161904761904762,0.5497382198952879,105 36 | 16,brettkromkamp__topic-db,0.8476190476190476,0.8761904761904762,0.9238095238095239,1.0952380952380953,0.3633217993079585,105 37 | 37,ClearcodeHQ__mirakuru,0.6907216494845361,0.7422680412371134,0.7731958762886598,1.8556701030927836,0.383399209486166,97 38 | 32,JakobGM__quelf,0.5434782608695652,0.5652173913043478,0.5543478260869565,1.184782608695652,0.5287356321839081,92 39 | 7,boompig__book-classics,0.9058823529411765,0.9411764705882353,0.9411764705882353,1.2705882352941176,0.5862068965517241,85 40 | 40,lebrice__blurred-GAN,0.7012987012987013,0.6233766233766234,0.6883116883116883,1.077922077922078,0.18421052631578946,77 41 | 24,amplify-education__python-hcl2,0.6301369863013698,0.6438356164383562,0.821917808219178,1.0273972602739727,0.73,73 42 | 47,cliffxuan__mew,0.7638888888888888,0.8333333333333334,0.7777777777777778,1.0277777777777777,0.36923076923076925,72 43 | 12,albertyw__albertyw.com,0.5285714285714286,0.6285714285714286,0.6428571428571429,1.6,0.37433155080213903,70 44 | 36,ohjames__babies,0.8507462686567164,0.8805970149253731,0.8805970149253731,1.3582089552238805,0.32524271844660196,67 45 | 15,ammarnajjar__pautomate,0.8518518518518519,0.8703703703703703,0.9259259259259259,1.3888888888888888,0.5046728971962616,54 46 | 4,jelford__webwatcher,0.7659574468085106,0.7659574468085106,0.7659574468085106,1.3191489361702127,0.2186046511627907,47 47 | 8,dropbox__sqlalchemy-stubs,0.5952380952380952,0.6666666666666666,0.6428571428571429,1.2619047619047619,0.6086956521739131,42 48 | 44,jfly__jfly.github.io,0.696969696969697,0.8181818181818182,0.9696969696969697,1.1515151515151516,0.23741007194244604,33 49 | 3,eirannejad__calcatime,0.46875,0.53125,0.5,1.9375,0.8,32 50 | 31,silasary__discord_connotations,0.6153846153846154,0.6923076923076923,0.6923076923076923,1.4615384615384615,0.2708333333333333,13 51 | 46,yeraydiazdiaz__wait_for_it.py,1.0,1.0,1.0,1.3333333333333333,0.375,6 52 | -------------------------------------------------------------------------------- /data/9.14.accs_by_project.csv: -------------------------------------------------------------------------------- 1 | ,project,non-incr,random-twice,double-traversal,label_size,label_rate,labels 2 | 17,basilisp-lang__basilisp,0.6087098886705959,0.49279633267845446,0.6175507531106745,1.2426326129666012,0.5258264462809917,3054 3 | 29,nabla-c0d3__sslyze,0.7771911298838438,0.7613516367476241,0.7697993664202746,1.2692713833157339,0.5660490137477585,947 4 | 27,kornicameister__axion,0.4884726224783862,0.4654178674351585,0.4812680115273775,1.644092219020173,0.6257889990982868,694 5 | 1,scalableminds__webknossos-connect,0.7231467473524962,0.7367624810892587,0.7609682299546142,1.2420574886535551,0.8304020100502513,661 6 | 35,lucaswerkmeister__tool-quickcategories,0.7504553734061931,0.7723132969034608,0.7905282331511839,1.5100182149362478,0.6052921719955898,549 7 | 42,seattleflu__id3c,0.7093235831809872,0.6983546617915904,0.716636197440585,1.2084095063985374,0.4494658997534922,547 8 | 30,rakitaj__daily-programmer,0.7388888888888889,0.75,0.7666666666666667,1.5277777777777777,0.5009276437847866,540 9 | 25,nubark__instark,0.8877284595300261,0.8903394255874674,0.8772845953002611,1.1723237597911227,0.3647619047619048,383 10 | 16,brettkromkamp__topic-db,0.7261904761904762,0.6934523809523809,0.7232142857142857,1.1458333333333333,0.5989304812834224,336 11 | 48,everyclass__everyclass-server,0.7777777777777778,0.8198198198198198,0.7837837837837838,1.2012012012012012,0.3729003359462486,333 12 | 6,flopp__GpxTrackPoster,0.6863468634686347,0.7011070110701108,0.6826568265682657,1.2472324723247232,0.6545893719806763,271 13 | 13,marcosschroh__dataclasses-avroschema,0.5576923076923077,0.5423076923076923,0.5346153846153846,1.353846153846154,0.29312288613303267,260 14 | 41,webrecorder__browsertrix,0.73046875,0.71484375,0.703125,1.17578125,0.30732292917166865,256 15 | 18,paulcwatts__drf-json-schema,0.6245059288537549,0.6600790513833992,0.6561264822134387,1.4031620553359683,0.5337552742616034,253 16 | 49,knowark__estimark,0.8672199170124482,0.8962655601659751,0.8630705394190872,1.3278008298755186,0.33565459610027853,241 17 | 38,futursolo__magichttp,0.581081081081081,0.6216216216216216,0.6756756756756757,1.1486486486486487,0.35406698564593303,222 18 | 20,AxelVoitier__lookups,0.4155251141552511,0.4794520547945205,0.4474885844748858,1.5936073059360731,0.337442218798151,219 19 | 22,joshtemple__lkml,0.751219512195122,0.8341463414634146,0.8146341463414634,1.4341463414634146,0.39575289575289574,205 20 | 45,jreese__aql,0.645320197044335,0.7044334975369458,0.7044334975369458,1.438423645320197,0.44713656387665196,203 21 | 34,sonic182__aiosonic,0.8031088082901554,0.8238341968911918,0.7305699481865285,1.1347150259067358,0.28550295857988167,193 22 | 43,TomerFi__aioswitcher,0.7684210526315789,0.8421052631578947,0.8789473684210526,1.1421052631578947,0.336283185840708,190 23 | 26,albertyw__git-browse,0.7555555555555555,0.7222222222222222,0.7777777777777778,1.0777777777777777,0.46272493573264784,180 24 | 23,kitsuyui__bamboo-crawler,0.6568047337278107,0.6804733727810651,0.6804733727810651,1.7514792899408285,0.6450381679389313,169 25 | 9,reddit__baseplate.py-upgrader,0.8292682926829268,0.8292682926829268,0.8475609756097561,1.25,0.4385026737967914,164 26 | 11,linw1995__data_extractor,0.6538461538461539,0.6846153846153846,0.6692307692307692,1.4230769230769231,0.2480916030534351,130 27 | 33,ocf__slackbridge,0.7419354838709677,0.75,0.7983870967741935,1.3709677419354838,0.5767441860465117,124 28 | 28,payscale__fables,0.7903225806451613,0.7661290322580645,0.7983870967741935,1.6774193548387097,0.3668639053254438,124 29 | 5,road-master__video-archiver,0.8617886178861789,0.8617886178861789,0.8699186991869918,1.1626016260162602,0.41694915254237286,123 30 | 19,ShadowTemplate__beautiful-python-3,0.9159663865546218,0.9159663865546218,0.9159663865546218,1.1428571428571428,0.3287292817679558,119 31 | 2,flopp__unicode-explorer,0.7777777777777778,0.8632478632478633,0.8632478632478633,1.1452991452991452,0.5763546798029556,117 32 | 39,ActivityWatch__aw-research,0.8290598290598291,0.8376068376068376,0.8547008547008547,1.6752136752136753,0.3667711598746082,117 33 | 14,antonagestam__collectfast,0.7327586206896551,0.7586206896551724,0.7586206896551724,1.2586206896551724,0.5155555555555555,116 34 | 0,srittau__FakeSMTPd,0.8706896551724138,0.9224137931034483,0.9310344827586207,1.2327586206896552,0.3483483483483483,116 35 | 10,typeddjango__pytest-mypy-plugins,0.7217391304347827,0.782608695652174,0.782608695652174,1.6956521739130435,0.732484076433121,115 36 | 21,Gerschtli__teamspeak-update-notifier,0.7619047619047619,0.7714285714285715,0.8571428571428571,1.161904761904762,0.5497382198952879,105 37 | 37,ClearcodeHQ__mirakuru,0.7731958762886598,0.8041237113402062,0.7938144329896907,1.8556701030927836,0.383399209486166,97 38 | 32,JakobGM__quelf,0.45652173913043476,0.45652173913043476,0.4673913043478261,1.184782608695652,0.5287356321839081,92 39 | 7,boompig__book-classics,0.9294117647058824,0.9294117647058824,0.9294117647058824,1.2705882352941176,0.5862068965517241,85 40 | 40,lebrice__blurred-GAN,0.7380952380952381,0.6904761904761905,0.6785714285714286,1.0714285714285714,0.18876404494382024,84 41 | 24,amplify-education__python-hcl2,0.7397260273972602,0.7808219178082192,0.821917808219178,1.0273972602739727,0.73,73 42 | 47,cliffxuan__mew,0.704225352112676,0.5915492957746479,0.5915492957746479,1.028169014084507,0.36597938144329895,71 43 | 12,albertyw__albertyw.com,0.6714285714285714,0.6571428571428571,0.6571428571428571,1.6,0.37433155080213903,70 44 | 36,ohjames__babies,0.8507462686567164,0.8208955223880597,0.8208955223880597,1.3582089552238805,0.32524271844660196,67 45 | 15,ammarnajjar__pautomate,0.9074074074074074,0.9259259259259259,0.9259259259259259,1.3888888888888888,0.5046728971962616,54 46 | 4,jelford__webwatcher,0.7659574468085106,0.7446808510638298,0.7446808510638298,1.3191489361702127,0.2186046511627907,47 47 | 8,dropbox__sqlalchemy-stubs,0.6904761904761905,0.6904761904761905,0.6904761904761905,1.2619047619047619,0.6086956521739131,42 48 | 44,jfly__jfly.github.io,0.7878787878787878,0.8181818181818182,0.8181818181818182,1.1515151515151516,0.23741007194244604,33 49 | 3,eirannejad__calcatime,0.90625,0.90625,0.90625,1.9375,0.8,32 50 | 31,silasary__discord_connotations,0.6666666666666666,0.6666666666666666,0.6666666666666666,1.4,0.28846153846153844,15 51 | 46,yeraydiazdiaz__wait_for_it.py,1.0,1.0,1.0,1.3333333333333333,0.375,6 52 | -------------------------------------------------------------------------------- /data/TypeT5-Workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/TypeT5-Workflow.png -------------------------------------------------------------------------------- /data/code/bad_code_1.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | # A recursive fibonacci function 5 | def fib(n: str) -> list[int]: 6 | if n == 0: 7 | return 0 8 | elif n == 1: 9 | return 1 10 | else: 11 | return fib(n - 1) + fib(n - 2) 12 | 13 | 14 | def t_add(x: str, y: str) -> int: 15 | r = x + y 16 | return r 17 | 18 | 19 | x: int = fib(3) 20 | bad_y: str = 1 21 | -------------------------------------------------------------------------------- /data/code/bad_code_2.py: -------------------------------------------------------------------------------- 1 | from bad_code_1 import fib 2 | 3 | i: int = 4 4 | fib(i) 5 | -------------------------------------------------------------------------------- /data/code/code_with_slash.py: -------------------------------------------------------------------------------- 1 | class SlashClass: 2 | def __init__(self, check_interval: int, folder: Path, /) -> None: 3 | self._autolocked: Dict[Path, int] = {} 4 | self._lockers: Dict[Path, "DirectEdit"] = {} 5 | self._to_lock: Items = [] 6 | -------------------------------------------------------------------------------- /data/code/dummy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/code/dummy/__init__.py -------------------------------------------------------------------------------- /data/code/dummy/dummy_1.py: -------------------------------------------------------------------------------- 1 | def f_int(x: int) -> int: 2 | return x 3 | -------------------------------------------------------------------------------- /data/code/dummy/dummy_2.py: -------------------------------------------------------------------------------- 1 | from dummy.dummy_1 import f_int 2 | 3 | s: str = f_int(2) 4 | -------------------------------------------------------------------------------- /data/code/env_code_1.py: -------------------------------------------------------------------------------- 1 | # Env example 1: no existing annotations 2 | 3 | 4 | def fib(n): 5 | if n == 0: 6 | return 0 7 | elif n == 1: 8 | return 1 9 | else: 10 | return fib(n - 1) + fib(n - 2) 11 | 12 | 13 | def foo(bar): 14 | return fib(bar) 15 | 16 | 17 | def int_add(a, b): 18 | return a + b + "c" 19 | 20 | 21 | def int_tripple_add(a, b, c): 22 | return a + b + c 23 | -------------------------------------------------------------------------------- /data/code/env_code_2.py: -------------------------------------------------------------------------------- 1 | # Env example 2: some existing annotations 2 | 3 | from typing import * 4 | 5 | 6 | def fib(n: int): 7 | if n == 0: 8 | return 0 9 | elif n == 1: 10 | return 1 11 | else: 12 | return fib(n - 1) + fib(n - 2) 13 | 14 | 15 | def foo(bar: int): 16 | return fib(bar) 17 | 18 | 19 | class Bar: 20 | z: str = "hello" 21 | w: str 22 | 23 | def __init__(self, x: int): 24 | self.x: int = x 25 | self.y: Optional[int] = None 26 | self.reset(self.z) 27 | 28 | def reset(self, w0): 29 | self.w = w0 30 | 31 | def foo(self, z: str) -> int: 32 | return self.x + len(z) 33 | 34 | 35 | bar: Bar = Bar(3) 36 | -------------------------------------------------------------------------------- /data/code/good_code_1.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any # [added by SPOT] 3 | from typing import Optional 4 | 5 | print(math.sin(4)) 6 | 7 | x_str: str = "x" 8 | y: Any = 1 9 | z_str: str = x_str + y 10 | 11 | 12 | class Foo: 13 | def __init__(self, x: int): 14 | self.x: int = x 15 | self.y: Optional[int] = None 16 | self.z = "hello" 17 | 18 | def foo(self, z: str) -> int: 19 | return self.x + len(z) 20 | -------------------------------------------------------------------------------- /data/ex_repo/ex_code_1.py: -------------------------------------------------------------------------------- 1 | # Env example 1: no existing annotations 2 | 3 | good = 5 4 | 5 | 6 | def fib(n): 7 | if n == 0: 8 | return 0 9 | elif n == 1: 10 | return 1 11 | else: 12 | return fib(n - 1) + fib(n - 2) 13 | 14 | 15 | class Wrapper: 16 | x_elem: int 17 | y: str 18 | 19 | @staticmethod 20 | def foo(bar): 21 | return fib(bar) 22 | 23 | def inc(self): 24 | self.x_elem += 1 25 | return self.y 26 | 27 | 28 | def int_add(a, b): 29 | # this is a strange function 30 | return a + b + "c" 31 | 32 | 33 | def int_tripple_add(a, b, c): 34 | return a + b + c 35 | -------------------------------------------------------------------------------- /data/ex_repo/ex_code_2.py: -------------------------------------------------------------------------------- 1 | # Env example 2: some existing annotations 2 | 3 | from typing import * 4 | from ex_code_1 import int_add 5 | 6 | 7 | def fib(n: int): 8 | if n == 0: 9 | return 0 10 | elif n == 1: 11 | return 1 12 | else: 13 | return fib(n - 1) + fib(n - 2) 14 | 15 | 16 | def foo(bar: int): 17 | return fib(bar) 18 | 19 | 20 | class Bar: 21 | z: str = "hello" 22 | w: str 23 | 24 | def __init__(self, x: int): 25 | self.x: int = x 26 | self.y: Optional[int] = None 27 | self.reset(self.z) 28 | 29 | def reset(self, w0): 30 | self.w = w0 31 | 32 | def foo(self, z: str) -> int: 33 | return int_add(self.x, len(z)) 34 | 35 | 36 | bar: Bar = Bar(3) 37 | -------------------------------------------------------------------------------- /data/repos_split.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/repos_split.pkl -------------------------------------------------------------------------------- /data/useful_repos.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/useful_repos.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -i https://pypi.org/simple 2 | absl-py==1.4.0 ; python_version >= '3.6' 3 | aiohttp==3.8.4 ; python_version >= '3.6' 4 | aiosignal==1.3.1 ; python_version >= '3.7' 5 | appdirs==1.4.4 6 | asttokens==2.2.1 7 | async-timeout==4.0.2 ; python_version >= '3.6' 8 | attrs==22.2.0 ; python_version >= '3.6' 9 | backcall==0.2.0 10 | cachetools==5.3.0 ; python_version ~= '3.7' 11 | certifi==2022.12.7 ; python_version >= '3.6' 12 | cffi==1.15.1 13 | charset-normalizer==3.0.1 ; python_version >= '3.6' 14 | click==8.1.3 ; python_version >= '3.7' 15 | colored==1.4.4 16 | comm==0.1.2 ; python_version >= '3.6' 17 | cryptography==39.0.1 ; python_version >= '3.6' 18 | datasets==2.4.0 19 | dateparser==1.1.7 20 | debugpy==1.6.6 ; python_version >= '3.7' 21 | decorator==5.1.1 ; python_version >= '3.5' 22 | dill==0.3.5.1 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6' 23 | docker-pycreds==0.4.0 24 | executing==1.2.0 25 | filelock==3.9.0 ; python_version >= '3.7' 26 | frozenlist==1.3.3 ; python_version >= '3.7' 27 | fsspec[http]==2023.1.0 ; python_version >= '3.7' 28 | gitdb==4.0.10 ; python_version >= '3.7' 29 | gitpython==3.1.31 ; python_version >= '3.7' 30 | google-auth==2.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' 31 | google-auth-oauthlib==0.4.6 ; python_version >= '3.6' 32 | grpcio==1.51.1 ; python_version >= '3.7' 33 | huggingface-hub==0.12.0 34 | idna==3.4 ; python_version >= '3.5' 35 | ipykernel==6.21.2 36 | ipython==8.10.0 ; python_version >= '3.8' 37 | ipywidgets==8.0.4 38 | jedi==0.18.2 ; python_version >= '3.6' 39 | jupyter-client==8.0.3 ; python_version >= '3.8' 40 | jupyter-core==5.2.0 ; python_version >= '3.8' 41 | jupyterlab-widgets==3.0.5 ; python_version >= '3.7' 42 | libcst==0.4.2 43 | markdown==3.4.1 ; python_version >= '3.7' 44 | markupsafe==2.1.2 ; python_version >= '3.7' 45 | matplotlib-inline==0.1.6 ; python_version >= '3.5' 46 | multidict==6.0.4 ; python_version >= '3.7' 47 | multiprocess==0.70.13 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6' 48 | mypy==0.991 49 | mypy-extensions==1.0.0 ; python_version >= '3.5' 50 | nest-asyncio==1.5.6 ; python_version >= '3.5' 51 | numpy==1.24.2 ; python_version >= '3.8' 52 | nvidia-cublas-cu11==11.10.3.66 ; platform_system == 'Linux' 53 | nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == 'Linux' 54 | nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == 'Linux' 55 | nvidia-cudnn-cu11==8.5.0.96 ; platform_system == 'Linux' 56 | oauthlib==3.2.2 ; python_version >= '3.6' 57 | packaging==23.0 ; python_version >= '3.7' 58 | pandas==1.5.3 59 | parso==0.8.3 ; python_version >= '3.6' 60 | pathtools==0.1.2 61 | pexpect==4.8.0 ; sys_platform != 'win32' 62 | pickleshare==0.7.5 63 | platformdirs==3.0.0 ; python_version >= '3.7' 64 | plotly==5.13.0 65 | prettytable==3.4.1 66 | prompt-toolkit==3.0.36 ; python_full_version >= '3.6.2' 67 | protobuf==4.22.0 ; python_version >= '3.10' and sys_platform == 'linux' 68 | psutil==5.9.4 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 69 | ptyprocess==0.7.0 70 | pure-eval==0.2.2 71 | pyarrow==11.0.0 ; python_version >= '3.7' 72 | pyasn1==0.4.8 73 | pyasn1-modules==0.2.8 74 | pycparser==2.21 75 | pydeprecate==0.3.2 ; python_version >= '3.6' 76 | pygments==2.14.0 ; python_version >= '3.6' 77 | pyrsistent==0.19.3 78 | python-dateutil==2.8.2 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 79 | pytorch-lightning==1.7.7 80 | pytz==2022.7.1 81 | pytz-deprecation-shim==0.1.0.post0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' 82 | pyyaml==6.0 ; python_version >= '3.6' 83 | pyzmq==25.0.0 ; python_version >= '3.6' 84 | regex==2022.10.31 ; python_version >= '3.6' 85 | requests==2.28.2 ; python_version >= '3.7' and python_version < '4' 86 | requests-oauthlib==1.3.1 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 87 | responses==0.18.0 ; python_version >= '3.7' 88 | rsa==4.9 ; python_version >= '3.6' 89 | sentry-sdk==1.15.0 90 | setproctitle==1.3.2 ; python_version >= '3.7' 91 | setuptools==67.3.2 ; python_version >= '3.7' 92 | six==1.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 93 | smmap==5.0.0 ; python_version >= '3.6' 94 | stack-data==0.6.2 95 | tenacity==8.2.1 ; python_version >= '3.6' 96 | tensorboard==2.12.0 ; python_version >= '3.8' 97 | tensorboard-data-server==0.7.0 ; python_version >= '3.7' 98 | tensorboard-plugin-wit==1.8.1 99 | termcolor==1.1.0 100 | tokenizers==0.12.1 101 | tomli==2.0.1 ; python_version < '3.11' 102 | torch==1.13.1 ; python_full_version >= '3.7.0' 103 | torchmetrics==0.11.1 ; python_version >= '3.7' 104 | tornado==6.2 ; python_version >= '3.7' 105 | tqdm==4.64.1 106 | traitlets==5.9.0 ; python_version >= '3.7' 107 | transformers==4.21.3 108 | types-aiofiles==22.1.0.8 109 | types-all==1.0.0 110 | types-annoy==1.17.8.1 111 | types-atomicwrites==1.4.5.1 112 | types-backports==0.1.3 113 | types-backports-abc==0.5.2 114 | types-bleach==6.0.0.0 115 | types-boto==2.49.18.5 116 | types-cachetools==5.3.0.0 117 | types-certifi==2021.10.8.3 118 | types-cffi==1.15.1.5 119 | types-characteristic==14.3.7 120 | types-chardet==5.0.4.1 121 | types-click==7.1.8 122 | types-click-spinner==0.1.13.2 123 | types-colorama==0.4.15.7 124 | types-contextvars==2.4.7 125 | types-croniter==1.3.2.4 126 | types-cryptography==3.3.23.2 127 | types-dataclasses==0.6.6 128 | types-dateparser==1.1.4.7 129 | types-datetimerange==2.0.0.1 130 | types-decorator==5.1.8.2 131 | types-deprecated==1.2.9 132 | types-docopt==0.6.11.1 133 | types-docutils==0.19.1.4 134 | types-emoji==2.1.0.1 135 | types-enum34==1.1.8 136 | types-fb303==1.0.0 137 | types-filelock==3.2.7 138 | types-first==2.0.5 139 | types-flask==1.1.6 140 | types-freezegun==1.1.10 141 | types-frozendict==2.0.9 142 | types-futures==3.3.8 143 | types-geoip2==3.0.0 144 | types-ipaddress==1.0.8 145 | types-itsdangerous==1.1.6 146 | types-jack-client==0.5.10.5 147 | types-jinja2==2.11.9 148 | types-kazoo==0.1.3 149 | types-markdown==3.4.2.4 150 | types-markupsafe==1.1.10 151 | types-maxminddb==1.5.0 152 | types-mock==5.0.0.4 153 | types-mypy-extensions==1.0.0.1 154 | types-nmap==0.1.6 155 | types-openssl-python==0.1.3 156 | types-orjson==3.6.2 157 | types-paramiko==3.0.0.3 158 | types-pathlib2==2.3.0 159 | types-pillow==9.4.0.12 160 | types-pkg-resources==0.1.3 161 | types-polib==1.1.12.1 162 | types-protobuf==4.21.0.6 163 | types-pyaudio==0.2.16.5 164 | types-pycurl==7.45.2.2 165 | types-pyfarmhash==0.3.1 166 | types-pyjwt==1.7.1 167 | types-pymssql==2.1.0 168 | types-pymysql==1.0.19.3 169 | types-pyopenssl==23.0.0.3 170 | types-pyrfc3339==1.1.1.2 171 | types-pysftp==0.2.17.1 172 | types-python-dateutil==2.8.19.7 173 | types-python-gflags==3.1.7.1 174 | types-python-slugify==8.0.0.0 175 | types-pytz==2022.7.1.0 176 | types-pyvmomi==8.0.0.0 177 | types-pyyaml==6.0.12.6 178 | types-redis==4.5.1.1 179 | types-requests==2.28.11.13 180 | types-retry==0.9.9.1 181 | types-routes==2.5.0 182 | types-scribe==2.0.0 183 | types-simplejson==3.18.0.0 184 | types-singledispatch==4.0.0.0 185 | types-six==1.16.21.4 186 | types-tabulate==0.9.0.0 187 | types-termcolor==1.1.6 188 | types-toml==0.10.8.4 189 | types-tornado==5.1.1 190 | types-typed-ast==1.5.8.3 191 | types-tzlocal==4.2.2.2 192 | types-ujson==5.7.0.0 193 | types-urllib3==1.26.25.6 194 | types-waitress==2.1.4.4 195 | types-werkzeug==1.0.9 196 | types-xxhash==3.0.5.1 197 | -e . 198 | typing-extensions==4.5.0 ; python_version >= '3.7' 199 | typing-inspect==0.8.0 200 | tzdata==2022.7 ; python_version >= '3.6' 201 | tzlocal==4.2 ; python_version >= '3.6' 202 | urllib3==1.26.14 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' 203 | wandb==0.13.10 204 | wcwidth==0.2.6 205 | werkzeug==2.2.3 ; python_version >= '3.7' 206 | wheel==0.38.4 ; python_version >= '3.7' 207 | widgetsnbextension==4.0.5 ; python_version >= '3.7' 208 | xxhash==3.2.0 ; python_version >= '3.6' 209 | yarl==1.8.2 ; python_version >= '3.7' 210 | -------------------------------------------------------------------------------- /scripts/archive/analyze_dagger.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import os\n", 13 | "from typing import *\n", 14 | "\n", 15 | "from typet5.utils import proj_root, get_data_dir\n", 16 | "\n", 17 | "os.chdir(proj_root())\n", 18 | "\n", 19 | "datadir = get_data_dir()" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/home/jiayi/Projects/SPOT/.venv/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n", 32 | " warn(f\"Failed to load image Python extension: {e}\")\n" 33 | ] 34 | }, 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "quicktest=False\n", 40 | "Loading datasets: tk_dataset-all_labels-drop_comments\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "# experiment configurations\n", 46 | "\n", 47 | "from typet5.data import (\n", 48 | " TokenizedSrcSet,\n", 49 | " get_dataset_name,\n", 50 | " load_tokenized_srcsets,\n", 51 | " TypeCheckSettings,\n", 52 | ")\n", 53 | "from typet5.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper\n", 54 | "from typet5.train import TrainingConfig, TypeCheckArgs\n", 55 | "\n", 56 | "config = TrainingConfig(\n", 57 | " quicktest=False,\n", 58 | " all_labels=True,\n", 59 | " ctx_size=2048,\n", 60 | " left_margin=1024,\n", 61 | " right_margin=1023,\n", 62 | " modifications=\"no_type_checker\",\n", 63 | ")\n", 64 | "gpu_id = 1\n", 65 | "TypeCheckSettings.temp_path = f\"DAgger-{gpu_id}\"\n", 66 | "\n", 67 | "print(f\"quicktest={config.quicktest}\")\n", 68 | "\n", 69 | "project_name = \"test-SPOT\" if config.quicktest else \"SPOT\"\n", 70 | "train_ctx_args = config.train_ctx_args()\n", 71 | "tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)\n", 72 | "\n", 73 | "datasets_name = get_dataset_name(\n", 74 | " drop_comments=config.drop_comments,\n", 75 | " all_labels=config.all_labels,\n", 76 | " imports_in_preamble=config.imports_in_preamble,\n", 77 | ")\n", 78 | "\n", 79 | "model_name = \"DAgger-model--\" + config.as_name()\n", 80 | "\n", 81 | "tk_dataset = load_tokenized_srcsets(\n", 82 | " datadir,\n", 83 | " datasets_name,\n", 84 | " data_reduction=config.data_reduction,\n", 85 | " quicktest=config.quicktest,\n", 86 | ")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 7, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# load the model\n", 96 | "from typet5.model import load_model_spot, DefaultTokenizer\n", 97 | "from typet5.model import ModelWrapper\n", 98 | "from typet5.dagger import DAggerModel\n", 99 | "import torch\n", 100 | "\n", 101 | "dec_args = DecodingArgs(\n", 102 | " sampling_max_tokens=8 * config.ctx_size,\n", 103 | " ctx_args=config.dec_ctx_args(),\n", 104 | " do_sample=True,\n", 105 | " num_beams=None, # try greedy decoding\n", 106 | " top_p=0.9,\n", 107 | ")\n", 108 | "\n", 109 | "wrapper = ModelWrapper.from_pretrained(datadir / f\"checkpoints/saved/{model_name}\")\n", 110 | "device = torch.device(f\"cuda:{gpu_id}\" if torch.cuda.is_available() else \"cpu\")\n", 111 | "wrapper.to(device)\n", 112 | "dmodel = DAggerModel(wrapper)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 5, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stderr", 122 | "output_type": "stream", 123 | "text": [ 124 | "compute_preexisting_fdbks: 100%|██████████| 50/50 [00:04<00:00, 11.99it/s]\n", 125 | "eval_on_data: 100%|██████████| 16950/16950 [46:45<00:00, 6.04it/s]\n" 126 | ] 127 | }, 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "partial_acc (ImNone): 69.01% (count=16.9k)\n", 133 | "full_acc (ImNone): 64.86% (count=16.9k)\n", 134 | "partial_acc: 67.43% (count=16.9k)\n", 135 | "ast_acc: 56.81% (count=21.3k)\n", 136 | "full_acc: 62.03% (count=16.9k)\n", 137 | "partial_acc_by_cat:\n", 138 | " FuncArg: 62.96% (count=8.0k)\n", 139 | " FuncReturn: 78.14% (count=5.7k)\n", 140 | " ClassAtribute: 58.21% (count=2.7k)\n", 141 | " GlobalVar: 75.96% (count=104)\n", 142 | " LocalVar: 64.22% (count=531)\n", 143 | "partial_acc_by_pos:\n", 144 | " range(0, 1): 80.39% (count=933)\n", 145 | " range(1, 2): 77.13% (count=870)\n", 146 | " range(2, 4): 77.58% (count=1.5k)\n", 147 | " range(4, 8): 74.05% (count=2.4k)\n", 148 | " range(8, 16): 72.80% (count=3.1k)\n", 149 | " range(16, 32): 67.31% (count=3.2k)\n", 150 | " range(32, 64): 63.86% (count=2.3k)\n", 151 | " range(64, 128): 53.42% (count=1.1k)\n", 152 | " range(128, 256): 40.00% (count=735)\n", 153 | " range(256, 512): 32.89% (count=672)\n", 154 | " range(512, 1024): 52.83% (count=53)\n", 155 | "avg_label_size: 1.2589\n", 156 | "avg_pred_size: 1.1258\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "# evaluate (greedy)\n", 162 | "from typet5.utils import pretty_print_dict, pretty_show_dict\n", 163 | "from typet5.visualization import visualize_preds_on_code\n", 164 | "\n", 165 | "eval_r = await dmodel.eval_on_data(tk_dataset[\"test\"])\n", 166 | "pretty_print_dict(eval_r.accuracies)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 10, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stderr", 176 | "output_type": "stream", 177 | "text": [ 178 | "compute_preexisting_fdbks: 100%|██████████| 6/6 [00:02<00:00, 2.16it/s]\n", 179 | "eval_on_data: 100%|██████████| 180/180 [04:57<00:00, 1.65s/it]\n" 180 | ] 181 | }, 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "partial_acc (ImNone): 63.89% (count=180)\n", 187 | "full_acc (ImNone): 61.11% (count=180)\n", 188 | "partial_acc: 65.00% (count=180)\n", 189 | "ast_acc: 49.59% (count=244)\n", 190 | "full_acc: 60.56% (count=180)\n", 191 | "partial_acc_by_cat:\n", 192 | " FuncArg: 58.42% (count=101)\n", 193 | " FuncReturn: 71.23% (count=73)\n", 194 | " GlobalVar: 100.00% (count=3)\n", 195 | " LocalVar: 100.00% (count=3)\n", 196 | "partial_acc_by_pos:\n", 197 | " range(0, 1): 90.91% (count=11)\n", 198 | " range(1, 2): 100.00% (count=8)\n", 199 | " range(2, 4): 100.00% (count=12)\n", 200 | " range(4, 8): 69.23% (count=13)\n", 201 | " range(8, 16): 54.55% (count=11)\n", 202 | " range(16, 32): 43.75% (count=16)\n", 203 | " range(32, 64): 50.00% (count=32)\n", 204 | " range(64, 128): 60.94% (count=64)\n", 205 | " range(128, 256): 76.92% (count=13)\n", 206 | "avg_label_size: 1.3556\n", 207 | "avg_pred_size: 1.1\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "# evaluate\n", 213 | "from numpy import roll\n", 214 | "from typet5.utils import pretty_print_dict, pretty_show_dict\n", 215 | "from typet5.visualization import visualize_preds_on_code\n", 216 | "\n", 217 | "dmodel.wrapper.args = DecodingArgs(\n", 218 | " sampling_max_tokens=8 * config.ctx_size,\n", 219 | " ctx_args=config.dec_ctx_args(),\n", 220 | " do_sample=True, # use necleus sampling during training\n", 221 | " top_p=0.9,\n", 222 | ")\n", 223 | "\n", 224 | "eval_r = await dmodel.eval_on_data(tk_dataset[\"train\"][1:105:10])\n", 225 | "pretty_print_dict(eval_r.accuracies)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 12, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stderr", 235 | "output_type": "stream", 236 | "text": [ 237 | "Exporting: 100%|██████████| 11/11 [00:00<00:00, 131.04it/s]\n", 238 | "Computing accuracies: 100%|██████████| 11/11 [00:00<00:00, 9128.88it/s]\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "from typet5.data import TokenizedSrcSet\n", 244 | "from typet5.visualization import export_preds_on_code\n", 245 | "\n", 246 | "viz_ds = TokenizedSrcSet(tk_dataset[\"test\"].repos_root, eval_r.final_srcs)\n", 247 | "viz_preds = eval_r.final_preds\n", 248 | "\n", 249 | "export_preds_on_code(viz_ds, viz_preds, proj_root() / \"caches/DAgger-preds-on-code\")" 250 | ] 251 | } 252 | ], 253 | "metadata": { 254 | "kernelspec": { 255 | "display_name": "Python 3.10.4 ('.venv': pipenv)", 256 | "language": "python", 257 | "name": "python3" 258 | }, 259 | "language_info": { 260 | "codemirror_mode": { 261 | "name": "ipython", 262 | "version": 3 263 | }, 264 | "file_extension": ".py", 265 | "mimetype": "text/x-python", 266 | "name": "python", 267 | "nbconvert_exporter": "python", 268 | "pygments_lexer": "ipython3", 269 | "version": "3.10.4" 270 | }, 271 | "orig_nbformat": 4, 272 | "vscode": { 273 | "interpreter": { 274 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 275 | } 276 | } 277 | }, 278 | "nbformat": 4, 279 | "nbformat_minor": 2 280 | } 281 | -------------------------------------------------------------------------------- /scripts/archive/code-t5-workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "\n", 19 | "import os\n", 20 | "import pickle\n", 21 | "from concurrent.futures import ProcessPoolExecutor\n", 22 | "from pathlib import Path\n", 23 | "from typing import *\n", 24 | "\n", 25 | "import pandas as pd\n", 26 | "import plotly.express as px\n", 27 | "\n", 28 | "from typet5.data import GitRepo, ModuleRemapUnpickler\n", 29 | "from typet5.type_env import (\n", 30 | " AnnotPath,\n", 31 | " MypyChecker,\n", 32 | " SelectAnnotations,\n", 33 | " TypeInfAction,\n", 34 | " TypeInfEnv,\n", 35 | " TypeInfState,\n", 36 | " collect_annotations,\n", 37 | " mypy_checker,\n", 38 | ")\n", 39 | "from typet5.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file\n", 40 | "\n", 41 | "os.chdir(proj_root())\n", 42 | "\n", 43 | "datadir = Path(os.getenv(\"datadir\"))\n", 44 | "repos_dir = datadir / \"SPOT-data/repos\"\n", 45 | "\n", 46 | "useful_repos_path = proj_root() / \"scripts\" / \"useful_repos.pkl\"\n", 47 | "rename_module = lambda n: \"typet5.data\" if n == \"typet5.data_prepare\" else n\n", 48 | "with useful_repos_path.open(\"rb\") as f:\n", 49 | " useful_repos: list[GitRepo] = ModuleRemapUnpickler(f, rename_module).load()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stderr", 59 | "output_type": "stream", 60 | "text": [ 61 | "/home/jiayi/Projects/SPOT/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1402: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", 62 | " warnings.warn(\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "# loading pre-trained model and tokenizer\n", 68 | "\n", 69 | "model_dir = datadir/\"checkpoints/saved/SPOT-CodeT5-with_margin/\"\n", 70 | "\n", 71 | "import torch\n", 72 | "from transformers import (\n", 73 | " DataCollatorForSeq2Seq,\n", 74 | " RobertaTokenizer,\n", 75 | " T5ForConditionalGeneration,\n", 76 | ")\n", 77 | "from transformers.models.t5 import T5ForConditionalGeneration\n", 78 | "\n", 79 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 80 | "tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(model_dir)\n", 81 | "model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(\n", 82 | " model_dir\n", 83 | ").to(device)\n", 84 | "max_target_length = 128" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 25, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "from typet5.data import mask_type_annots, output_ids_as_types, tokenize_masked\n", 94 | "\n", 95 | "test_code = \"\"\"\n", 96 | "@dataclass\n", 97 | "class GitRepo:\n", 98 | " author: str\n", 99 | " name: str\n", 100 | " url: str\n", 101 | " stars: int\n", 102 | " forks: int\n", 103 | "\n", 104 | " def authorname(self):\n", 105 | " return self.author + \"__\" + self.name\n", 106 | "\n", 107 | " def repo_dir(self, repos_dir: Path) -> Path:\n", 108 | " return repos_dir / \"downloaded\" / self.authorname()\n", 109 | "\n", 110 | " def download(self, repos_dir: Path, timeout=None) -> bool:\n", 111 | " pass\n", 112 | "\"\"\"\n", 113 | "\n", 114 | "\n", 115 | "def run_model(code: str, num_beams=16):\n", 116 | " tks = tokenize_masked(mask_type_annots((Path('no_source'), code)), tokenizer, device)\n", 117 | " input_ids = tks[\"input_ids\"]\n", 118 | " with torch.no_grad():\n", 119 | " loss = model.forward(**tks).loss\n", 120 | " dec = model.generate(\n", 121 | " input_ids,\n", 122 | " max_length=max_target_length,\n", 123 | " num_beams=num_beams,\n", 124 | " # do_sample=True,\n", 125 | " )[0]\n", 126 | " return {\n", 127 | " \"loss\": loss,\n", 128 | " \"predicted_types\": output_ids_as_types(dec, tokenizer),\n", 129 | " \"labels\": output_ids_as_types(tks[\"labels\"][0], tokenizer),\n", 130 | " \"generation\": tokenizer.decode(dec),\n", 131 | " \"input_ids\": input_ids[0],\n", 132 | " \"output_ids\": dec,\n", 133 | " }\n", 134 | "\n", 135 | "\n", 136 | "result = run_model(test_code, num_beams=10)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 27, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "\n", 149 | "@dataclass\n", 150 | "class GitRepo:\n", 151 | " author:\n", 152 | " name:\n", 153 | " url:\n", 154 | " stars:\n", 155 | " forks:\n", 156 | "\n", 157 | " def authorname(self):\n", 158 | " return self.author + \"__\" + self.name\n", 159 | "\n", 160 | " def repo_dir(self, repos_dir:) ->:\n", 161 | " return repos_dir / \"downloaded\" / self.authorname()\n", 162 | "\n", 163 | " def download(self, repos_dir:, timeout=None) ->:\n", 164 | " pass\n", 165 | "\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "# Step 1: Replace all types to predict with special tokens\n", 171 | "print(tokenizer.decode(result['input_ids']))" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 28, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "['', 'Ċ', '@', 'data', 'class', 'Ċ', 'class', 'ĠGit', 'Repo', ':', 'Ċ', 'ĠĠĠ', 'Ġauthor', ':', '', 'Ċ', 'ĠĠĠ', 'Ġname', ':', '', 'Ċ', 'ĠĠĠ', 'Ġurl', ':', '', 'Ċ', 'ĠĠĠ', 'Ġstars', ':', '', 'Ċ', 'ĠĠĠ', 'Ġfor', 'ks', ':', '', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġauthor', 'name', '(', 'self', '):', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġself', '.', 'author', 'Ġ+', 'Ġ\"__', '\"', 'Ġ+', 'Ġself', '.', 'name', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġrepo', '_', 'dir', '(', 'self', ',', 'Ġrepos', '_', 'dir', ':', '', ')', 'Ġ->', '', ':', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġrepos', '_', 'dir', 'Ġ/', 'Ġ\"', 'down', 'loaded', '\"', 'Ġ/', 'Ġself', '.', 'author', 'name', '()', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġdownload', '(', 'self', ',', 'Ġrepos', '_', 'dir', ':', '', ',', 'Ġtimeout', '=', 'None', ')', 'Ġ->', '', ':', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġpass', 'Ċ', '']\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "# Step 2: Tokenize using Byte Pair Encoding (BPE)\n", 189 | "print(tokenizer.convert_ids_to_tokens(result['input_ids']))" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 29, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "['', '', '', 'str', '', 'str', '', 'str', '', 'List', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', '', 'List', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', '', 'Path', '', 'Path', 'Ġ.', 'ĠPath', '', 'Path', 'Ġ.', 'ĠPath', '', 'Path', 'Ġ.', 'ĠPath', 'Ġ[', 'Ġstr', ']', '']\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "# Step 3: Let model predict a sequence of types using BPE\n", 207 | "print(tokenizer.convert_ids_to_tokens(result['output_ids']))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 30, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "[str, str, str, Any, Any, Path, Path.Path, Path.Path, Path.Path[str]]\n" 220 | ] 221 | } 222 | ], 223 | "source": [ 224 | "# Step 4: Extract the predicted types\n", 225 | "print(result['predicted_types'])" 226 | ] 227 | } 228 | ], 229 | "metadata": { 230 | "interpreter": { 231 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 232 | }, 233 | "kernelspec": { 234 | "display_name": "Python 3.9.7 ('.venv': pipenv)", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.10.4" 249 | }, 250 | "orig_nbformat": 4 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 2 254 | } 255 | -------------------------------------------------------------------------------- /scripts/archive/debug_mypy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 26, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Daemon stopped\n", 13 | "Daemon started\n", 14 | "---checking code_1---\n", 15 | "Success: no issues found in 1 source file\n", 16 | "---checking code_2---\n", 17 | "code.py:5: error: Incompatible return value type (got \"int\", expected \"str\")\n", 18 | "code.py:7: error: Incompatible return value type (got \"int\", expected \"str\")\n", 19 | "Found 2 errors in 1 file (checked 2 source files)\n", 20 | "---checking code_3---\n", 21 | "code.py:5: error: Incompatible return value type (got \"int\", expected \"str\")\n", 22 | "code.py:7: error: Incompatible return value type (got \"int\", expected \"str\")\n", 23 | "Found 2 errors in 1 file (checked 2 source files)\n", 24 | "test finished.\n", 25 | "---wait and check code_3 again---\n", 26 | "Success: no issues found in 2 source files\n", 27 | "test finished.\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "from pathlib import Path\n", 33 | "import subprocess\n", 34 | "import time\n", 35 | "\n", 36 | "# no type error\n", 37 | "code_1 = '''\n", 38 | "from typing import Any\n", 39 | "def fib(n: int) -> int:\n", 40 | " if n == 0:\n", 41 | " return 0\n", 42 | " elif n == 1:\n", 43 | " return 1\n", 44 | " else:\n", 45 | " return fib(n-1) + fib(n-2)\n", 46 | "'''\n", 47 | "\n", 48 | "# incorrect return type\n", 49 | "code_2 = '''\n", 50 | "from typing import Any\n", 51 | "def fib(n: int) -> str:\n", 52 | " if n == 0:\n", 53 | " return 0\n", 54 | " elif n == 1:\n", 55 | " return 1\n", 56 | " else:\n", 57 | " return fib(n-1) + fib(n-2)\n", 58 | "'''\n", 59 | "\n", 60 | "# changed return type to Any, should not error\n", 61 | "code_3 = '''\n", 62 | "from typing import Any\n", 63 | "def fib(n: int) -> Any:\n", 64 | " if n == 0:\n", 65 | " return 0\n", 66 | " elif n == 1:\n", 67 | " return 1\n", 68 | " else:\n", 69 | " return fib(n-1) + fib(n-2)\n", 70 | "'''\n", 71 | "\n", 72 | "# this should be the dmypy path in the current virtual env\n", 73 | "dmypy_path = '/home/jiayi/Projects/SPOT/.venv/bin/dmypy'\n", 74 | "\n", 75 | "check_dir = Path(\"temp/type_check\")\n", 76 | "check_dir.mkdir(exist_ok=True, parents=True)\n", 77 | "with open(check_dir / \"code.py\", \"w\") as f:\n", 78 | " f.write(code_1)\n", 79 | "subprocess.run(['python', dmypy_path, 'restart', '--', '--follow-imports=skip'],cwd=check_dir)\n", 80 | "\n", 81 | "print('---checking code_1---')\n", 82 | "subprocess.run(['python', dmypy_path, 'check', '.'],cwd=check_dir)\n", 83 | "\n", 84 | "with open(check_dir / \"code.py\", \"w\") as f:\n", 85 | " f.write(code_2)\n", 86 | "print('---checking code_2---')\n", 87 | "subprocess.run(['python', dmypy_path, 'recheck', \"--update\", \"code.py\"],cwd=check_dir)\n", 88 | "\n", 89 | "with open(check_dir / \"code.py\", \"w\") as f:\n", 90 | " f.write(code_3)\n", 91 | "print('---checking code_3---')\n", 92 | "subprocess.run(['python', dmypy_path, 'recheck', \"--update\", \"code.py\"],cwd=check_dir)\n", 93 | "print(\"test finished.\")\n", 94 | "\n", 95 | "print('---wait and check code_3 again---')\n", 96 | "time.sleep(1.0) # will not work if waiting time is shorter\n", 97 | "with open(check_dir / \"code.py\", \"w\") as f: # need this rewriting\n", 98 | " f.write(code_3)\n", 99 | "subprocess.run(['python', dmypy_path, 'recheck', \"--update\", \"code.py\"],cwd=check_dir)\n", 100 | "print(\"test finished.\")" 101 | ] 102 | } 103 | ], 104 | "metadata": { 105 | "interpreter": { 106 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 107 | }, 108 | "kernelspec": { 109 | "display_name": "Python 3.10.4 ('.venv': pipenv)", 110 | "language": "python", 111 | "name": "python3" 112 | }, 113 | "language_info": { 114 | "codemirror_mode": { 115 | "name": "ipython", 116 | "version": 3 117 | }, 118 | "file_extension": ".py", 119 | "mimetype": "text/x-python", 120 | "name": "python", 121 | "nbconvert_exporter": "python", 122 | "pygments_lexer": "ipython3", 123 | "version": "3.10.4" 124 | }, 125 | "orig_nbformat": 4 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 2 129 | } 130 | -------------------------------------------------------------------------------- /scripts/archive/kill_dmypy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ps aux | grep -i dmypy | awk '{print $2}' | xargs kill 4 | -------------------------------------------------------------------------------- /scripts/archive/train_dagger.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import asyncio 4 | import os 5 | from typing import * 6 | 7 | from typet5.utils import get_data_dir, not_none, proj_root 8 | 9 | os.chdir(proj_root()) 10 | 11 | datadir = get_data_dir() 12 | 13 | # %% 14 | # experiment configurations 15 | 16 | from termcolor import colored 17 | 18 | from typet5.data import TypeCheckSettings, get_tk_dataset_name, load_tokenized_srcsets 19 | from typet5.model import CtxArgs, DecodingArgs, ModelType, ModelWrapper 20 | from typet5.train import TrainingConfig, TypeCheckArgs 21 | 22 | use_type_checker = False 23 | 24 | config = TrainingConfig( 25 | quicktest=False, 26 | ctx_size=2048, 27 | left_margin=1024, 28 | right_margin=1023, 29 | modifications="no_type_checker", 30 | ) 31 | gpu_id = 0 32 | TypeCheckSettings.temp_path = f"DAgger-{gpu_id}" 33 | 34 | if config.quicktest: 35 | print(colored("quicktest: True", "red")) 36 | 37 | project_name = "test-SPOT" if config.quicktest else "SPOT" 38 | train_ctx_args = config.train_ctx_args() 39 | tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation) 40 | 41 | datasets_name = get_tk_dataset_name( 42 | drop_comments=config.drop_comments, 43 | ) 44 | 45 | model_name = "DAgger-model--" + config.as_name() 46 | 47 | tk_dataset = load_tokenized_srcsets( 48 | datadir, 49 | datasets_name, 50 | data_reduction=config.data_reduction, 51 | quicktest=config.quicktest, 52 | ) 53 | 54 | 55 | import torch 56 | 57 | from typet5.dagger import DAggerModel 58 | 59 | # %% 60 | # initialize the model 61 | from typet5.model import DefaultTokenizer, ModelWrapper, load_model_spot 62 | 63 | train_dec_args = DecodingArgs( 64 | sampling_max_tokens=8 * config.ctx_size, 65 | ctx_args=config.dec_ctx_args(), 66 | do_sample=True, # use necleus sampling during training 67 | top_p=0.9, 68 | ) 69 | 70 | model = load_model_spot("Salesforce/codet5-base") 71 | wrapper = ModelWrapper(model, DefaultTokenizer, train_dec_args) 72 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") 73 | wrapper.to(device) 74 | dmodel = DAggerModel(wrapper, use_type_checker=use_type_checker) 75 | 76 | 77 | # %% 78 | # pre-train evaluation 79 | # from typet5.utils import pretty_print_dict 80 | 81 | # eval_r = asyncio.run(dmodel.eval_on_data(tk_dataset["test"][0:50])) 82 | # pretty_print_dict(eval_r.accuracies) 83 | 84 | 85 | import shutil 86 | 87 | import wandb 88 | 89 | # %% 90 | # train the model 91 | from typet5.dagger import DAggerArgs 92 | from typet5.utils import run_long_task 93 | 94 | ckpt_dir = datadir / f"checkpoints/running/{model_name}" 95 | 96 | with run_long_task("DAgger training"): 97 | wandb.init( 98 | project=project_name, 99 | name=model_name, 100 | config=config.as_dict(), 101 | dir=str(datadir), 102 | ) 103 | 104 | dargs = DAggerArgs( 105 | ckpt_dir, 106 | grad_accum_steps=config.grad_accum_labels, 107 | replay_buffer_size=1000, 108 | ) 109 | 110 | finished = False 111 | try: 112 | asyncio.run( 113 | dmodel.train_on_data( 114 | tk_dataset["train"], 115 | dargs, 116 | log_fn=lambda t, x: wandb.log({"train/step": t, **x}), 117 | ) 118 | ) 119 | finished = True 120 | finally: 121 | save_tpye = "saved" if finished else "saved-emergent" 122 | save_path = datadir / f"checkpoints/{save_tpye}/{model_name}" 123 | print(colored(f"Saving trained model to: {save_path}", "blue")) 124 | shutil.rmtree(save_path, ignore_errors=True) 125 | wrapper.save(save_path) 126 | 127 | # %% 128 | # post-train full evaluation 129 | from typet5.utils import PickleCache, pretty_print_dict, pretty_show_dict 130 | from typet5.visualization import string_to_html 131 | 132 | test_dec_args = DecodingArgs( 133 | sampling_max_tokens=8 * config.ctx_size, 134 | ctx_args=CtxArgs( 135 | ctx_size=4096, 136 | left_margin=2048, 137 | right_margin=1023, 138 | ), 139 | do_sample=False, 140 | num_beams=8, 141 | ) 142 | dmodel.wrapper.args = test_dec_args 143 | 144 | eval_cache = PickleCache(save_path / "eval_cache") # type: ignore 145 | 146 | eval_r = eval_cache.cached( 147 | "eval_test", lambda: asyncio.run(dmodel.eval_on_data(tk_dataset["test"])) 148 | ) 149 | pretty_print_dict(eval_r.accuracies) 150 | 151 | 152 | def wandb_string(s: str): 153 | return wandb.Html(string_to_html(s)) 154 | 155 | 156 | wandb.log({"test/accuracies": wandb_string(pretty_show_dict(eval_r.accuracies))}) 157 | 158 | # %% 159 | # compute valid set performance 160 | import re 161 | 162 | from typet5.utils import not_none 163 | 164 | validset = tk_dataset["valid"][0:-1:3] 165 | # dmodel.wrapper.args = train_dec_args 166 | 167 | with run_long_task("DAgger evaluating (valid set)"): 168 | for model_path in ckpt_dir.glob("step=*"): 169 | print(colored(f"Evaluating model checkpoint: {model_path}", "blue")) 170 | m = not_none(re.match("step=(.+)", model_path.name)).groups()[0] 171 | step = int(m) 172 | wrapper = ModelWrapper.load(model_path) 173 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") 174 | wrapper.to(device) 175 | dmodel = DAggerModel(wrapper) 176 | eval_r = asyncio.run(dmodel.eval_on_data(validset)) 177 | wandb.log( 178 | { 179 | "valid/full_acc": eval_r.accuracies["full_acc"].acc, 180 | "train/step": step, 181 | } 182 | ) 183 | -------------------------------------------------------------------------------- /scripts/collect_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import json\n", 13 | "import logging\n", 14 | "import os\n", 15 | "import shutil\n", 16 | "import subprocess\n", 17 | "import time\n", 18 | "from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed\n", 19 | "from pathlib import Path\n", 20 | "\n", 21 | "import libcst as cst\n", 22 | "from tqdm import tqdm\n", 23 | "\n", 24 | "from typet5.data import GitRepo, get_dataset_dir\n", 25 | "from typet5.type_env import collect_annots_info, mypy_checker\n", 26 | "from typet5.utils import proj_root, read_file, write_file, not_none\n", 27 | "\n", 28 | "os.chdir(proj_root())" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 4, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Repos already downloaded.\n", 41 | "Reading last updates...\n" 42 | ] 43 | }, 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "100%|██████████| 4890/4890 [00:27<00:00, 175.00it/s]" 49 | ] 50 | }, 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Downloaded 4890/5996 repos.\n" 56 | ] 57 | }, 58 | { 59 | "name": "stderr", 60 | "output_type": "stream", 61 | "text": [ 62 | "\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "# download all candidate repos\n", 68 | "\n", 69 | "all_repos = json.loads(read_file(\"data/mypy-dependents-by-stars.json\"))\n", 70 | "all_repos = [GitRepo.from_json(r) for r in all_repos]\n", 71 | "# all_repos=all_repos[:10] # for testing\n", 72 | "\n", 73 | "repos_dir = get_dataset_dir(\"ManyTypes4Py\") / \"repos\"\n", 74 | "\n", 75 | "def clear_downloaded_repos(repos_dir):\n", 76 | " shutil.rmtree(repos_dir)\n", 77 | "\n", 78 | "\n", 79 | "def download_repos(\n", 80 | " to_download: list[GitRepo], repos_dir, download_timeout=10.0, max_workers=10\n", 81 | ") -> list[GitRepo]:\n", 82 | " def download_single(repo: GitRepo):\n", 83 | " try:\n", 84 | " if repo.download(repos_dir, timeout=download_timeout):\n", 85 | " repo.read_last_update(repos_dir)\n", 86 | " return repo\n", 87 | " else:\n", 88 | " return None\n", 89 | " except subprocess.TimeoutExpired:\n", 90 | " return None\n", 91 | " except Exception as e:\n", 92 | " logging.warning(f\"Failed to download {repo.name}. Exception: {e}\")\n", 93 | " return None\n", 94 | "\n", 95 | " print(\"Downloading repos from Github...\")\n", 96 | " t_start = time.time()\n", 97 | " with ThreadPoolExecutor(max_workers=max_workers) as executor:\n", 98 | " fs = [executor.submit(download_single, repo) for repo in to_download]\n", 99 | " rs = [f.result() for f in tqdm(as_completed(fs), total=len(fs))]\n", 100 | " print(f\"Downloading took {time.time() - t_start} seconds.\")\n", 101 | " downloaded = [r for r in rs if r is not None]\n", 102 | " return downloaded\n", 103 | "\n", 104 | "\n", 105 | "if not repos_dir.exists():\n", 106 | " (repos_dir / \"downloading\").mkdir(parents=True)\n", 107 | " (repos_dir / \"downloaded\").mkdir(parents=True)\n", 108 | " downloaded_repos = download_repos(all_repos, repos_dir)\n", 109 | " print(\"Deleting failed repos...\")\n", 110 | " shutil.rmtree(repos_dir / \"downloading\")\n", 111 | "else:\n", 112 | " print(\"Repos already downloaded.\")\n", 113 | " downloaded_dirs = set(d.name for d in (repos_dir / \"downloaded\").iterdir())\n", 114 | " downloaded_repos = [r for r in all_repos if r.authorname() in downloaded_dirs]\n", 115 | " print(\"Reading last updates...\")\n", 116 | " for r in tqdm(downloaded_repos):\n", 117 | " r.read_last_update(repos_dir)\n", 118 | "print(f\"Downloaded {len(downloaded_repos)}/{len(all_repos)} repos.\")\n" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 33, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "1218 / 4890 repos are updated within a year.\n" 131 | ] 132 | }, 133 | { 134 | "name": "stderr", 135 | "output_type": "stream", 136 | "text": [ 137 | "100%|██████████| 1218/1218 [00:05<00:00, 243.41it/s]" 138 | ] 139 | }, 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "1181/1218 repos are within the size limit.\n" 145 | ] 146 | }, 147 | { 148 | "name": "stderr", 149 | "output_type": "stream", 150 | "text": [ 151 | "\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "# filter out repos that are too old or too big\n", 157 | "\n", 158 | "from datetime import datetime, timezone\n", 159 | "\n", 160 | "date_threshold = datetime(2021, 4, 20)\n", 161 | "new_repos = [r for r in downloaded_repos if not_none(r.last_update) > date_threshold]\n", 162 | "print(f\"{len(new_repos)} / {len(downloaded_repos)} repos are updated within a year.\")\n", 163 | "loc_limit = 50000\n", 164 | "\n", 165 | "small_repos = []\n", 166 | "for rep in tqdm(new_repos):\n", 167 | " try:\n", 168 | " loc = rep.count_lines_of_code(repos_dir)\n", 169 | " if loc < loc_limit:\n", 170 | " small_repos.append(rep)\n", 171 | " except UnicodeDecodeError:\n", 172 | " # nothing we can do\n", 173 | " pass\n", 174 | " except Exception as e:\n", 175 | " logging.warning(f\"Failed to count lines of code for {rep.name}. Exception: {e}\")\n", 176 | "\n", 177 | "print(\n", 178 | " f\"{len(small_repos)}/{len(new_repos)} repos are within the size limit ({loc_limit} LOC).\"\n", 179 | ")\n" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "# filter away repos with too few annotations\n", 189 | "\n", 190 | "def count_repo_annots(rep):\n", 191 | " try:\n", 192 | " rep.count_annotations(repos_dir)\n", 193 | " if rep.n_type_annots / rep.lines_of_code > 0.05:\n", 194 | " return rep\n", 195 | " except Exception as e:\n", 196 | " logging.warning(f\"Failed to count annotations for {rep.name}. Exception: {e}\")\n", 197 | " return None\n", 198 | "\n", 199 | "\n", 200 | "with ProcessPoolExecutor(max_workers=30) as executor:\n", 201 | " fs = [executor.submit(count_repo_annots, rep) for rep in small_repos]\n", 202 | " rs = [f.result() for f in tqdm(as_completed(fs), total=len(fs))]\n", 203 | "useful_repos: list[GitRepo] = [\n", 204 | " r for r in rs if r is not None and \"typeshed\" not in r.name\n", 205 | "]\n", 206 | "\n", 207 | "print(\n", 208 | " f\"{len(useful_repos)}/{len(small_repos)} repos are parsable and have enough portions of type annotations.\"\n", 209 | ")\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 35, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "Total number of manual annotations: 343595\n", 222 | "Total number of type places: 544497\n", 223 | "Total number of lines of code: 3342911\n" 224 | ] 225 | }, 226 | { 227 | "data": { 228 | "text/plain": [ 229 | "[GitRepo(author='skorokithakis', name='catt', url='https://github.com/skorokithakis/catt', stars=1740, forks=762, lines_of_code=2036, last_update=datetime.datetime(2022, 4, 10, 1, 30, 43), n_type_annots=140, n_type_places=433),\n", 230 | " GitRepo(author='encode', name='databases', url='https://github.com/encode/databases', stars=769, forks=48, lines_of_code=3124, last_update=datetime.datetime(2022, 3, 6, 12, 25, 10), n_type_annots=323, n_type_places=498),\n", 231 | " GitRepo(author='Curt-Park', name='rainbow-is-all-you-need', url='https://github.com/Curt-Park/rainbow-is-all-you-need', stars=490, forks=110, lines_of_code=107, last_update=datetime.datetime(2022, 1, 13, 23, 4, 48), n_type_annots=26, n_type_places=30),\n", 232 | " GitRepo(author='jreese', name='aiomultiprocess', url='https://github.com/jreese/aiomultiprocess', stars=585, forks=45, lines_of_code=1140, last_update=datetime.datetime(2022, 2, 4, 21, 28, 7), n_type_annots=138, n_type_places=213),\n", 233 | " GitRepo(author='instaloader', name='instaloader', url='https://github.com/instaloader/instaloader', stars=874, forks=134, lines_of_code=5417, last_update=datetime.datetime(2022, 4, 18, 9, 49, 34), n_type_annots=569, n_type_places=843)]" 234 | ] 235 | }, 236 | "execution_count": 35, 237 | "metadata": {}, 238 | "output_type": "execute_result" 239 | } 240 | ], 241 | "source": [ 242 | "# Some summary statistics\n", 243 | "\n", 244 | "# print total number of manual annotations\n", 245 | "n_total_annots = sum(not_none(rep.n_type_annots) for rep in useful_repos)\n", 246 | "print(\"Total number of manual annotations:\", n_total_annots)\n", 247 | "\n", 248 | "# print total number of type places\n", 249 | "n_total_places = sum(not_none(rep.n_type_places) for rep in useful_repos)\n", 250 | "print(\"Total number of type places:\", n_total_places)\n", 251 | "\n", 252 | "# print total number of lines of code\n", 253 | "n_total_lines = sum(not_none(rep.lines_of_code) for rep in useful_repos)\n", 254 | "print(\"Total number of lines of code:\", n_total_lines)\n", 255 | "\n", 256 | "# print average number of type annotations per line of code excluding projects with more than 1000 lines of code\n", 257 | "n_avg_annots = (\n", 258 | " sum(not_none(rep.n_type_annots) for rep in useful_repos if rep.lines_of_code < 1000)\n", 259 | " / n_total_lines\n", 260 | ")\n" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 5, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "name": "stdout", 270 | "output_type": "stream", 271 | "text": [ 272 | "[GitRepo(author='typeddjango', name='pytest-mypy-plugins', url='https://github.com/typeddjango/pytest-mypy-plugins', stars=12, forks=0, lines_of_code=1039, last_update=datetime.datetime(2022, 4, 18, 23, 25, 40), n_type_annots=155, n_type_places=158), GitRepo(author='jfly', name='jfly.github.io', url='https://github.com/jfly/jfly.github.io', stars=0, forks=0, lines_of_code=650, last_update=datetime.datetime(2022, 4, 12, 8, 23, 39), n_type_annots=39, n_type_places=122), GitRepo(author='seattleflu', name='id3c', url='https://github.com/seattleflu/id3c', stars=2, forks=0, lines_of_code=8883, last_update=datetime.datetime(2022, 4, 21, 15, 38, 59), n_type_annots=675, n_type_places=1068)]\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "import pickle\n", 278 | "\n", 279 | "useful_repos_path = proj_root() / \"scripts\" / \"useful_repos.pkl\"\n", 280 | "with useful_repos_path.open(\"wb\") as f:\n", 281 | " pickle.dump(useful_repos, f)\n", 282 | "print(f\"Saved {len(useful_repos)} useful repos to {useful_repos_path}.\")\n", 283 | "with useful_repos_path.open(\"rb\") as f:\n", 284 | " print(pickle.load(f)[:3])\n" 285 | ] 286 | }, 287 | { 288 | "attachments": {}, 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "The cell below tries to split the dataset based on the original dataset split used by the paper. But since that the list of repos returned by the GitHub query above can change over time, some repos might no longer be present. You might consider perform your own splitting if the issue is serious." 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "from typet5.utils import pickle_load, Path, proj_root, tqdm\n", 302 | "import shutil\n", 303 | "from typet5.data import GitRepo, get_dataset_dir\n", 304 | "\n", 305 | "os.chdir(proj_root())\n", 306 | "\n", 307 | "repos_split = pickle_load(Path(\"data/repos_split.pkl\"))\n", 308 | "repos_dir = get_dataset_dir(\"ManyTypes4Py\") / \"repos\"\n", 309 | "\n", 310 | "for split, repos in repos_split.items():\n", 311 | " for r in tqdm(repos, desc=f\"Moving {split} repos.\"):\n", 312 | " r: GitRepo\n", 313 | " split: str\n", 314 | " src = repos_dir / r.authorname()\n", 315 | " (repos_dir / split).mkdir(parents=True, exist_ok=True)\n", 316 | " dest = repos_dir / split / r.authorname()\n", 317 | " if src.exists():\n", 318 | " shutil.move(src, dest)\n", 319 | " else:\n", 320 | " print(f\"Repo {r.name} not found.\")" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "interpreter": { 326 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 327 | }, 328 | "kernelspec": { 329 | "display_name": "Python 3.9.7 ('.venv': pipenv)", 330 | "language": "python", 331 | "name": "python3" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.10.8" 344 | }, 345 | "orig_nbformat": 4 346 | }, 347 | "nbformat": 4, 348 | "nbformat_minor": 2 349 | } 350 | -------------------------------------------------------------------------------- /scripts/experiments/eval_file_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "import asyncio\n", 22 | "import os\n", 23 | "from typing import *\n", 24 | "\n", 25 | "import torch\n", 26 | "import wandb\n", 27 | "from typet5.data import get_tk_dataset_name\n", 28 | "from typet5.function_dataset import data_project_from_dir\n", 29 | "from typet5.model import ModelWrapper\n", 30 | "from typet5.train import TrainingConfig, PreprocessArgs\n", 31 | "from typet5.type_env import AccuracyMetric\n", 32 | "from typet5.utils import (\n", 33 | " PickleCache,\n", 34 | " assert_eq,\n", 35 | " get_dataroot,\n", 36 | " get_dataset_dir,\n", 37 | " get_eval_dir,\n", 38 | " get_gpu_id,\n", 39 | " get_model_dir,\n", 40 | " pickle_dump,\n", 41 | " pmap,\n", 42 | " pretty_print_dict,\n", 43 | " pretty_show_dict,\n", 44 | " proj_root,\n", 45 | " run_long_task,\n", 46 | " write_file,\n", 47 | ")\n", 48 | "from typet5.visualization import string_to_html\n", 49 | "from termcolor import colored\n", 50 | "\n", 51 | "os.chdir(proj_root())\n", 52 | "\n", 53 | "\n", 54 | "def wandb_string(s: str):\n", 55 | " return wandb.Html(string_to_html(s))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 8, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "GPU_ID not set, using: 1\n", 68 | "\u001b[32mUse GPU: 1\u001b[0m\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "# experiment configurations\n", 74 | "quicktest = False\n", 75 | "\n", 76 | "gpu_id = get_gpu_id(1)\n", 77 | "# model_name = \"model-v6--TrainingConfig(func_only=False, left_margin=2048, preamble_size=800, right_margin=1536)\"\n", 78 | "model_name = \"model-v6--TrainingConfig(func_only=False, imports_in_preamble=False, stub_in_preamble=False, left_margin=2048, right_margin=1536)\"\n", 79 | "pre_args = PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)\n", 80 | "dataset_name = \"ManyTypes4Py\"\n", 81 | "# dataset_name = \"InferTypes4Py\"\n", 82 | "# dataset_name = \"SPOT-src\"\n", 83 | "experiment_name = dataset_name + \": \" + model_name\n", 84 | "\n", 85 | "print(colored(f\"Use GPU: {gpu_id}\", \"green\"))\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 9, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Loading TokenizedSrcSets: /mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)\n", 98 | "254M\t/mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "# load test data\n", 104 | "from typet5.data import load_tokenized_srcsets, create_tokenized_srcsets\n", 105 | "\n", 106 | "sdata_name = get_tk_dataset_name(dataset_name, pre_args, func_only=False)\n", 107 | "sdata_path = get_dataroot() / \"TokenizedSrcSets\" / sdata_name\n", 108 | "recreate=False\n", 109 | "if recreate or not sdata_path.exists():\n", 110 | " create_tokenized_srcsets(\n", 111 | " dataset_name,\n", 112 | " sdata_path,\n", 113 | " func_only=False,\n", 114 | " pre_args=pre_args,\n", 115 | " )\n", 116 | "tk_dataset = load_tokenized_srcsets(\n", 117 | " sdata_path,\n", 118 | " quicktest=quicktest,\n", 119 | " sets_to_load=[\"test\"],\n", 120 | ")\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 10, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# model evaluation\n", 130 | "\n", 131 | "from typet5.function_decoding import (\n", 132 | " DecodingOrders,\n", 133 | " EvalResult,\n", 134 | " PreprocessArgs,\n", 135 | " RolloutCtx,\n", 136 | ")\n", 137 | "from typet5.function_dataset import sigmap_from_file_predictions\n", 138 | "from typet5.static_analysis import SignatureErrorAnalysis\n", 139 | "\n", 140 | "# load model\n", 141 | "model = ModelWrapper.from_pretrained(get_model_dir() / model_name)\n", 142 | "device = torch.device(f\"cuda:{gpu_id}\" if torch.cuda.is_available() else \"cpu\")\n", 143 | "model.to(device)\n", 144 | "\n", 145 | "ctx_args = model.args.ctx_args\n", 146 | "model.args.sampling_max_tokens = ctx_args.ctx_size\n", 147 | "model.args.do_sample = False\n", 148 | "model.args.num_beams = 10\n", 149 | "model.args.tokens_per_type = 16\n", 150 | "\n", 151 | "eval_cache = PickleCache(get_eval_dir(dataset_name, model_name) / f\"{pre_args}\")\n", 152 | "# eval_cache.clear()\n", 153 | "pre_r = eval_cache.cached(\n", 154 | " \"DatasetPredResult.pkl\",\n", 155 | " lambda: model.eval_on_dataset(tk_dataset[\"test\"]),\n", 156 | ")\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 11, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stderr", 166 | "output_type": "stream", 167 | "text": [ 168 | "Loading test projects: 100%|██████████| 50/50 [00:27<00:00, 1.85it/s]\n" 169 | ] 170 | }, 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "Accuracies on all types:\n", 176 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 177 | "67.07 & 67.47 & 72.12 & 44.05 & 73.44\n", 178 | "Accuracies on common types:\n", 179 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 180 | "76.74 & 78.04 & 82.43 & 53.03 & 82.44\n", 181 | "Accuracies on rare types:\n", 182 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 183 | "49.47 & 52.95 & 57.28 & 34.26 & 57.65\n", 184 | "full_acc:\n", 185 | " full_acc: 67.07% (count=15.7k)\n", 186 | " full_acc_by_cat:\n", 187 | " FuncArg: 62.00% (count=8.0k)\n", 188 | " FuncReturn: 77.89% (count=5.8k)\n", 189 | " ClassAtribute: 55.36% (count=1.8k)\n", 190 | " GlobalVar: 63.55% (count=107)\n", 191 | " full_acc_by_simple:\n", 192 | " complex: 41.80% (count=3.3k)\n", 193 | " simple: 73.71% (count=12.4k)\n", 194 | " full_acc_label_size: 1.4194\n", 195 | " full_acc_pred_size: 1.4107\n", 196 | " full_acc_ignored_labels: 0\n", 197 | " n_missing_types: 53\n", 198 | "full_acc_common:\n", 199 | " full_acc_common: 76.74% (count=10.1k)\n", 200 | " full_acc_common_by_cat:\n", 201 | " FuncArg: 76.04% (count=5.3k)\n", 202 | " FuncReturn: 76.74% (count=3.8k)\n", 203 | " ClassAtribute: 79.57% (count=984)\n", 204 | " GlobalVar: 88.10% (count=84)\n", 205 | " full_acc_common_by_simple:\n", 206 | " complex: 48.52% (count=1.8k)\n", 207 | " simple: 82.65% (count=8.4k)\n", 208 | " full_acc_common_label_size: 1.3693\n", 209 | " full_acc_common_pred_size: 1.3379\n", 210 | " full_acc_common_ignored_labels: 5569\n", 211 | " n_missing_types: 53\n", 212 | "full_acc_rare:\n", 213 | " full_acc_rare: 49.47% (count=5.6k)\n", 214 | " full_acc_rare_by_cat:\n", 215 | " FuncArg: 51.05% (count=3.2k)\n", 216 | " FuncReturn: 48.03% (count=2.0k)\n", 217 | " ClassAtribute: 44.10% (count=356)\n", 218 | " GlobalVar: 41.67% (count=36)\n", 219 | " full_acc_rare_by_simple:\n", 220 | " complex: 34.04% (count=1.5k)\n", 221 | " simple: 55.24% (count=4.1k)\n", 222 | " full_acc_rare_label_size: 1.5103\n", 223 | " full_acc_rare_pred_size: 1.543\n", 224 | " full_acc_rare_ignored_labels: 10129\n", 225 | " n_missing_types: 53\n", 226 | "acc:\n", 227 | " acc: 67.47% (count=13.2k)\n", 228 | " acc_by_cat:\n", 229 | " FuncArg: 66.28% (count=6.7k)\n", 230 | " FuncReturn: 68.91% (count=4.9k)\n", 231 | " ClassAtribute: 68.28% (count=1.5k)\n", 232 | " GlobalVar: 63.64% (count=99)\n", 233 | " acc_by_simple:\n", 234 | " complex: 44.05% (count=2.2k)\n", 235 | " simple: 72.12% (count=11.0k)\n", 236 | " acc_label_size: 1.3155\n", 237 | " acc_pred_size: 1.2971\n", 238 | " acc_ignored_labels: 2521\n", 239 | " n_missing_types: 53\n", 240 | "acc_common:\n", 241 | " acc_common: 78.04% (count=7.6k)\n", 242 | " acc_common_by_cat:\n", 243 | " FuncArg: 77.28% (count=4.0k)\n", 244 | " FuncReturn: 78.69% (count=2.7k)\n", 245 | " ClassAtribute: 80.51% (count=816)\n", 246 | " GlobalVar: 64.91% (count=57)\n", 247 | " acc_common_by_simple:\n", 248 | " complex: 53.03% (count=1.1k)\n", 249 | " simple: 82.43% (count=6.5k)\n", 250 | " acc_common_label_size: 1.2978\n", 251 | " acc_common_pred_size: 1.2665\n", 252 | " acc_common_ignored_labels: 8072\n", 253 | " n_missing_types: 53\n", 254 | "acc_rare:\n", 255 | " acc_rare: 52.95% (count=5.6k)\n", 256 | " acc_rare_by_cat:\n", 257 | " FuncArg: 55.36% (count=3.2k)\n", 258 | " FuncReturn: 50.94% (count=2.0k)\n", 259 | " ClassAtribute: 43.26% (count=356)\n", 260 | " GlobalVar: 44.44% (count=36)\n", 261 | " acc_rare_by_simple:\n", 262 | " complex: 34.26% (count=1.0k)\n", 263 | " simple: 57.28% (count=4.5k)\n", 264 | " acc_rare_label_size: 1.3399\n", 265 | " acc_rare_pred_size: 1.3392\n", 266 | " acc_rare_ignored_labels: 10147\n", 267 | " n_missing_types: 53\n", 268 | "base_acc:\n", 269 | " base_acc: 73.44% (count=13.2k)\n", 270 | " base_acc_by_cat:\n", 271 | " FuncArg: 72.19% (count=6.7k)\n", 272 | " FuncReturn: 74.45% (count=4.9k)\n", 273 | " ClassAtribute: 75.77% (count=1.5k)\n", 274 | " GlobalVar: 72.73% (count=99)\n", 275 | " base_acc_ignored_labels: 2521\n", 276 | " n_missing_types: 53\n", 277 | "base_acc_common:\n", 278 | " base_acc_common: 82.44% (count=8.4k)\n", 279 | " base_acc_common_by_cat:\n", 280 | " FuncArg: 82.34% (count=4.4k)\n", 281 | " FuncReturn: 82.62% (count=3.1k)\n", 282 | " ClassAtribute: 82.96% (count=886)\n", 283 | " GlobalVar: 73.02% (count=63)\n", 284 | " base_acc_common_ignored_labels: 7305\n", 285 | " n_missing_types: 53\n", 286 | "base_acc_rare:\n", 287 | " base_acc_rare: 57.65% (count=4.8k)\n", 288 | " base_acc_rare_by_cat:\n", 289 | " FuncArg: 58.01% (count=2.8k)\n", 290 | " FuncReturn: 55.16% (count=1.6k)\n", 291 | " ClassAtribute: 65.96% (count=332)\n", 292 | " GlobalVar: 67.74% (count=31)\n", 293 | " base_acc_rare_ignored_labels: 10914\n", 294 | " n_missing_types: 53\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "repos_dir = get_dataset_dir(dataset_name) / \"repos\" / \"test\"\n", 300 | "test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]\n", 301 | "test_projects = pmap(\n", 302 | " data_project_from_dir,\n", 303 | " test_repo_paths,\n", 304 | " desc=\"Loading test projects\",\n", 305 | ")\n", 306 | "assert len(test_projects) > 0\n", 307 | "\n", 308 | "common_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)\n", 309 | "pred_map, label_map = sigmap_from_file_predictions(pre_r, test_projects, repos_dir)\n", 310 | "accs = {\n", 311 | " m.name: SignatureErrorAnalysis(pred_map, label_map, m).accuracies\n", 312 | " for m in AccuracyMetric.default_metrics(common_names)\n", 313 | "}\n", 314 | "\n", 315 | "from typet5.experiments.typet5 import accs_as_table_row\n", 316 | "accs_as_table_row(accs)\n", 317 | "pretty_print_dict(accs)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 12, 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "name": "stderr", 327 | "output_type": "stream", 328 | "text": [ 329 | "Exporting: 100%|██████████| 1851/1851 [00:18<00:00, 100.04it/s]\n", 330 | "Computing accuracies: 100%|██████████| 1851/1851 [00:00<00:00, 9748.66it/s]\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "from typet5.utils import decode_tokens, Path\n", 336 | "from typet5.visualization import export_preds_on_code\n", 337 | "\n", 338 | "export_to = Path(f\"caches/model_predictions/eval_file_model/{dataset_name}\")\n", 339 | "export_preds_on_code(pre_r.chunks, pre_r.predictions, export_to, AccuracyMetric(common_names))" 340 | ] 341 | } 342 | ], 343 | "metadata": { 344 | "kernelspec": { 345 | "display_name": "Python 3.10.4 ('.venv': pipenv)", 346 | "language": "python", 347 | "name": "python3" 348 | }, 349 | "language_info": { 350 | "codemirror_mode": { 351 | "name": "ipython", 352 | "version": 3 353 | }, 354 | "file_extension": ".py", 355 | "mimetype": "text/x-python", 356 | "name": "python", 357 | "nbconvert_exporter": "python", 358 | "pygments_lexer": "ipython3", 359 | "version": "3.10.4" 360 | }, 361 | "orig_nbformat": 4, 362 | "vscode": { 363 | "interpreter": { 364 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 365 | } 366 | } 367 | }, 368 | "nbformat": 4, 369 | "nbformat_minor": 2 370 | } 371 | -------------------------------------------------------------------------------- /scripts/experiments/eval_func_model.py: -------------------------------------------------------------------------------- 1 | # evaluate the model trained on fucntional dataset using 2 | # incremental decoding. 3 | 4 | # %% 5 | import asyncio 6 | import copy 7 | import os 8 | from typing import * 9 | 10 | import torch 11 | import wandb 12 | from termcolor import colored 13 | 14 | from typet5.experiments.typet5 import TypeT5Configs 15 | from typet5.function_dataset import data_project_from_dir 16 | from typet5.model import ModelWrapper 17 | from typet5.train import PreprocessArgs, TrainingConfig 18 | from typet5.type_env import AccuracyMetric 19 | from typet5.utils import ( 20 | assert_eq, 21 | get_dataset_dir, 22 | get_eval_dir, 23 | get_gpu_id, 24 | get_model_dir, 25 | get_modified_args, 26 | pickle_dump, 27 | pickle_load, 28 | pmap, 29 | pretty_show_dict, 30 | proj_root, 31 | run_long_task, 32 | write_file, 33 | ) 34 | from typet5.visualization import string_to_html 35 | 36 | os.chdir(proj_root()) 37 | 38 | 39 | def wandb_string(s: str): 40 | return wandb.Html(string_to_html(s)) 41 | 42 | 43 | # %% 44 | 45 | # experiment configurations 46 | 47 | load_results = True 48 | use_oracle = True 49 | gpu_id = get_gpu_id(1) 50 | train_config = TypeT5Configs.Default 51 | 52 | model_name = train_config.get_model_name() 53 | # model_name = ( 54 | # "model-v7--TrainingConfig(drop_env_types=False, add_implicit_rel_imports=True)" 55 | # ) 56 | dataset_name = "ManyTypes4Py" 57 | # dataset_name = "InferTypes4Py" 58 | # dataset_name = "TinyEval" 59 | 60 | test_pre_args = train_config.pre_args 61 | oracle_tag = "(use-oracle) " if use_oracle else "" 62 | # group_tag = "(implicit_imports, new) " 63 | # group_tag = "(ablation) " 64 | group_tag = "" 65 | experiment_name = oracle_tag + group_tag + model_name 66 | 67 | print(colored(f"Use GPU: {gpu_id}", "green")) 68 | 69 | # %% 70 | 71 | # load model 72 | model = ModelWrapper.load(get_model_dir() / model_name) 73 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") 74 | model.to(device) 75 | print(f"Model loaded:", model_name) 76 | 77 | # load test projects 78 | repos_dir = get_dataset_dir(dataset_name) / "repos" / "test" 79 | test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()] 80 | if not load_results: 81 | test_projects = pmap( 82 | data_project_from_dir, 83 | test_repo_paths, 84 | desc="Loading test projects", 85 | ) 86 | assert len(test_projects) > 0 87 | 88 | # %% 89 | 90 | from typet5.experiments.typet5 import accs_as_table_row 91 | from typet5.function_decoding import DecodingOrders, EvalResult, RolloutCtx 92 | 93 | ctx_args = model.args.ctx_args 94 | ctx_args.max_labels = 16 95 | model.args.sampling_max_tokens = ctx_args.ctx_size 96 | model.args.do_sample = False 97 | model.args.num_beams = 16 98 | model.args.tokens_per_type = 16 99 | 100 | rctx = RolloutCtx(model=model) 101 | 102 | decode_orders = { 103 | # "double-traversal": DecodingOrders.DoubleTraversal(), 104 | # "reverse-double-traversal": DecodingOrders.Reversed( 105 | # DecodingOrders.DoubleTraversal() 106 | # ), 107 | # "non-incr": DecodingOrders.IndependentOrder(), 108 | # "random": DecodingOrders.RandomOrder(), 109 | # "no-neighbors": DecodingOrders.IndependentOrder(), 110 | "callee2caller": DecodingOrders.Callee2Caller(), 111 | # "caller2callee": DecodingOrders.Caller2Callee(), 112 | # "random-twice": DecodingOrders.RandomTwice(), 113 | } 114 | 115 | metrics = AccuracyMetric.default_metrics(model.common_type_names) 116 | with run_long_task("Evaluating different decoding strategy", notify=not load_results): 117 | results_dir = get_eval_dir(dataset_name, experiment_name) 118 | results_dir.mkdir(exist_ok=True, parents=True) 119 | print(colored(f"Results will be saved to: {str(results_dir)}", "green")) 120 | 121 | if not load_results: 122 | wandb.init( 123 | project="SPOT-eval", 124 | name=dataset_name + ": " + experiment_name, 125 | dir=str(results_dir), 126 | config=get_modified_args(model.args), 127 | ) 128 | 129 | evals = dict[str, EvalResult]() 130 | for oname, order in decode_orders.items(): 131 | result_path = results_dir / f"{oname}-EvalResult.pkl" 132 | if not load_results: 133 | print(f"Evaluating decoding strategy: {oname}") 134 | pre_args = copy.deepcopy(test_pre_args) 135 | if oname == "no-neighbors": 136 | pre_args.max_callers = 0 137 | pre_args.max_callees = 0 138 | evalr = asyncio.run( 139 | rctx.evaluate_on_projects( 140 | test_projects, # type: ignore 141 | pre_args, 142 | order, 143 | use_oracle=use_oracle, 144 | ) 145 | ) 146 | pickle_dump(result_path, evalr) 147 | else: 148 | if not result_path.exists(): 149 | print(f"Result file not found, skip: {result_path}") 150 | continue 151 | evalr = pickle_load(result_path) 152 | evals[oname] = evalr 153 | accs = {m.name: evalr.error_analysis(None, m).accuracies for m in metrics} 154 | accs_str = pretty_show_dict(accs) 155 | write_file(results_dir / f"{oname}-accuracy.txt", accs_str) 156 | if not load_results: 157 | wandb.log({f"test/{oname}": wandb_string(accs_str)}) 158 | print(f"========== {oname} ===========") 159 | print(accs_str) 160 | accs_as_table_row(accs) 161 | 162 | # %% 163 | if False: 164 | # print predictions 165 | for oname, evalr in evals.items(): 166 | print(f"========== {oname} ===========") 167 | evalr.print_predictions() 168 | 169 | import prettytable as pt 170 | 171 | # %% 172 | from prettytable import PrettyTable 173 | 174 | common_type_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name) 175 | results_table = PrettyTable() 176 | results_table.field_names = ["order", *(m.name for m in metrics)] 177 | results_table.align = "r" 178 | results_table.set_style(pt.SINGLE_BORDER) 179 | results_table.float_format = ".4" 180 | 181 | for oname in evals: 182 | accs = [ 183 | evals[oname].error_analysis(None, metric).accuracies[metric.name].acc 184 | for metric in metrics 185 | ] 186 | results_table.add_row([oname, *accs]) 187 | 188 | print(results_table) 189 | write_file(results_dir / "comparison.txt", results_table.get_string()) 190 | -------------------------------------------------------------------------------- /scripts/experiments/run_hityper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "from typet5.utils import *\n", 13 | "os.chdir(proj_root())\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 15, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stderr", 23 | "output_type": "stream", 24 | "text": [ 25 | "Removing newer syntax: 100%|██████████| 1594/1594 [00:01<00:00, 873.90it/s] \n" 26 | ] 27 | }, 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "1594 / 1594 files have been rewritten.\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "from typet5.experiments.utils import remove_newer_syntax_for_repo, Path\n", 38 | "from typet5.experiments.type4py import Type4PySupportedSyntax\n", 39 | "from typet5.utils import get_dataset_dir\n", 40 | "import shutil\n", 41 | "\n", 42 | "# dataset_name = \"InferTypes4Py\"\n", 43 | "dataset_name = \"ManyTypes4Py\"\n", 44 | "repos_dir = get_dataset_dir(dataset_name) / \"repos\"\n", 45 | "shutil.rmtree(repos_dir / \"test-hityper\", ignore_errors=True)\n", 46 | "shutil.copytree(repos_dir / \"test\", repos_dir / \"test-hityper\")\n", 47 | "remove_newer_syntax_for_repo(repos_dir / \"test-hityper\", Type4PySupportedSyntax)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 16, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "from typet5.experiments.hityper import run_hityper, eval_hityper_on_repos\n", 57 | "\n", 58 | "test_repos = [\n", 59 | " p\n", 60 | " for p in (repos_dir / \"test-hityper\").iterdir()\n", 61 | " if p.is_dir()\n", 62 | "]\n", 63 | "\n", 64 | "hityper_path = Path(\"/home/jiayi/anaconda3/envs/hityper/bin/python\")\n", 65 | "out_dir = Path(\"output/hityper\")\n", 66 | "cache = PickleCache(Path(f\"caches/run_hityper\"))\n", 67 | "\n", 68 | "eval_r = cache.cached(f\"{dataset_name}.pkl\", lambda: eval_hityper_on_repos(test_repos, hityper_path, out_dir, max_workers=4))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 17, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "Accuracies on all types:\n", 81 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 82 | "42.53 & 41.95 & 44.86 & 19.02 & 47.88\n", 83 | "Accuracies on common types:\n", 84 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 85 | "59.20 & 54.28 & 57.70 & 26.44 & 59.01\n", 86 | "Accuracies on rare types:\n", 87 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 88 | "10.30 & 25.51 & 27.59 & 9.79 & 29.33\n", 89 | "----------------------------------------\n", 90 | "full_acc:\n", 91 | " full_acc: 42.53% (count=10.5k)\n", 92 | " full_acc_by_cat:\n", 93 | " FuncArg: 26.80% (count=5.9k)\n", 94 | " FuncReturn: 63.41% (count=4.5k)\n", 95 | " GlobalVar: 18.84% (count=69)\n", 96 | " full_acc_by_simple:\n", 97 | " complex: 14.59% (count=1.7k)\n", 98 | " simple: 47.85% (count=8.8k)\n", 99 | " full_acc_label_size: 1.3993\n", 100 | " full_acc_pred_size: 1.3325\n", 101 | " full_acc_ignored_labels: 0\n", 102 | " n_skipped_types: 3293\n", 103 | " n_missing_types: 1922\n", 104 | "full_acc_common:\n", 105 | " full_acc_common: 59.20% (count=6.9k)\n", 106 | " full_acc_common_by_cat:\n", 107 | " FuncArg: 58.18% (count=4.0k)\n", 108 | " FuncReturn: 60.82% (count=2.9k)\n", 109 | " GlobalVar: 50.00% (count=52)\n", 110 | " full_acc_common_by_simple:\n", 111 | " complex: 20.43% (count=940)\n", 112 | " simple: 65.27% (count=6.0k)\n", 113 | " full_acc_common_label_size: 1.3505\n", 114 | " full_acc_common_pred_size: 1.2762\n", 115 | " full_acc_common_ignored_labels: 3592\n", 116 | " n_skipped_types: 3293\n", 117 | " n_missing_types: 1922\n", 118 | "full_acc_rare:\n", 119 | " full_acc_rare: 10.30% (count=3.6k)\n", 120 | " full_acc_rare_by_cat:\n", 121 | " FuncArg: 10.04% (count=2.2k)\n", 122 | " FuncReturn: 10.80% (count=1.4k)\n", 123 | " GlobalVar: 0.00% (count=9)\n", 124 | " full_acc_rare_by_simple:\n", 125 | " complex: 7.24% (count=746)\n", 126 | " simple: 11.10% (count=2.8k)\n", 127 | " full_acc_rare_label_size: 1.4936\n", 128 | " full_acc_rare_pred_size: 1.4413\n", 129 | " full_acc_rare_ignored_labels: 6944\n", 130 | " n_skipped_types: 3293\n", 131 | " n_missing_types: 1922\n", 132 | "acc:\n", 133 | " acc: 41.95% (count=8.4k)\n", 134 | " acc_by_cat:\n", 135 | " FuncArg: 41.69% (count=4.8k)\n", 136 | " FuncReturn: 42.24% (count=3.5k)\n", 137 | " GlobalVar: 45.61% (count=57)\n", 138 | " acc_by_simple:\n", 139 | " complex: 19.02% (count=941)\n", 140 | " simple: 44.86% (count=7.4k)\n", 141 | " acc_label_size: 1.3238\n", 142 | " acc_pred_size: 1.2303\n", 143 | " acc_ignored_labels: 2180\n", 144 | " n_skipped_types: 3293\n", 145 | " n_missing_types: 1922\n", 146 | "acc_common:\n", 147 | " acc_common: 54.28% (count=4.8k)\n", 148 | " acc_common_by_cat:\n", 149 | " FuncArg: 53.86% (count=2.9k)\n", 150 | " FuncReturn: 54.78% (count=1.8k)\n", 151 | " GlobalVar: 68.00% (count=25)\n", 152 | " acc_common_by_simple:\n", 153 | " complex: 26.44% (count=522)\n", 154 | " simple: 57.70% (count=4.3k)\n", 155 | " acc_common_label_size: 1.3241\n", 156 | " acc_common_pred_size: 1.2162\n", 157 | " acc_common_ignored_labels: 5763\n", 158 | " n_skipped_types: 3293\n", 159 | " n_missing_types: 1922\n", 160 | "acc_rare:\n", 161 | " acc_rare: 25.51% (count=3.6k)\n", 162 | " acc_rare_by_cat:\n", 163 | " FuncArg: 23.96% (count=2.2k)\n", 164 | " FuncReturn: 28.09% (count=1.3k)\n", 165 | " GlobalVar: 22.22% (count=9)\n", 166 | " acc_rare_by_simple:\n", 167 | " complex: 9.79% (count=419)\n", 168 | " simple: 27.59% (count=3.2k)\n", 169 | " acc_rare_label_size: 1.3235\n", 170 | " acc_rare_pred_size: 1.249\n", 171 | " acc_rare_ignored_labels: 6953\n", 172 | " n_skipped_types: 3293\n", 173 | " n_missing_types: 1922\n", 174 | "base_acc:\n", 175 | " base_acc: 47.88% (count=8.4k)\n", 176 | " base_acc_by_cat:\n", 177 | " FuncArg: 47.48% (count=4.8k)\n", 178 | " FuncReturn: 48.34% (count=3.5k)\n", 179 | " GlobalVar: 54.39% (count=57)\n", 180 | " base_acc_ignored_labels: 2180\n", 181 | " n_skipped_types: 3293\n", 182 | " n_missing_types: 1922\n", 183 | "base_acc_common:\n", 184 | " base_acc_common: 59.01% (count=5.2k)\n", 185 | " base_acc_common_by_cat:\n", 186 | " FuncArg: 58.99% (count=3.1k)\n", 187 | " FuncReturn: 59.19% (count=2.1k)\n", 188 | " GlobalVar: 48.28% (count=29)\n", 189 | " base_acc_common_ignored_labels: 5313\n", 190 | " n_skipped_types: 3293\n", 191 | " n_missing_types: 1922\n", 192 | "base_acc_rare:\n", 193 | " base_acc_rare: 29.33% (count=3.1k)\n", 194 | " base_acc_rare_by_cat:\n", 195 | " FuncArg: 29.17% (count=2.0k)\n", 196 | " FuncReturn: 29.44% (count=1.2k)\n", 197 | " GlobalVar: 57.14% (count=7)\n", 198 | " base_acc_rare_ignored_labels: 7403\n", 199 | " n_skipped_types: 3293\n", 200 | " n_missing_types: 1922\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "from typet5.static_analysis import SignatureErrorAnalysis, AccuracyMetric\n", 206 | "from typet5.experiments.typet5 import accs_as_table_row, ModelWrapper\n", 207 | "\n", 208 | "\n", 209 | "common_names = ModelWrapper.load_common_type_names(\n", 210 | " get_model_dir() / \"model-v7--TrainingConfig(drop_env_types=False)\"\n", 211 | ")\n", 212 | "metrics = AccuracyMetric.default_metrics(common_type_names=common_names)\n", 213 | "# acc_metric = AccuracyMetric(common_type_names=ubiq_names)\n", 214 | "\n", 215 | "accs = {\n", 216 | " m.name: SignatureErrorAnalysis(\n", 217 | " eval_r.pred_maps,\n", 218 | " eval_r.label_maps,\n", 219 | " m,\n", 220 | " error_on_mismatched_signature=False,\n", 221 | " ).accuracies\n", 222 | " for m in metrics\n", 223 | "}\n", 224 | "\n", 225 | "accs_as_table_row(accs)\n", 226 | "print(\"-\" * 40)\n", 227 | "pretty_print_dict(accs)" 228 | ] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3.10.4 ('.venv': pipenv)", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.10.4" 248 | }, 249 | "orig_nbformat": 4, 250 | "vscode": { 251 | "interpreter": { 252 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 253 | } 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 2 258 | } 259 | -------------------------------------------------------------------------------- /scripts/experiments/run_typilus.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "from typet5.utils import proj_root, os\n", 13 | "os.chdir(proj_root())" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stderr", 23 | "output_type": "stream", 24 | "text": [ 25 | "Removing newer syntax: 100%|██████████| 1594/1594 [00:01<00:00, 799.31it/s]" 26 | ] 27 | }, 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "1594 / 1594 files have been rewritten.\n" 33 | ] 34 | }, 35 | { 36 | "name": "stderr", 37 | "output_type": "stream", 38 | "text": [ 39 | "\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "from typet5.experiments.utils import remove_newer_syntax_for_repo, Path\n", 45 | "from typet5.experiments.typilus import eval_typilus_on_repos, TypilusSupportedSyntax\n", 46 | "from typet5.utils import get_dataset_dir\n", 47 | "import shutil\n", 48 | "\n", 49 | "# dataset_name = \"InferTypes4Py\"\n", 50 | "dataset_name = \"ManyTypes4Py\"\n", 51 | "repos_dir = get_dataset_dir(dataset_name) / \"repos\"\n", 52 | "shutil.rmtree(repos_dir / \"test-typilus\", ignore_errors=True)\n", 53 | "shutil.copytree(repos_dir / \"test\", repos_dir / \"test-typilus\")\n", 54 | "remove_newer_syntax_for_repo(repos_dir / \"test-typilus\", TypilusSupportedSyntax)\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 5, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stderr", 64 | "output_type": "stream", 65 | "text": [ 66 | "Running Typilus: 100%|██████████| 50/50 [00:46<00:00, 1.09it/s]\n", 67 | "Collecting labels: 100%|██████████| 50/50 [00:15<00:00, 3.17it/s]\n", 68 | "WARNING:root:Missing 27 predictions for module: test_fakesmtpd.syntax\n", 69 | "WARNING:root:Missing 1 predictions for module: archive\n", 70 | "WARNING:root:Missing 1 predictions for module: tests.test_utils\n", 71 | "WARNING:root:Missing 1 predictions for module: tests.test_init_import\n", 72 | "WARNING:root:Missing 1 predictions for module: tests.test_axion_plugins\n", 73 | "WARNING:root:Missing 1 predictions for module: typesafety.conftest\n", 74 | "WARNING:root:Missing 1 predictions for module: tests.test_base\n" 75 | ] 76 | }, 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "Accuracies on all types:\n", 82 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 83 | "47.10 & 54.05 & 55.12 & 33.23 & 60.37\n", 84 | "Accuracies on common types:\n", 85 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 86 | "47.10 & 54.05 & 55.12 & 33.23 & 60.37\n", 87 | "Accuracies on rare types:\n", 88 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n", 89 | "nan & nan & N/A & N/A & nan\n", 90 | "full_acc:\n", 91 | " full_acc: 47.10% (count=6.9k)\n", 92 | " full_acc_by_cat:\n", 93 | " FuncArg: 47.86% (count=4.4k)\n", 94 | " FuncReturn: 46.79% (count=1.3k)\n", 95 | " ClassAtribute: 45.40% (count=1.1k)\n", 96 | " GlobalVar: 20.45% (count=44)\n", 97 | " full_acc_by_simple:\n", 98 | " complex: 20.42% (count=529)\n", 99 | " simple: 49.31% (count=6.4k)\n", 100 | " full_acc_label_size: 1.4992\n", 101 | " full_acc_pred_size: 1.1339\n", 102 | " full_acc_ignored_labels: 0\n", 103 | " n_missing: 3958\n", 104 | " n_skipped_rare: 4892\n", 105 | "full_acc_common:\n", 106 | " full_acc_common: 47.10% (count=6.9k)\n", 107 | " full_acc_common_by_cat:\n", 108 | " FuncArg: 47.86% (count=4.4k)\n", 109 | " FuncReturn: 46.79% (count=1.3k)\n", 110 | " ClassAtribute: 45.40% (count=1.1k)\n", 111 | " GlobalVar: 20.45% (count=44)\n", 112 | " full_acc_common_by_simple:\n", 113 | " complex: 20.42% (count=529)\n", 114 | " simple: 49.31% (count=6.4k)\n", 115 | " full_acc_common_label_size: 1.4992\n", 116 | " full_acc_common_pred_size: 1.1339\n", 117 | " full_acc_common_ignored_labels: 0\n", 118 | " n_missing: 3958\n", 119 | " n_skipped_rare: 4892\n", 120 | "full_acc_rare:\n", 121 | " full_acc_rare: nan% (count=0)\n", 122 | " full_acc_rare_by_cat:\n", 123 | " full_acc_rare_by_simple:\n", 124 | " full_acc_rare_label_size: nan\n", 125 | " full_acc_rare_pred_size: nan\n", 126 | " full_acc_rare_ignored_labels: 6903\n", 127 | " n_missing: 3958\n", 128 | " n_skipped_rare: 4892\n", 129 | "acc:\n", 130 | " acc: 54.05% (count=6.9k)\n", 131 | " acc_by_cat:\n", 132 | " FuncArg: 55.27% (count=4.4k)\n", 133 | " FuncReturn: 50.30% (count=1.3k)\n", 134 | " ClassAtribute: 54.41% (count=1.1k)\n", 135 | " GlobalVar: 36.36% (count=44)\n", 136 | " acc_by_simple:\n", 137 | " complex: 33.23% (count=337)\n", 138 | " simple: 55.12% (count=6.6k)\n", 139 | " acc_label_size: 1.2988\n", 140 | " acc_pred_size: 1.0755\n", 141 | " acc_ignored_labels: 15\n", 142 | " n_missing: 3958\n", 143 | " n_skipped_rare: 4892\n", 144 | "acc_common:\n", 145 | " acc_common: 54.05% (count=6.9k)\n", 146 | " acc_common_by_cat:\n", 147 | " FuncArg: 55.27% (count=4.4k)\n", 148 | " FuncReturn: 50.30% (count=1.3k)\n", 149 | " ClassAtribute: 54.41% (count=1.1k)\n", 150 | " GlobalVar: 36.36% (count=44)\n", 151 | " acc_common_by_simple:\n", 152 | " complex: 33.23% (count=337)\n", 153 | " simple: 55.12% (count=6.6k)\n", 154 | " acc_common_label_size: 1.2988\n", 155 | " acc_common_pred_size: 1.0755\n", 156 | " acc_common_ignored_labels: 15\n", 157 | " n_missing: 3958\n", 158 | " n_skipped_rare: 4892\n", 159 | "acc_rare:\n", 160 | " acc_rare: nan% (count=0)\n", 161 | " acc_rare_by_cat:\n", 162 | " acc_rare_by_simple:\n", 163 | " acc_rare_label_size: nan\n", 164 | " acc_rare_pred_size: nan\n", 165 | " acc_rare_ignored_labels: 6903\n", 166 | " n_missing: 3958\n", 167 | " n_skipped_rare: 4892\n", 168 | "base_acc:\n", 169 | " base_acc: 60.37% (count=6.9k)\n", 170 | " base_acc_by_cat:\n", 171 | " FuncArg: 60.62% (count=4.4k)\n", 172 | " FuncReturn: 61.90% (count=1.3k)\n", 173 | " ClassAtribute: 57.94% (count=1.1k)\n", 174 | " GlobalVar: 47.73% (count=44)\n", 175 | " base_acc_ignored_labels: 15\n", 176 | " n_missing: 3958\n", 177 | " n_skipped_rare: 4892\n", 178 | "base_acc_common:\n", 179 | " base_acc_common: 60.37% (count=6.9k)\n", 180 | " base_acc_common_by_cat:\n", 181 | " FuncArg: 60.62% (count=4.4k)\n", 182 | " FuncReturn: 61.90% (count=1.3k)\n", 183 | " ClassAtribute: 57.94% (count=1.1k)\n", 184 | " GlobalVar: 47.73% (count=44)\n", 185 | " base_acc_common_ignored_labels: 15\n", 186 | " n_missing: 3958\n", 187 | " n_skipped_rare: 4892\n", 188 | "base_acc_rare:\n", 189 | " base_acc_rare: nan% (count=0)\n", 190 | " base_acc_rare_by_cat:\n", 191 | " base_acc_rare_ignored_labels: 6903\n", 192 | " n_missing: 3958\n", 193 | " n_skipped_rare: 4892\n" 194 | ] 195 | } 196 | ], 197 | "source": [ 198 | "from typet5.model import ModelWrapper\n", 199 | "from typet5.static_analysis import AccuracyMetric\n", 200 | "from typet5.utils import *\n", 201 | "from typet5.experiments.typet5 import accs_as_table_row\n", 202 | "\n", 203 | "test_repos = [p for p in (repos_dir / \"test-typilus\").iterdir() if p.is_dir()]\n", 204 | "\n", 205 | "common_names = ModelWrapper.load_common_type_names(\n", 206 | " get_model_dir() / \"model-v7--TrainingConfig(drop_env_types=False)\"\n", 207 | ")\n", 208 | "metrics = AccuracyMetric.default_metrics(common_names)\n", 209 | "typilus_path = Path(\"~/Projects/typilus-action/\").expanduser()\n", 210 | "work_dir = Path(\"~/Projects/typilus-action/data_out\").expanduser()\n", 211 | "\n", 212 | "cache = PickleCache(Path(f\"caches/run_typilus\"))\n", 213 | "cache.remove(f\"{dataset_name}.pkl\")\n", 214 | "accs = cache.cached(\n", 215 | " f\"{dataset_name}.pkl\",\n", 216 | " lambda: eval_typilus_on_repos(\n", 217 | " test_repos, metrics, typilus_path, work_dir, max_workers=4\n", 218 | " ),\n", 219 | ")\n", 220 | "\n", 221 | "accs_as_table_row(accs)\n", 222 | "pretty_print_dict(accs)\n" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3.10.4 ('.venv': pipenv)", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.10.4" 243 | }, 244 | "orig_nbformat": 4, 245 | "vscode": { 246 | "interpreter": { 247 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8" 248 | } 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 2 253 | } 254 | -------------------------------------------------------------------------------- /scripts/run_typet5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import os\n", 13 | "from typing import *\n", 14 | "\n", 15 | "import torch\n", 16 | "\n", 17 | "from typet5.model import ModelWrapper\n", 18 | "from typet5.train import PreprocessArgs\n", 19 | "from typet5.utils import *\n", 20 | "from typet5.function_decoding import (\n", 21 | " RolloutCtx,\n", 22 | " PreprocessArgs,\n", 23 | " DecodingOrders,\n", 24 | " AccuracyMetric,\n", 25 | ")\n", 26 | "from typet5.static_analysis import PythonProject\n", 27 | "\n", 28 | "os.chdir(proj_root())" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "application/vnd.jupyter.widget-view+json": { 39 | "model_id": "3852e189479e408a801e24d3e050cff7", 40 | "version_major": 2, 41 | "version_minor": 0 42 | }, 43 | "text/plain": [ 44 | "Fetching 9 files: 0%| | 0/9 [00:00 int\n", 91 | "ex_code_1/Wrapper.foo: (bar: int) -> int\n", 92 | "ex_code_1/Wrapper.inc: () -> str\n", 93 | "ex_code_1/int_add: (a: int, b: int) -> str\n", 94 | "ex_code_1/int_tripple_add: (a: int, b: int, c: int) -> int\n", 95 | "ex_code_2/fib: (n: int) -> int\n", 96 | "ex_code_2/foo: (bar: int) -> int\n", 97 | "ex_code_2/Bar.x: int\n", 98 | "ex_code_2/Bar.y: int\n", 99 | "ex_code_2/Bar.reset: (w0: str) -> None\n", 100 | "ex_code_2/Bar.__init__: (x: int) -> None\n", 101 | "(updated) ex_code_1/int_add: (a: int, b: int) -> int\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "# Use case 1: Run TypeT5 on a given project, taking advantage of existing user \n", 107 | "# annotations and only make predictions for missing types.\n", 108 | "\n", 109 | "project = PythonProject.parse_from_root(proj_root() / \"data/ex_repo\")\n", 110 | "rollout = await rctx.run_on_project(project, pre_args, decode_order)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stderr", 120 | "output_type": "stream", 121 | "text": [ 122 | "evaluate_on_projects: 100%|██████████| 35/35 [00:04<00:00, 7.20it/s]" 123 | ] 124 | }, 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "==================== /home/jiayi/Projects/TypeT5/data/ex_repo ====================\n", 130 | "\tex_code_1/fib: (n: int) -> int\n", 131 | "\tex_code_1/Wrapper.foo: (bar: int) -> int\n", 132 | "\tex_code_1/Wrapper.inc: () -> int\n", 133 | "\tex_code_1/int_add: (a: int, b: int) -> str\n", 134 | "\tex_code_1/int_tripple_add: (a: int, b: int, c: int) -> int\n", 135 | "\tex_code_2/fib: (n: int) -> int\n", 136 | "\tex_code_2/foo: (bar: int) -> int\n", 137 | "\tex_code_2/Bar.__init__: (x: int) -> None\n", 138 | "\tex_code_2/Bar.reset: (w0: int_add) -> None\n", 139 | "\tex_code_2/Bar.foo: (z: str) -> str\n", 140 | "\tex_code_1/good: int\n", 141 | "\tex_code_1/Wrapper.x_elem: int\n", 142 | "\tex_code_1/Wrapper.y: int\n", 143 | "\tex_code_2/Bar.z: str\n", 144 | "\tex_code_2/Bar.w: int_add\n", 145 | "\tex_code_2/Bar.x: int\n", 146 | "\tex_code_2/Bar.y: int\n", 147 | "\tex_code_2/bar: Bar\n" 148 | ] 149 | }, 150 | { 151 | "name": "stderr", 152 | "output_type": "stream", 153 | "text": [ 154 | "\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "# Use case 2: Run TypeT5 on a test project where all user annotations will be treated as\n", 160 | "# labels and removed before running the model.\n", 161 | "\n", 162 | "eval_r = await rctx.evaluate_on_projects([project], pre_args, decode_order)\n", 163 | "eval_r.print_predictions()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 6, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "full_acc:\n", 176 | " full_acc: 70.00% (count=10)\n", 177 | " full_acc_by_cat:\n", 178 | " FuncArg: 100.00% (count=4)\n", 179 | " FuncReturn: 0.00% (count=1)\n", 180 | " ClassAtribute: 50.00% (count=4)\n", 181 | " GlobalVar: 100.00% (count=1)\n", 182 | " full_acc_by_simple:\n", 183 | " simple: 70.00% (count=10)\n", 184 | " full_acc_label_size: 1\n", 185 | " full_acc_pred_size: 1\n", 186 | " full_acc_ignored_labels: 0\n", 187 | "full_acc_common:\n", 188 | " full_acc_common: 66.67% (count=9)\n", 189 | " full_acc_common_by_cat:\n", 190 | " FuncArg: 100.00% (count=4)\n", 191 | " FuncReturn: 0.00% (count=1)\n", 192 | " ClassAtribute: 50.00% (count=4)\n", 193 | " full_acc_common_by_simple:\n", 194 | " simple: 66.67% (count=9)\n", 195 | " full_acc_common_label_size: 1\n", 196 | " full_acc_common_pred_size: 1\n", 197 | " full_acc_common_ignored_labels: 1\n", 198 | "full_acc_rare:\n", 199 | " full_acc_rare: 100.00% (count=1)\n", 200 | " full_acc_rare_by_cat:\n", 201 | " ClassAtribute: 100.00% (count=1)\n", 202 | " full_acc_rare_by_simple:\n", 203 | " simple: 100.00% (count=1)\n", 204 | " full_acc_rare_label_size: 1\n", 205 | " full_acc_rare_pred_size: 1\n", 206 | " full_acc_rare_ignored_labels: 9\n", 207 | "acc:\n", 208 | " acc: 70.00% (count=10)\n", 209 | " acc_by_cat:\n", 210 | " FuncArg: 100.00% (count=4)\n", 211 | " FuncReturn: 0.00% (count=1)\n", 212 | " ClassAtribute: 50.00% (count=4)\n", 213 | " GlobalVar: 100.00% (count=1)\n", 214 | " acc_by_simple:\n", 215 | " simple: 70.00% (count=10)\n", 216 | " acc_label_size: 1\n", 217 | " acc_pred_size: 1\n", 218 | " acc_ignored_labels: 0\n", 219 | "acc_common:\n", 220 | " acc_common: 66.67% (count=9)\n", 221 | " acc_common_by_cat:\n", 222 | " FuncArg: 100.00% (count=4)\n", 223 | " FuncReturn: 0.00% (count=1)\n", 224 | " ClassAtribute: 50.00% (count=4)\n", 225 | " acc_common_by_simple:\n", 226 | " simple: 66.67% (count=9)\n", 227 | " acc_common_label_size: 1\n", 228 | " acc_common_pred_size: 1\n", 229 | " acc_common_ignored_labels: 1\n", 230 | "acc_rare:\n", 231 | " acc_rare: 100.00% (count=1)\n", 232 | " acc_rare_by_cat:\n", 233 | " ClassAtribute: 100.00% (count=1)\n", 234 | " acc_rare_by_simple:\n", 235 | " simple: 100.00% (count=1)\n", 236 | " acc_rare_label_size: 1\n", 237 | " acc_rare_pred_size: 1\n", 238 | " acc_rare_ignored_labels: 9\n", 239 | "base_acc:\n", 240 | " base_acc: 70.00% (count=10)\n", 241 | " base_acc_by_cat:\n", 242 | " FuncArg: 100.00% (count=4)\n", 243 | " FuncReturn: 0.00% (count=1)\n", 244 | " ClassAtribute: 50.00% (count=4)\n", 245 | " GlobalVar: 100.00% (count=1)\n", 246 | " base_acc_ignored_labels: 0\n", 247 | "base_acc_common:\n", 248 | " base_acc_common: 66.67% (count=9)\n", 249 | " base_acc_common_by_cat:\n", 250 | " FuncArg: 100.00% (count=4)\n", 251 | " FuncReturn: 0.00% (count=1)\n", 252 | " ClassAtribute: 50.00% (count=4)\n", 253 | " base_acc_common_ignored_labels: 1\n", 254 | "base_acc_rare:\n", 255 | " base_acc_rare: 100.00% (count=1)\n", 256 | " base_acc_rare_by_cat:\n", 257 | " ClassAtribute: 100.00% (count=1)\n", 258 | " base_acc_rare_ignored_labels: 9\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "metrics = AccuracyMetric.default_metrics(wrapper.common_type_names)\n", 264 | "for metric in metrics:\n", 265 | " accs = eval_r.error_analysis(None, metric).accuracies\n", 266 | " pretty_print_dict({metric.name: accs})\n", 267 | " " 268 | ] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": ".venv", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.10.8" 288 | }, 289 | "orig_nbformat": 4 290 | }, 291 | "nbformat": 4, 292 | "nbformat_minor": 2 293 | } 294 | -------------------------------------------------------------------------------- /scripts/train_model.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | from typing import * 4 | 5 | from termcolor import colored 6 | 7 | from typet5.data import ( 8 | TypeCheckSettings, 9 | create_tokenized_srcsets, 10 | get_tk_dataset_name, 11 | load_tokenized_srcsets, 12 | ) 13 | from typet5.experiments.typet5 import TypeT5Configs 14 | from typet5.model import DecodingArgs, ModelWrapper 15 | from typet5.train import TypeCheckArgs 16 | from typet5.utils import * 17 | 18 | os.chdir(proj_root()) 19 | 20 | # %% 21 | # ----------------------------------------------------------- 22 | # experiment configurations 23 | 24 | gpu_id = get_gpu_id(0) # which GPU to use 25 | eval_only = False # whether to skip training and only evaluate the model 26 | recreate_dataset = False # whether to recreate the tokenized dataset if found 27 | 28 | config = TypeT5Configs.Default # which model configuration to use 29 | 30 | # %% 31 | # ----------------------------------------------------------- 32 | 33 | 34 | TypeCheckSettings.temp_path = f"GPU-{gpu_id}" 35 | print(colored(f"Use GPU: {gpu_id}", "green")) 36 | 37 | if config.quicktest: 38 | print(colored("Quicktest mode", "red")) 39 | if eval_only: 40 | print(colored("Model Evaluating Mode", "blue")) 41 | 42 | project_name = "test-SPOT" if config.quicktest else "SPOT" 43 | train_ctx_args = config.train_ctx_args() 44 | if train_ctx_args.window_size < 100: 45 | print( 46 | colored( 47 | f"[Warning] window size is very small: {train_ctx_args.window_size}", "red" 48 | ) 49 | ) 50 | tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation) 51 | 52 | max_tokens_per_file = config.ctx_size 53 | dec_args = DecodingArgs( 54 | sampling_max_tokens=8 * max_tokens_per_file, 55 | ctx_args=config.dec_ctx_args(), 56 | ) 57 | 58 | dataset = config.trained_on 59 | print("Model will be trained on dataset:", colored(dataset, "blue")) 60 | 61 | sdata_name = get_tk_dataset_name( 62 | dataset, config.pre_args, config.func_only, data_reduction=config.data_reduction 63 | ) 64 | sdata_path = get_dataroot() / "TokenizedSrcSets" / sdata_name 65 | if recreate_dataset or not sdata_path.exists(): 66 | create_tokenized_srcsets( 67 | dataset, 68 | sdata_path, 69 | func_only=config.func_only, 70 | pre_args=config.pre_args, 71 | data_reduction=config.data_reduction, 72 | ) 73 | 74 | tk_dataset = load_tokenized_srcsets( 75 | sdata_path, 76 | quicktest=config.quicktest, 77 | ) 78 | print("Training set stats:") 79 | tk_dataset["train"].print_stats() 80 | 81 | model_name = config.get_model_name() 82 | print(colored(f"Training model: {model_name}", "green")) 83 | 84 | import torch 85 | import wandb 86 | 87 | # %% 88 | # ----------------------------------------------------------- 89 | # train the model 90 | from typet5.train import ModelTrainingArgs, TypeCheckArgs, train_spot_model 91 | from typet5.utils import run_long_task 92 | 93 | if not eval_only: 94 | train_args = ModelTrainingArgs( 95 | train_ctx_args, 96 | dec_args, 97 | train_max_tokens=max_tokens_per_file, 98 | eval_max_tokens=2 * max_tokens_per_file, 99 | max_epochs=1, 100 | tc_args=tc_args, 101 | ) 102 | 103 | wandb.init( 104 | project=project_name, 105 | name=model_name, 106 | config=config.as_dict(), 107 | dir=str(get_dataroot()), 108 | ) 109 | 110 | with run_long_task("Training spot model"): 111 | wrapper = train_spot_model( 112 | tk_dataset, 113 | model_name, 114 | train_args=train_args, 115 | gpus=[gpu_id], 116 | quicktest=config.quicktest, 117 | use_small_model=config.use_small_model, 118 | use_early_stop=False, 119 | ) 120 | else: 121 | wrapper = ModelWrapper.load(get_model_dir() / model_name) 122 | 123 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") 124 | wrapper.to(device) 125 | 126 | 127 | # %% 128 | # ----------------------------------------------------------- 129 | # model evaluation 130 | 131 | from typet5.type_env import AccuracyMetric 132 | from typet5.utils import PickleCache 133 | from typet5.visualization import pretty_print_dict 134 | 135 | bs_args = DecodingArgs( 136 | sampling_max_tokens=max_tokens_per_file, 137 | ctx_args=config.dec_ctx_args(), 138 | do_sample=False, 139 | num_beams=16, 140 | ) 141 | wrapper.args = bs_args 142 | 143 | eval_cache = PickleCache(get_eval_dir(dataset, model_name) / "eval_cache") 144 | # eval_cache.clear() 145 | eval_r = eval_cache.cached( 146 | "dataset_pred.pkl", 147 | lambda: wrapper.eval_on_dataset(tk_dataset["test"]), 148 | ) 149 | common_names = wrapper.common_type_names 150 | metrics = AccuracyMetric.default_metrics(common_names) 151 | r0_accs = {m.name: eval_r.accuracies(m) for m in metrics} 152 | print("Accuracies on all user annotations:") 153 | pretty_print_dict(r0_accs) 154 | 155 | 156 | import wandb 157 | 158 | # %% 159 | # ----------------------------------------------------------- 160 | # close wandb 161 | from typet5.utils import pretty_show_dict 162 | from typet5.visualization import string_to_html 163 | 164 | 165 | def wandb_string(s: str): 166 | return wandb.Html(string_to_html(s)) 167 | 168 | 169 | if not eval_only: 170 | wandb.log({f"test/accuracies": wandb_string(pretty_show_dict(r0_accs))}) 171 | 172 | from typet5.function_dataset import data_project_from_dir, sigmap_from_file_predictions 173 | 174 | # %% 175 | # ----------------------------------------------------------- 176 | # compute accuracies on the top-level elements 177 | from typet5.static_analysis import SignatureErrorAnalysis 178 | 179 | repos_dir = get_dataset_dir(dataset) / "repos" / "test" 180 | test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()] 181 | test_projects = pmap( 182 | data_project_from_dir, 183 | test_repo_paths, 184 | desc="Loading test projects", 185 | ) 186 | 187 | eval_r = eval_r 188 | pred_map, label_map = sigmap_from_file_predictions(eval_r, test_projects, repos_dir) 189 | api_accs = { 190 | m.name: SignatureErrorAnalysis(pred_map, label_map, m).accuracies 191 | for m in AccuracyMetric.default_metrics(common_names) 192 | } 193 | 194 | print("Accuracies on top-level elements:") 195 | pretty_print_dict(api_accs) 196 | if not eval_only: 197 | wandb.log({f"test/api_accuracies": wandb_string(pretty_show_dict(api_accs))}) 198 | 199 | # %% 200 | # ----------------------------------------------------------- 201 | # export the code with inlined predictions as HTML 202 | 203 | from typet5.visualization import export_preds_on_code, proj_root 204 | 205 | export_preds = True 206 | 207 | if export_preds: 208 | max_samples = 500 209 | sample_every = max(1, len(eval_r.chunks) // max_samples) 210 | sub_ids = range(0, len(eval_r.chunks), sample_every) 211 | export_to = proj_root() / "caches" / "model_predictions" / model_name 212 | export_preds_on_code( 213 | eval_r.chunks[sub_ids], 214 | [eval_r.predictions[i] for i in sub_ids], 215 | export_to=export_to, 216 | metric=AccuracyMetric(common_names), 217 | ) 218 | print(f"Model predictions exported to '{export_to}'") 219 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name="TypeT5", 5 | version="0.1", 6 | packages=["typet5"], 7 | package_dir={"": "src"}, 8 | license="BSD 3-Clause", 9 | ) 10 | -------------------------------------------------------------------------------- /src/typet5/__init__.py: -------------------------------------------------------------------------------- 1 | from .type_env import AnnotCat, PythonType 2 | from .utils import proj_root 3 | 4 | __all__ = ["AnnotCat", "PythonType", "proj_root"] 5 | -------------------------------------------------------------------------------- /src/typet5/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/src/typet5/experiments/__init__.py -------------------------------------------------------------------------------- /src/typet5/experiments/hityper.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import requests 4 | 5 | from typet5.experiments.type4py import Type4PyEvalResult, Type4PySupportedSyntax 6 | from typet5.experiments.utils import SupportedSyntax, remove_newer_syntax 7 | from typet5.function_dataset import SignatureMap 8 | from typet5.static_analysis import ( 9 | ElemSignature, 10 | FunctionSignature, 11 | ModuleName, 12 | ProjectPath, 13 | PythonProject, 14 | VariableSignature, 15 | ) 16 | from typet5.type_check import normalize_type, parse_type_expr, parse_type_str 17 | from typet5.type_env import AccuracyMetric 18 | from typet5.utils import * 19 | 20 | PredList = list[list] # of the form [[type1, score1], [type2, score2], ...] 21 | 22 | 23 | def eval_hityper_on_repos( 24 | repo_roots: list[Path], 25 | hityper_python: Path, 26 | work_dir: Path, 27 | max_workers: int | None = None, 28 | ): 29 | out_dirs = [work_dir / r.name for r in repo_roots] 30 | for out_dir in out_dirs: 31 | out_dir.mkdir(parents=True, exist_ok=True) 32 | model_outputs = pmap( 33 | run_hityper, 34 | repo_roots, 35 | out_dirs, 36 | [hityper_python] * len(repo_roots), 37 | desc="Running HiTyper", 38 | max_workers=max_workers, 39 | ) 40 | 41 | projects = [ 42 | PythonProject.parse_from_root(r, discard_bad_files=True) for r in repo_roots 43 | ] 44 | 45 | label_signatures: dict[str, SignatureMap] = { 46 | project.name: {e.path: e.get_signature() for e in project.all_elems()} 47 | for project in projects 48 | } 49 | 50 | pred_signatures: dict[str, SignatureMap] = {n: dict() for n in label_signatures} 51 | for proj, sigmap in zip(projects, model_outputs): 52 | pred_signatures[proj.name] = sigmap 53 | 54 | return Type4PyEvalResult( 55 | pred_maps=pred_signatures, 56 | label_maps=label_signatures, 57 | ) 58 | 59 | 60 | class HiTyperResponseParser: 61 | def __init__(self, module: ModuleName): 62 | self.module = module 63 | self.assignment: SignatureMap = dict() 64 | 65 | def parse(self, res_json: dict[str, list]) -> SignatureMap: 66 | self.assignment = dict() 67 | 68 | def parse_var(x: dict) -> tuple[str, cst.Annotation | None]: 69 | return x["name"], _parse_annot(x["type"]) 70 | 71 | for e_name, e_list in res_json.items(): 72 | if e_name == "global@global": 73 | vars = [v := parse_var(x) for x in e_list if x["category"] == "local"] 74 | for name, annot in vars: 75 | path = ProjectPath(self.module, name) 76 | self.assignment[path] = VariableSignature(annot, in_class=False) 77 | else: 78 | name, parent = e_name.split("@") 79 | parent = "" if parent == "global" else parent 80 | path = ProjectPath(self.module, parent).append(name) 81 | params = [parse_var(x) for x in e_list if x["category"] == "arg"] 82 | returns = [parse_var(x) for x in e_list if x["category"] == "return"] 83 | rt = returns[0][1] if returns else None 84 | self.assignment[path] = FunctionSignature( 85 | {v[0]: v[1] for v in params}, 86 | rt, 87 | in_class=False, 88 | ) 89 | 90 | return self.assignment 91 | 92 | 93 | def run_hityper(repo: Path, out_dir: Path, python_path: Path) -> SignatureMap: 94 | out_dir.mkdir(parents=True, exist_ok=True) 95 | out = subprocess.run( 96 | [ 97 | python_path, 98 | "-m", 99 | "hityper", 100 | "infer", 101 | "--type4py", 102 | "-p", 103 | repo.resolve(), 104 | "-d", 105 | out_dir.resolve(), 106 | ], 107 | cwd=out_dir, 108 | # env={"PYTHONPATH": "src"}, 109 | capture_output=True, 110 | ) 111 | if out.returncode != 0: 112 | raise RuntimeError( 113 | f"HiTyper failed on {repo} with error: {out.stderr.decode()}" 114 | ) 115 | results = json.loads(read_file(out_dir / "inferred_types.json")) 116 | sigmap = SignatureMap() 117 | for fname, mres in results.items(): 118 | mname = PythonProject.rel_path_to_module_name( 119 | Path(fname).relative_to(repo.resolve()) 120 | ) 121 | parser = HiTyperResponseParser(mname) 122 | sigmap.update(parser.parse(mres)) 123 | return sigmap 124 | 125 | 126 | def _parse_annot(ts: list[str]) -> cst.Annotation | None: 127 | if not ts: 128 | return None 129 | try: 130 | if len(ts) == 1: 131 | return cst.Annotation(cst.parse_expression(ts[0])) 132 | else: 133 | return cst.Annotation(cst.parse_expression(" | ".join(ts))) 134 | except cst.ParserSyntaxError: 135 | return None 136 | -------------------------------------------------------------------------------- /src/typet5/experiments/type4py.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | from typet5.experiments.utils import SupportedSyntax, remove_newer_syntax 4 | from typet5.function_dataset import SignatureMap 5 | from typet5.static_analysis import ( 6 | ElemSignature, 7 | FunctionSignature, 8 | ModuleName, 9 | ProjectPath, 10 | PythonProject, 11 | VariableSignature, 12 | reorder_signature_map, 13 | ) 14 | from typet5.type_check import normalize_type, parse_type_expr 15 | from typet5.utils import * 16 | 17 | PredList = list[list] # of the form [[type1, score1], [type2, score2], ...] 18 | 19 | 20 | class Type4PyResponseParser: 21 | def __init__(self, module: ModuleName): 22 | self.module = module 23 | self.assignment: dict[ProjectPath, ElemSignature] = dict() 24 | 25 | def parse(self, res_json: dict) -> dict[ProjectPath, ElemSignature]: 26 | self.assignment = dict() 27 | res = res_json["response"] 28 | 29 | for name, pred in res["variables_p"].items(): 30 | annot = self.parse_prediction(pred) 31 | sig = VariableSignature(annot, in_class=False) 32 | self.assignment[ProjectPath(self.module, name)] = sig 33 | for f in res["funcs"]: 34 | self._parse_func( 35 | f, 36 | ProjectPath(self.module, ""), 37 | in_class=False, 38 | ) 39 | for c in res["classes"]: 40 | self._parse_cls(c, ProjectPath(self.module, "")) 41 | return self.assignment 42 | 43 | def _parse_cls(self, cls_json: dict, base: ProjectPath): 44 | attr_preds: dict[str, PredList] = cls_json["variables_p"] 45 | new_base = base.append(cls_json["name"]) 46 | for name, pred in attr_preds.items(): 47 | annot = self.parse_prediction(pred) 48 | sig = VariableSignature(annot, in_class=True) 49 | self.assignment[new_base.append(name)] = sig 50 | for func_json in cls_json["funcs"]: 51 | try: 52 | self._parse_func(func_json, new_base, in_class=True) 53 | except: 54 | print(f"Failed to parse function") 55 | print("JSON:") 56 | display(func_json) 57 | raise 58 | 59 | @staticmethod 60 | def parse_prediction(pred: PredList) -> cst.Annotation | None: 61 | if pred: 62 | return cst.Annotation(cst.parse_expression(pred[0][0])) 63 | else: 64 | return None 65 | 66 | def _parse_func(self, func_json: dict, base: ProjectPath, in_class: bool): 67 | preds = func_json["params_p"] 68 | params_pred = { 69 | v: self.parse_prediction(preds[v]) for v in preds if len(preds[v]) > 0 70 | } 71 | ret_pred = None 72 | if "ret_type_p" in func_json: 73 | ret_pred = self.parse_prediction(func_json["ret_type_p"]) 74 | if ret_pred is None: 75 | ret_pred = cst.Annotation(cst.parse_expression("None")) 76 | sig = FunctionSignature( 77 | params_pred, 78 | ret_pred, 79 | in_class=in_class, 80 | ) 81 | self.assignment[base.append(func_json["name"])] = sig 82 | 83 | 84 | def run_type4py_request( 85 | code: str, module: ModuleName 86 | ) -> dict[ProjectPath, ElemSignature] | str: 87 | res = requests.post("https://type4py.com/api/predict?tc=0", code.encode()).json() 88 | if res["response"] is None: 89 | return res["error"] 90 | return Type4PyResponseParser(module).parse(res) 91 | 92 | 93 | @dataclass 94 | class Type4PyEvalResult: 95 | pred_maps: dict[str, SignatureMap] 96 | label_maps: dict[str, SignatureMap] 97 | 98 | def __post_init__(self): 99 | # reorder the function args to match the labels 100 | for pname, pred_map in self.pred_maps.items(): 101 | if pname not in self.label_maps: 102 | continue 103 | label_map = self.label_maps[pname] 104 | self.pred_maps[pname] = reorder_signature_map(pred_map, label_map) 105 | 106 | 107 | Type4PySupportedSyntax = SupportedSyntax( 108 | pattern_match=False, union_types=False, basic_types=False 109 | ) 110 | 111 | 112 | def eval_type4py_on_projects( 113 | projects: list[PythonProject], 114 | max_workers: int = 4, 115 | ) -> Type4PyEvalResult: 116 | name2project = {p.name: p for p in projects} 117 | module_srcs = { 118 | (project.name, name): remove_newer_syntax(m.tree, Type4PySupportedSyntax).code 119 | for project in projects 120 | for name, m in project.modules.items() 121 | } 122 | model_outputs = pmap( 123 | run_type4py_request, 124 | list(module_srcs.values()), 125 | [mname for pname, mname in module_srcs.keys()], 126 | desc="Calling Type4Py", 127 | max_workers=max_workers, 128 | ) 129 | 130 | label_signatures: dict[str, SignatureMap] = { 131 | project.name: {e.path: e.get_signature() for e in project.all_elems()} 132 | for project in projects 133 | } 134 | 135 | pred_signatures: dict[str, SignatureMap] = {n: dict() for n in label_signatures} 136 | for (pname, mname), o in zip(module_srcs.keys(), model_outputs): 137 | if isinstance(o, str): 138 | if list(name2project[pname].modules[mname].all_elements()): 139 | # only warn for non-empty modules 140 | logging.warning( 141 | f"In project {pname} module {mname}, Type4Py errored: {o}" 142 | ) 143 | else: 144 | pred_signatures[pname].update(o) 145 | 146 | return Type4PyEvalResult( 147 | pred_maps=pred_signatures, 148 | label_maps=label_signatures, 149 | ) 150 | -------------------------------------------------------------------------------- /src/typet5/experiments/typet5.py: -------------------------------------------------------------------------------- 1 | from typet5.train import * 2 | 3 | 4 | def accs_as_table_row(accs_dict: dict): 5 | def retrive(path: str): 6 | segs = path.split(".") 7 | target = accs_dict 8 | for s in segs: 9 | if s not in target: 10 | return "N/A" 11 | target = target[s] 12 | assert isinstance(target, CountedAcc), f"Unexpected type: {CountedAcc}" 13 | return f"{target.acc * 100:.2f}" 14 | 15 | def print_row(name: str, postfix: str): 16 | row = { 17 | "full.all": f"full_acc{postfix}.full_acc{postfix}", 18 | "calibrated.all": f"acc{postfix}.acc{postfix}", 19 | "calibrated.simple": f"acc{postfix}.acc{postfix}_by_simple.simple", 20 | "calibrated.complex": f"acc{postfix}.acc{postfix}_by_simple.complex", 21 | "base.all": f"base_acc{postfix}.base_acc{postfix}", 22 | } 23 | 24 | nums = [retrive(path) for path in row.values()] 25 | print(f"Accuracies on {name} types:") 26 | print("header: ", list(row.keys())) 27 | print(" & ".join(nums)) 28 | 29 | print_row("all", "") 30 | print_row("common", "_common") 31 | print_row("rare", "_rare") 32 | 33 | 34 | class TypeT5Configs: 35 | Default = TrainingConfig( 36 | func_only=True, 37 | pre_args=PreprocessArgs( 38 | drop_env_types=False, 39 | add_implicit_rel_imports=True, 40 | ), 41 | left_margin=2048, 42 | right_margin=2048 - 512, 43 | preamble_size=1000, 44 | ) 45 | 46 | NoPreamble = TrainingConfig( 47 | func_only=True, 48 | pre_args=PreprocessArgs( 49 | imports_in_preamble=False, 50 | stub_in_preamble=False, 51 | drop_env_types=False, 52 | add_implicit_rel_imports=True, 53 | ), 54 | left_margin=2048, 55 | right_margin=2048 - 512, 56 | preamble_size=1000, 57 | ) 58 | 59 | NoUsees = TrainingConfig( 60 | func_only=True, 61 | pre_args=PreprocessArgs( 62 | max_callees=0, 63 | drop_env_types=False, 64 | add_implicit_rel_imports=True, 65 | ), 66 | left_margin=512, 67 | right_margin=2048 + 1024, 68 | preamble_size=511, 69 | ) 70 | 71 | NoUsers = TrainingConfig( 72 | func_only=True, 73 | pre_args=PreprocessArgs( 74 | max_callers=0, 75 | drop_env_types=False, 76 | add_implicit_rel_imports=True, 77 | ), 78 | left_margin=2048 + 1024, 79 | right_margin=512, 80 | preamble_size=1000, 81 | ) 82 | 83 | NoSequential = TrainingConfig( 84 | func_only=True, 85 | pre_args=PreprocessArgs( 86 | drop_env_types=True, 87 | add_implicit_rel_imports=True, 88 | ), 89 | left_margin=2048, 90 | right_margin=2048 - 512, 91 | preamble_size=1000, 92 | ) 93 | -------------------------------------------------------------------------------- /src/typet5/experiments/typilus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | from typet5.experiments.utils import SupportedSyntax 5 | from typet5.function_dataset import collect_public_api_labels 6 | from typet5.static_analysis import LabelInfo, ModuleName, PythonProject 7 | from typet5.type_check import PythonType, parse_type_expr, parse_type_str 8 | from typet5.type_env import AccuracyMetric, type_accuracies 9 | from typet5.utils import * 10 | 11 | JSON = dict 12 | 13 | 14 | def eval_typilus_on_repos( 15 | repo_roots: list[Path], 16 | metrics: list[AccuracyMetric], 17 | typilus_path: Path, 18 | work_dir: Path, 19 | max_workers: int | None = None, 20 | ): 21 | out_dirs = [work_dir / r.name for r in repo_roots] 22 | for out_dir in out_dirs: 23 | out_dir.mkdir(parents=True, exist_ok=True) 24 | pmap( 25 | run_typilus, 26 | repo_roots, 27 | out_dirs, 28 | [typilus_path] * len(repo_roots), 29 | desc="Running Typilus", 30 | max_workers=max_workers, 31 | ) 32 | 33 | typilus_outputs = [] 34 | for repo in repo_roots: 35 | with open(work_dir / repo.name / "predictions.json") as f: 36 | typilus_outputs.append(json.load(f)) 37 | return analyze_typilus_predictions( 38 | typilus_outputs, 39 | repo_roots, 40 | metrics, 41 | ) 42 | 43 | 44 | def run_typilus(repo: Path, out_dir: Path, typilus_path: Path) -> None: 45 | typilus_python = "/home/jiayi/anaconda3/envs/typilus-torch/bin/python" 46 | 47 | out = subprocess.run( 48 | [ 49 | typilus_python, 50 | "-m", 51 | "run_typilus", 52 | repo.resolve(), 53 | out_dir.resolve(), 54 | ], 55 | cwd=typilus_path, 56 | env={"PYTHONPATH": "src"}, 57 | capture_output=True, 58 | ) 59 | if out.returncode != 0: 60 | raise RuntimeError( 61 | f"Typilus failed on {repo} with error: {out.stderr.decode()}" 62 | ) 63 | 64 | 65 | def analyze_typilus_predictions( 66 | typilus_outputs: list[JSON], 67 | repo_roots: list[Path], 68 | metrics: list[AccuracyMetric], 69 | common_labels_only: bool = True, 70 | ): 71 | assert_eq( 72 | len({r.name for r in repo_roots}), 73 | len(repo_roots), 74 | extra_message=lambda: "repo names must be unique", 75 | ) 76 | 77 | # first, build the label map 78 | label_maps = pmap(collect_public_api_labels, repo_roots, desc="Collecting labels") 79 | 80 | # then, collect the prediction-label pairs 81 | pred_maps = list[dict[ModuleName, dict[CodePosition, str]]]() 82 | for out in typilus_outputs: 83 | pred_map = dict[ModuleName, dict[CodePosition, str]]() 84 | for file, preds in out.items(): 85 | assert isinstance(file, str) 86 | assert isinstance(preds, list) 87 | if file.startswith("/"): 88 | file = file[1:] 89 | file_mod = PythonProject.rel_path_to_module_name(Path(file)) 90 | submap = pred_map[file_mod] = dict() 91 | for pred in preds: 92 | line, col = pred["location"] 93 | submap[CodePosition(line, col)] = pred["pred"] 94 | pred_maps.append(pred_map) 95 | 96 | pred_list, label_list, cat_list = [], [], [] 97 | n_missing = 0 98 | n_skipped_rare = 0 99 | for label_map, pred_map in zip(label_maps, pred_maps): 100 | for mod, labels in label_map.items(): 101 | preds = pred_map.get(mod) 102 | n_labels = len(labels) 103 | if preds is None: 104 | if n_labels > 0: 105 | logging.warning(f"Missing {n_labels} predictions for module: {mod}") 106 | n_missing += n_labels 107 | continue 108 | for pos, linfo in labels.items(): 109 | pred = preds.get(pos) 110 | if pred is None: 111 | n_missing += 1 112 | continue 113 | 114 | if (ltype := parse_type_expr(linfo.annot.annotation)) is None: 115 | continue 116 | try: 117 | pred = parse_type_str(pred) 118 | except SyntaxError: 119 | pred = PythonType.Any() 120 | 121 | if common_labels_only and not metrics[0].is_common_type(ltype): 122 | n_skipped_rare += 1 123 | continue 124 | label_list.append(ltype) 125 | cat_list.append(linfo.cat) 126 | pred_list.append(pred) 127 | 128 | accs = dict[str, dict]() 129 | for m in metrics: 130 | accs[m.name] = sub_a = type_accuracies(pred_list, label_list, cat_list, m) 131 | if n_missing > 0: 132 | sub_a["n_missing"] = n_missing 133 | if n_skipped_rare > 0: 134 | sub_a["n_skipped_rare"] = n_skipped_rare 135 | return accs 136 | 137 | 138 | TypilusSupportedSyntax = SupportedSyntax( 139 | pattern_match=False, 140 | union_types=False, 141 | basic_types=False, 142 | named_exprs=False, 143 | ) 144 | -------------------------------------------------------------------------------- /src/typet5/experiments/utils.py: -------------------------------------------------------------------------------- 1 | from typet5.function_decoding import EvalResult 2 | from typet5.static_analysis import ( 3 | ElemSignature, 4 | FunctionSignature, 5 | ModuleName, 6 | ProjectPath, 7 | PythonModule, 8 | PythonProject, 9 | SignatureMap, 10 | VariableSignature, 11 | _VisitKind, 12 | is_type_rhs, 13 | ) 14 | from typet5.type_check import MypyChecker, MypyFeedback, MypyResult 15 | from typet5.type_env import normalize_type, parse_type_expr 16 | from typet5.utils import * 17 | 18 | _DefaultImport = cst.parse_statement( 19 | "from typing import Any, List, Tuple, Dict, Set, Union, Type, Callable # SPOT" 20 | ) 21 | 22 | 23 | @dataclass 24 | class SupportedSyntax: 25 | pattern_match: bool = True 26 | union_types: bool = True 27 | basic_types: bool = True 28 | named_exprs: bool = True 29 | 30 | 31 | def remove_newer_syntax(m: cst.Module, supported: SupportedSyntax) -> cst.Module: 32 | """ 33 | Remove or rewrite any newer python features that Type4Py doesn't support. 34 | """ 35 | 36 | class PatternRewriter(cst.CSTTransformer): 37 | def leave_MatchAs(self, node, updated: cst.MatchAs): 38 | if updated.pattern: 39 | return updated.pattern 40 | elif updated.name: 41 | return updated.name 42 | else: 43 | # wild card pattern 44 | return cst.Name("_") 45 | 46 | def pattern_to_expr(pattern: cst.MatchPattern): 47 | np = cast(cst.BaseExpression, pattern.visit(PatternRewriter())) 48 | return cst.parse_expression(m.code_for_node(np)) 49 | 50 | class Rewriter(cst.CSTTransformer): 51 | def leave_Annotation(self, node, updated: "cst.Annotation"): 52 | if supported.union_types: 53 | return updated 54 | ty = parse_type_expr(updated.annotation, silent=True) 55 | if ty is None: 56 | return cst.RemoveFromParent() 57 | ty = normalize_type(ty) # this should get rid of the Union type syntax. 58 | return updated.with_changes(annotation=cst.parse_expression(str(ty))) 59 | 60 | def leave_Module(self, node, updated: "cst.Module"): 61 | new_lines = [_DefaultImport] if not supported.basic_types else [] 62 | default_import = updated.code_for_node(_DefaultImport) 63 | for stmt in updated.body: 64 | if updated.code_for_node(stmt) != default_import: 65 | new_lines.append(stmt) 66 | return updated.with_changes(body=new_lines) 67 | 68 | def leave_Match(self, node, updated: cst.Match): 69 | if supported.pattern_match: 70 | return updated 71 | subject = updated.subject 72 | if isinstance(subject, cst.Tuple): 73 | subject = subject.with_changes( 74 | lpar=[cst.LeftParen()], rpar=[cst.RightParen()] 75 | ) 76 | 77 | conditions = [ 78 | cst.Comparison( 79 | subject, 80 | [ 81 | cst.ComparisonTarget( 82 | cst.Equal(), 83 | pattern_to_expr(c.pattern), 84 | ) 85 | ], 86 | ) 87 | for c in updated.cases 88 | ] 89 | bodies = [c.body for c in updated.cases] 90 | if_clauses = None 91 | for cond, body in reversed(list(zip(conditions, bodies))): 92 | if_clauses = cst.If(cond, body, orelse=if_clauses) 93 | assert isinstance(if_clauses, cst.If) 94 | return if_clauses 95 | 96 | def leave_NamedExpr(self, node, updated: "cst.NamedExpr"): 97 | if supported.named_exprs: 98 | return updated 99 | return updated.value 100 | 101 | return m.visit(Rewriter()) 102 | 103 | 104 | def remove_newer_syntax_for_file(file: Path, rules: SupportedSyntax) -> bool: 105 | text = read_file(file) 106 | m = cst.parse_module(text) 107 | m = remove_newer_syntax(m, rules) 108 | new_text = m.code 109 | if new_text != text: 110 | write_file(file, new_text) 111 | return True 112 | return False 113 | 114 | 115 | def remove_newer_syntax_for_repo(root: Path, rules: SupportedSyntax) -> None: 116 | all_files = [p for p in root.glob("**/*.py") if p.is_file()] 117 | changed = pmap( 118 | remove_newer_syntax_for_file, 119 | all_files, 120 | [rules] * len(all_files), 121 | desc="Removing newer syntax", 122 | ) 123 | print(f"{sum(changed)} / {len(all_files)} files have been rewritten.") 124 | 125 | 126 | def apply_sigmap( 127 | m: cst.Module, 128 | sigmap: SignatureMap, 129 | module_name: ModuleName, 130 | add_default_imports=True, 131 | ) -> cst.Module: 132 | """ 133 | Apply the signature map to the module. 134 | """ 135 | 136 | class Rewriter(cst.CSTTransformer): 137 | def __init__(self): 138 | super().__init__() 139 | self.path_stack = [ProjectPath(module_name, "")] 140 | self.visit_stack = [_VisitKind.Root] 141 | 142 | @property 143 | def current_path(self) -> ProjectPath: 144 | return self.path_stack[-1] 145 | 146 | @property 147 | def current_visit_kind(self) -> _VisitKind: 148 | return self.visit_stack[-1] 149 | 150 | def enter_(self, name: str, kind: _VisitKind): 151 | self.path_stack.append(self.current_path.append(name)) 152 | self.visit_stack.append(kind) 153 | 154 | def exit_(self): 155 | self.path_stack.pop() 156 | self.visit_stack.pop() 157 | 158 | def visit_FunctionDef(self, node: cst.FunctionDef): 159 | self.enter_(node.name.value, _VisitKind.Function) 160 | 161 | def leave_FunctionDef(self, node, updated: cst.FunctionDef): 162 | if isinstance(sig := sigmap.get(self.current_path), FunctionSignature): 163 | try: 164 | updated = sig.apply(updated) 165 | except LookupError: 166 | pass 167 | self.exit_() 168 | return updated 169 | 170 | def visit_ClassDef(self, node: "cst.ClassDef") -> Optional[bool]: 171 | self.enter_(node.name.value, _VisitKind.Class) 172 | 173 | def leave_ClassDef(self, node, updated: cst.ClassDef): 174 | self.exit_() 175 | return updated 176 | 177 | def leave_AnnAssign(self, node, updated: cst.AnnAssign): 178 | target = None 179 | match updated.target: 180 | case cst.Name(name): 181 | target = name 182 | if ( 183 | target is not None 184 | and isinstance( 185 | sig := sigmap.get(self.current_path.append(target)), 186 | VariableSignature, 187 | ) 188 | and sig.annot is not None 189 | ): 190 | updated = updated.with_changes(annotation=sig.annot) 191 | return updated 192 | 193 | def leave_Assign(self, node, updated: cst.Assign): 194 | target = None 195 | if self.current_visit_kind != _VisitKind.Function: 196 | match updated.targets: 197 | case [cst.AssignTarget(target=cst.Name(name))]: 198 | target = name 199 | if ( 200 | target is not None 201 | and isinstance( 202 | sig := sigmap.get(self.current_path.append(target)), 203 | VariableSignature, 204 | ) 205 | and sig.annot is not None 206 | and not ( 207 | self.current_visit_kind == _VisitKind.Root 208 | and is_type_rhs(updated.value) 209 | ) # skip annotating type aliases 210 | ): 211 | return cst.AnnAssign(cst.Name(target), sig.annot, updated.value) 212 | return updated 213 | 214 | def leave_Module(self, node, updated: cst.Module): 215 | if add_default_imports: 216 | return updated.with_changes( 217 | body=[_DefaultImport] + list(updated.body), 218 | ) 219 | return updated 220 | 221 | return m.visit(Rewriter()) 222 | 223 | 224 | def quote_annotations(m: cst.Module, normalize_types: bool = True) -> cst.Module: 225 | """ 226 | Quote all type annotations as strings in the module.. 227 | """ 228 | 229 | class Rewriter(cst.CSTTransformer): 230 | def leave_Annotation(self, node, updated: "cst.Annotation"): 231 | if updated.annotation is None: 232 | return updated 233 | if normalize_types: 234 | ty = parse_type_expr(updated.annotation) 235 | if ty is not None: 236 | text = repr(str(ty.normalized())) 237 | else: 238 | text = repr(show_expr(updated.annotation, quoted=False)) 239 | else: 240 | text = repr(show_expr(updated.annotation, quoted=False)) 241 | return updated.with_changes(annotation=cst.SimpleString(text)) 242 | 243 | return m.visit(Rewriter()) 244 | 245 | 246 | def apply_sigmap_and_typecheck( 247 | project: PythonProject, 248 | sigmap: SignatureMap, 249 | workdir: Path, 250 | quote_types=True, 251 | binary_path: Optional[Path] = None, 252 | ) -> MypyResult: 253 | assert workdir.is_dir(), f"Workdir is not a directory: {workdir}" 254 | 255 | # write the type annotated source files to the workdir 256 | for name, m in project.modules.items(): 257 | # file = workdir / project.module2src_file[name] 258 | file = workdir / (name.replace(".", "/") + ".py") 259 | file.parent.mkdir(parents=True, exist_ok=True) 260 | m1 = apply_sigmap(m.tree, sigmap, name) 261 | if quote_types: 262 | m1 = quote_annotations(m1, normalize_types=True) 263 | write_file(file, m1.code) 264 | # handle __init__.py files specially 265 | files = list(workdir.glob("**/*.py")) 266 | for f in files: 267 | if (d := f.with_suffix("")).is_dir(): 268 | f.rename(d / "__init__.py") 269 | 270 | # now call the type checker 271 | r = MypyChecker.check_project(workdir, binary_path) 272 | if isinstance(r, str): 273 | raise RuntimeError(f"Type checking failed: {r}") 274 | return r 275 | 276 | 277 | def count_type_errors( 278 | result: Iterable[MypyFeedback], 279 | ) -> int: 280 | error_codes = { 281 | "name-defined", 282 | "attr-defined", 283 | "arg-type", 284 | "return-value", 285 | "assignment", 286 | } 287 | 288 | n = 0 289 | for e in result: 290 | if e.error_code in error_codes: 291 | n += 1 292 | return n 293 | 294 | 295 | def collect_project_type_errors( 296 | proj: PythonProject, 297 | sigmap: SignatureMap, 298 | workdir: Path = Path("mypy_temp"), 299 | binary_path: Path | None = None, 300 | ) -> list[MypyFeedback]: 301 | workdir = workdir / proj.name 302 | shutil.rmtree(workdir, ignore_errors=True) 303 | workdir.mkdir(exist_ok=True, parents=True) 304 | try: 305 | check_r = apply_sigmap_and_typecheck( 306 | proj, sigmap, workdir, binary_path=binary_path 307 | ) 308 | return [e for es in check_r.error_dict.values() for e in es] 309 | except RuntimeError as e: 310 | print("Warning: mypy failed for project:", proj.name) 311 | return [] 312 | -------------------------------------------------------------------------------- /src/typet5/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import Counter 3 | from copy import copy, deepcopy 4 | from typing import NamedTuple, overload 5 | 6 | import numpy as np 7 | from datasets.arrow_dataset import Dataset 8 | from huggingface_hub import snapshot_download 9 | from mypy_extensions import mypyc_attr 10 | from torch import Tensor 11 | from torch.utils.data import DataLoader, RandomSampler 12 | from transformers.data.data_collator import DataCollatorForSeq2Seq 13 | 14 | from .data import ( 15 | ChunkedDataset, 16 | CtxArgs, 17 | TokenizedSrcSet, 18 | output_ids_as_types, 19 | preds_to_accuracies, 20 | ) 21 | from .type_env import AccuracyMetric, PythonType 22 | from .utils import * 23 | 24 | 25 | @dataclass 26 | class DecodingArgs: 27 | ctx_args: CtxArgs 28 | sampling_max_tokens: int 29 | max_workers: int = DefaultWorkers 30 | # the maximal prediction length = tokens_per_type * num_types + slack_tokens 31 | tokens_per_type: int = 16 32 | slack_tokens: int = 10 33 | do_sample: bool = False 34 | top_p: float = 0.9 35 | num_beams: Optional[int] = None 36 | num_beam_groups: Optional[int] = None 37 | length_penalty: float = 1.0 38 | diversity_penalty: float | None = None 39 | 40 | def scale_ctx_size(self, factor: float) -> "DecodingArgs": 41 | result = deepcopy(self) 42 | assert result.ctx_args is not None 43 | """Scale the context size of the model by the given factor, while keeping the window size the same. 44 | Also scale down the sampling batch size accordingly.""" 45 | ctx_size = round(self.ctx_args.ctx_size * factor) 46 | right_margin = round(self.ctx_args.right_margin * factor) 47 | left_margin = ctx_size - right_margin - self.ctx_args.window_size 48 | result.ctx_args.ctx_size = ctx_size 49 | result.ctx_args.left_margin = left_margin 50 | result.ctx_args.right_margin = right_margin 51 | result.sampling_max_tokens = round(self.sampling_max_tokens / factor**2) 52 | 53 | return result 54 | 55 | def __repr__(self) -> str: 56 | return repr_modified_args(self) 57 | 58 | 59 | @dataclass 60 | class DatasetPredResult(Generic[T1]): 61 | chunks: ChunkedDataset 62 | predictions: list[list[PythonType]] 63 | extra_info: list[T1] = field(default_factory=list) 64 | 65 | def accuracies(self, metric: AccuracyMetric) -> dict: 66 | return preds_to_accuracies(self.predictions, self.chunks, metric) 67 | 68 | def group_by_repo(self, repos_dir: Path) -> dict[Path, "DatasetPredResult[T1]"]: 69 | chunk2repo = list[Path]() 70 | for i, info in enumerate(self.chunks.chunks_info): 71 | file = repos_dir / info.src_file 72 | repo = self.chunks.file2repo[file] 73 | chunk2repo.append(repo) 74 | 75 | group2ids = groupby(range(len(chunk2repo)), lambda i: chunk2repo[i]) 76 | result = dict() 77 | chunk_ids = self.chunks.data["chunk_id"] 78 | for repo, ids in group2ids.items(): 79 | result[repo] = DatasetPredResult( 80 | self.chunks[(chunk_ids[i] for i in ids)], 81 | [self.predictions[i] for i in ids], 82 | [self.extra_info[i] for i in ids] if self.extra_info else [], 83 | ) 84 | return result 85 | 86 | 87 | @dataclass 88 | class ModelWrapper: 89 | model: ModelType 90 | tokenizer: TokenizerType 91 | args: DecodingArgs 92 | common_type_names: set[str] 93 | monitor: TaskMonitor = EmptyLoggingMonitor() 94 | 95 | @staticmethod 96 | def get_codet5_path(use_small_model: bool = False) -> str: 97 | return ( 98 | "Salesforce/codet5-small" if use_small_model else "Salesforce/codet5-base" 99 | ) 100 | 101 | def scale_ctx_size(self, factor) -> "ModelWrapper": 102 | r = copy(self) 103 | r.args = r.args.scale_ctx_size(factor) 104 | return r 105 | 106 | def predict_on_batch( 107 | self, 108 | batch: dict, 109 | num_return_sequences: int | None = None, 110 | ) -> tuple[list[list[PythonType]], Tensor]: 111 | """Run the model on the given batch and return the predicted types for each row.""" 112 | model = self.model 113 | args = self.args 114 | n_labels = batch["n_labels"] 115 | max_labels = max(n_labels) 116 | 117 | div_pen = args.diversity_penalty 118 | if args.num_beam_groups is not None: 119 | assert ( 120 | div_pen is not None and div_pen > 0 121 | ), "num_beam_groups requires diversity_penalty > 0" 122 | 123 | output_ids = model.generate( 124 | inputs=batch["input_ids"].to(model.device), 125 | do_sample=args.do_sample, 126 | top_p=args.top_p, 127 | num_beams=args.num_beams, 128 | num_return_sequences=num_return_sequences, 129 | num_beam_groups=args.num_beam_groups, 130 | max_length=args.tokens_per_type * max_labels + args.slack_tokens, 131 | diversity_penalty=div_pen, 132 | length_penalty=args.length_penalty, 133 | renormalize_logits=True, 134 | ).cpu() # type: ignore 135 | assert len(output_ids.shape) == 2 136 | 137 | def decode_row(row, n_labels) -> list[PythonType]: 138 | return output_ids_as_types(row, n_labels) 139 | 140 | n_rows = output_ids.shape[0] 141 | if num_return_sequences is not None: 142 | assert_eq(n_rows, num_return_sequences * len(n_labels)) 143 | else: 144 | num_return_sequences = 1 145 | types = [ 146 | decode_row(output_ids[i, :], n_labels[i // num_return_sequences]) 147 | for i in range(n_rows) 148 | ] 149 | return types, output_ids 150 | 151 | @overload 152 | def predict( 153 | self, dataset: Dataset, tqdm_args: dict = {}, num_return_sequences: None = None 154 | ) -> list[list[PythonType]]: 155 | ... 156 | 157 | @overload 158 | def predict( 159 | self, dataset: Dataset, tqdm_args: dict, num_return_sequences: int 160 | ) -> list[list[list[PythonType]]]: 161 | ... 162 | 163 | def predict( 164 | self, 165 | dataset: Dataset, 166 | tqdm_args: dict = {}, 167 | num_return_sequences: Optional[int] = None, 168 | ): 169 | """Run the model on the given dataset and return the predicted types 170 | (or multiple sequences of predicted types if num_return_sequences is not none) for each row.""" 171 | model = self.model 172 | collator = DataCollatorForSeq2Seq(self.tokenizer, model) 173 | loader = dynamic_dataloader( 174 | dataset, # type: ignore 175 | max_tokens=self.args.sampling_max_tokens, 176 | collate_fn=collator, 177 | shuffle=True, 178 | ) 179 | device = model.device 180 | # we use this dict to keep the order of the chunks since it may be permuted by dynamic_dataloader 181 | pred_types = dict[int, list]() 182 | with tqdm( 183 | total=len(dataset), desc="predict", smoothing=0.01, **tqdm_args 184 | ) as tqdm_bar: 185 | for batch in loader: 186 | n_chunks = batch["input_ids"].shape[0] 187 | batch["input_ids"] = batch["input_ids"].to(device) 188 | preds, _ = self.predict_on_batch(batch, num_return_sequences) 189 | for i, c_id in enumerate(batch["chunk_id"]): 190 | c_id = int(c_id) 191 | if num_return_sequences is None: 192 | pred_types[c_id] = preds[i] 193 | else: 194 | pred_types[c_id] = preds[ 195 | i * num_return_sequences : (i + 1) * num_return_sequences 196 | ] 197 | tqdm_bar.update(n_chunks) 198 | return [pred_types[int(c_id)] for c_id in dataset["chunk_id"]] 199 | 200 | def save(self, path: Path): 201 | """Save the model to the given path along with its tokenizer and args.""" 202 | self.model.save_pretrained(str(path)) 203 | self.tokenizer.save_pretrained(str(path)) 204 | pickle_dump(path / "args.pkl", self.args) 205 | pickle_dump(path / "common_names.pkl", self.common_type_names) 206 | 207 | def to(self, device) -> "ModelWrapper": 208 | self.model = self.model.to(device) 209 | return self 210 | 211 | @classmethod 212 | def load_from_hub(cls, repo_name: str): 213 | path = snapshot_download(repo_name) 214 | return cls.load(Path(path)) 215 | 216 | @classmethod 217 | def load(cls, path: Path) -> "ModelWrapper": 218 | """Load a pretrained model from the given path.""" 219 | model = cast(ModelType, ModelType.from_pretrained(str(path))) 220 | tokenizer = TokenizerType.from_pretrained(str(path)) 221 | args = pickle_load(path / "args.pkl") 222 | common_type_names = ModelWrapper.load_common_type_names(path) 223 | return ModelWrapper( 224 | model=model, 225 | tokenizer=tokenizer, 226 | args=args, 227 | common_type_names=common_type_names, 228 | monitor=TaskLoggingMonitor(path.name), 229 | ) 230 | 231 | @classmethod 232 | def load_common_type_names(cls, model_path: Path) -> set[str]: 233 | if (model_path / "common_names.pkl").exists(): 234 | return pickle_load(model_path / "common_names.pkl") 235 | else: 236 | return set() 237 | 238 | def eval_on_dataset( 239 | self, 240 | src_data: TokenizedSrcSet, 241 | max_labels: Optional[int] = None, 242 | tqdm_args: dict = {}, 243 | ) -> DatasetPredResult: 244 | """Convinient method to preprocess the src according to the model's ctx_args and evaluate the (R0) accuracy.""" 245 | ctx_args = self.args.ctx_args 246 | if max_labels is not None: 247 | ctx_args = copy(ctx_args) 248 | ctx_args.max_labels = max_labels 249 | 250 | chunks = src_data.to_chunks(ctx_args, tqdm_args=tqdm_args) 251 | preds = self.predict( 252 | chunks.data, num_return_sequences=None, tqdm_args=tqdm_args 253 | ) 254 | return DatasetPredResult(chunks, preds) 255 | 256 | 257 | def dynamic_dataloader( 258 | dataset: Dataset, 259 | max_tokens: int, 260 | collate_fn, 261 | shuffle: bool = False, 262 | ): 263 | ex_sizes = [len(x) for x in dataset["input_ids"]] 264 | ids = list(range(len(ex_sizes))) 265 | if shuffle: 266 | random.shuffle(ids) 267 | ids.sort(key=lambda x: ex_sizes[x], reverse=True) 268 | batches = list[list[int]]() 269 | while len(ids) > 0: 270 | w = ex_sizes[ids[0]] 271 | n = max(1, max_tokens // w) 272 | batches.append(ids[:n]) 273 | ids = ids[n:] 274 | if shuffle: 275 | random.shuffle(batches) 276 | 277 | return DataLoader( 278 | cast(Any, dataset), 279 | batch_sampler=batches, 280 | collate_fn=collate_fn, 281 | ) 282 | -------------------------------------------------------------------------------- /src/typet5/tokenized_src.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy 4 | from libcst.metadata import CodeRange, PositionProvider 5 | 6 | from .static_analysis import remove_comments, remove_imports, stub_from_module 7 | from .type_check import MypyFeedback, PythonType, parse_type_str 8 | from .type_env import ( 9 | AnnotInfo, 10 | AnnotPath, 11 | CodePathManager, 12 | apply_annotations, 13 | collect_user_annotations, 14 | ) 15 | from .utils import * 16 | 17 | TokenSeq = list[int] # might need to make this more space-efficient 18 | 19 | 20 | @dataclass 21 | class TokenizedSrc: 22 | """A src file with certain type annotations masked out.""" 23 | 24 | file: Path 25 | repo: Path 26 | types: list[PythonType] 27 | types_pos: list[int] # the position of the types in tokenized_code. 28 | types_str: list[str] 29 | types_tks: list[TokenSeq] 30 | types_info: list[AnnotInfo] 31 | main_code: str 32 | tokenized_code: TokenSeq # with certain types masked out 33 | preamble_code: str 34 | tokenized_preamble: TokenSeq 35 | prev_types: dict[int, PythonType] | None = None # previously predicted types 36 | inlined_spans: dict[int, slice] | None = None # the spans of inlined previous types 37 | feedbacks: list[MypyFeedback] | None = None 38 | 39 | @staticmethod 40 | def parse( 41 | code: str, 42 | file: Path, 43 | repo: Path, 44 | args: "PreprocessArgs", 45 | ) -> "TokenizedSrc": 46 | d = preprocess_code(code, args) 47 | d["file"] = file 48 | d["repo"] = repo 49 | return tokenized_src_from_segs(**d) 50 | 51 | def __str__(self): 52 | segs = [ 53 | "========TokenizedSrc========", 54 | f"file:{self.file}", 55 | f"repo:{self.repo}", 56 | "--------Preamble--------", 57 | self.preamble_code, 58 | "--------Main Code--------", 59 | decode_tokens(self.tokenized_code), 60 | "========End of TokenizedSrc========", 61 | ] 62 | return "\n".join(segs) 63 | 64 | def inline_prev_predictions( 65 | self, as_comment: bool, prev_types: dict[int, PythonType] | None = None 66 | ) -> "TokenizedSrc": 67 | "Inine the previous predictions into the code, either directly or as comments." 68 | if len(self.types) == 0: 69 | return copy.copy(self) 70 | 71 | if prev_types is None: 72 | prev_types = self.prev_types 73 | assert isinstance(prev_types, dict), f"prev_types has type: {type(prev_types)}" 74 | assert len(prev_types) > 0 75 | 76 | types_pos = list[int]() 77 | inlined_spans = dict[int, slice]() 78 | new_tks = list[int]() 79 | tokenizer = DefaultTokenizer 80 | mask_id = tokenizer.mask_token_id 81 | comment_start = tokenizer.encode("/* ", add_special_tokens=False) 82 | comment_end = tokenizer.encode(" */", add_special_tokens=False) 83 | 84 | start = 0 85 | 86 | for t in range(len(self.types)): 87 | new_tks.extend(self.tokenized_code[start : self.types_pos[t]]) 88 | types_pos.append(len(new_tks)) 89 | type_tk = self.tokenized_code[self.types_pos[t]] 90 | if t in prev_types: 91 | assert type_tk == mask_id 92 | to_insert = tokenizer.encode( 93 | str(prev_types[t]), add_special_tokens=False 94 | ) 95 | if as_comment: 96 | to_insert = comment_start + to_insert + comment_end 97 | new_tks.extend(to_insert) 98 | inlined_spans[t] = slice(types_pos[t], len(new_tks)) 99 | else: 100 | new_tks.append(type_tk) 101 | start = self.types_pos[t] + 1 102 | new_tks.extend(self.tokenized_code[start:]) 103 | 104 | assert prev_types.keys() == inlined_spans.keys() 105 | 106 | return TokenizedSrc( 107 | file=self.file, 108 | repo=self.repo, 109 | types=self.types, 110 | types_pos=types_pos, 111 | types_str=self.types_str, 112 | types_tks=self.types_tks, 113 | types_info=self.types_info, 114 | main_code=self.main_code, 115 | tokenized_code=new_tks, 116 | preamble_code=self.preamble_code, 117 | tokenized_preamble=self.tokenized_preamble, 118 | prev_types=prev_types, 119 | inlined_spans=inlined_spans, 120 | feedbacks=self.feedbacks, 121 | ) 122 | 123 | def print_code(self, max_lines: int = 100, body_only: bool = False): 124 | "Print out the (decoded) token sequence" 125 | code = decode_tokens(self.tokenized_code) 126 | if not body_only: 127 | code = decode_tokens(self.tokenized_preamble) + code 128 | print_limited(code, max_lines) 129 | 130 | @staticmethod 131 | def inline_predictions( 132 | src: "TokenizedSrc", 133 | as_comment: bool, 134 | prev_types: dict[int, PythonType] | None = None, 135 | ): 136 | return src.inline_prev_predictions(as_comment=as_comment, prev_types=prev_types) 137 | 138 | 139 | @dataclass 140 | class PreprocessArgs: 141 | imports_in_preamble: bool = True 142 | stub_in_preamble: bool = True 143 | drop_comments: bool = True 144 | max_callees: int = 80 # only applicable to functional dataset 145 | max_callers: int = 20 # only applicable to functional dataset 146 | drop_env_types: bool = False # only applicable to functional dataset 147 | add_override_usages: bool = False # only applicable to functional dataset 148 | add_implicit_rel_imports: bool = True # only applicable to functional dataset 149 | 150 | 151 | def preprocess_code(code: str, args: PreprocessArgs) -> dict: 152 | """Preprocess the Python code to carve out all the type annotations. The original 153 | code is split into sequences at the type annotations.""" 154 | m = cst.parse_module(code) 155 | preamble_segs = list[str]() 156 | 157 | if args.drop_comments: 158 | m = remove_comments(m) 159 | if args.imports_in_preamble: 160 | m, imports = remove_imports(m) 161 | imports_part = cst.Module([cst.SimpleStatementLine([s]) for s in imports]) 162 | preamble_segs.append(imports_part.code) 163 | if args.stub_in_preamble: 164 | stub_m = stub_from_module(m) 165 | preamble_segs.append(stub_m.code) 166 | 167 | cst_code = m.code 168 | annots_info, types = collect_user_annotations(m) 169 | types_str = [ 170 | m.code_for_node(not_none(info.annot).annotation) for info in annots_info 171 | ] 172 | mask_annot = cst.Annotation(cst.Name(SpecialNames.TypeMask)) 173 | replaces = dict() 174 | for info in annots_info: 175 | replaces[info.path] = mask_annot 176 | new_code = apply_annotations(m, replaces).code 177 | code_segs = new_code.split(SpecialNames.TypeMask) 178 | 179 | assert ( 180 | len(code_segs) == len(types) + 1 181 | ), f"{len(code_segs)} != {len(types) + 1}. replaces: {replaces}\ncode: {new_code}" 182 | return { 183 | "preamble": "".join(preamble_segs), 184 | "code_segs": code_segs, 185 | "types": types, 186 | "types_str": types_str, 187 | "annots_info": annots_info, 188 | "cst_code": cst_code, 189 | "prev_types": None, 190 | } 191 | 192 | 193 | def tokenized_src_from_segs( 194 | file: Path, 195 | repo: Path, 196 | cst_code: str, 197 | preamble: str, 198 | code_segs: list[str], 199 | types: list[PythonType], 200 | types_str: list[str], 201 | annots_info: list[AnnotInfo], 202 | tokenized_preamble: list[int] | None = None, 203 | prev_types: dict[int, PythonType] | None = None, 204 | is_label: list[bool] | None = None, 205 | left_extra_tks: list[int] | None = None, 206 | right_extra_tks: list[int] | None = None, 207 | ) -> TokenizedSrc: 208 | tkn = DefaultTokenizer 209 | r = TokenizedSrc( 210 | file=file, 211 | repo=repo, 212 | main_code=cst_code, 213 | tokenized_code=list[int](), 214 | preamble_code=preamble, 215 | tokenized_preamble=tkn.encode(preamble, add_special_tokens=False) 216 | if tokenized_preamble is None 217 | else tokenized_preamble, 218 | types=list[PythonType](), 219 | types_pos=list[int](), 220 | types_str=list[str](), 221 | types_info=list[AnnotInfo](), 222 | types_tks=list[list[int]](), 223 | prev_types=prev_types, 224 | ) 225 | 226 | mask_id = not_none(tkn.mask_token_id) 227 | all_tks = r.tokenized_code 228 | if left_extra_tks: 229 | all_tks.extend(left_extra_tks) 230 | for i in range(len(code_segs) - 1): 231 | all_tks.extend(tkn.encode(code_segs[i], add_special_tokens=False)) 232 | if is_label is None or is_label[i]: 233 | r.types_pos.append(len(all_tks)) 234 | r.types.append(types[i]) 235 | r.types_tks.append(tkn.encode(str(types[i]), add_special_tokens=False)) 236 | r.types_str.append(types_str[i]) 237 | r.types_info.append(annots_info[i]) 238 | all_tks.append(mask_id) 239 | else: 240 | all_tks.extend(tkn.encode(types_str[i], add_special_tokens=False)) 241 | all_tks.extend(tkn.encode(code_segs[-1], add_special_tokens=False)) 242 | if right_extra_tks: 243 | all_tks.extend(right_extra_tks) 244 | 245 | return r 246 | 247 | 248 | def feedbacks_to_tokenized_src( 249 | src: TokenizedSrc, 250 | current_code: str, 251 | feedbacks: list[MypyFeedback], 252 | patch_predictions: bool = False, 253 | ) -> TokenizedSrc: 254 | try: 255 | m = cst.parse_module(current_code) 256 | except Exception as e: 257 | raise RuntimeError( 258 | f"Failed to parse file: '{src.file}' with content:\n{current_code}" 259 | ) from e 260 | m_code = m.code 261 | assert ( 262 | m_code.rstrip() == current_code.rstrip() 263 | ), f"String diffferences: {show_string_diff(current_code, m_code)}" 264 | current_annots, _ = collect_user_annotations(m) 265 | preds_map = dict[CodeRange, str]() 266 | types = list[PythonType]() 267 | prev_types = dict[int, PythonType]() 268 | types_str = list[str]() 269 | annots_info = list[AnnotInfo]() 270 | path2label_id = {info.path: i for i, info in enumerate(src.types_info)} 271 | 272 | for a in current_annots: 273 | if a.path in path2label_id: 274 | assert (r := a.annot_range) is not None 275 | assert (annot := a.annot) is not None 276 | prev_type = preds_map[r] = m.code_for_node(annot.annotation) 277 | li = path2label_id[a.path] 278 | prev_types[li] = parse_type_str(prev_type) 279 | types.append(src.types[li]) 280 | types_str.append(src.types_str[li]) 281 | annots_info.append(a) 282 | pos_to_msg = {f.position: f.message for f in feedbacks} 283 | new_code = patch_code_with_extra( 284 | current_code, preds_map, pos_to_msg, patch_predictions 285 | ) 286 | code_segs = new_code.split(SpecialNames.TypeMask) 287 | assert ( 288 | len(code_segs) == len(types) + 1 289 | ), f"{len(code_segs)} != {len(types)} + 1.\nNew Code:\n{new_code}" 290 | 291 | new_src = tokenized_src_from_segs( 292 | file=src.file, 293 | repo=src.repo, 294 | cst_code=new_code, 295 | preamble=src.preamble_code, 296 | tokenized_preamble=src.tokenized_preamble, 297 | code_segs=code_segs, 298 | types=types, 299 | types_str=types_str, 300 | annots_info=annots_info, 301 | prev_types=prev_types, 302 | ) 303 | new_src.feedbacks = feedbacks 304 | return new_src 305 | 306 | 307 | def patch_code_with_extra( 308 | code: str, 309 | predictions: dict[CodeRange, str], 310 | errors: dict[CodePosition, str], 311 | patch_predictions: bool, 312 | ) -> str: 313 | replaces = [] 314 | # When the ranges overlap, we want to use the order: new_prediction -> prev_prediction -> errors 315 | for r, t in predictions.items(): 316 | replaces.append((r, 1, SpecialNames.TypeMask)) 317 | if patch_predictions: 318 | replaces.append((CodeRange(r.start, r.start), 2, f"/* {t} */")) 319 | 320 | for p, e in errors.items(): 321 | replaces.append((CodeRange(p, p), 3, f"/* error: {e} */")) 322 | 323 | return replace_strs_by_pos(code, replaces) 324 | -------------------------------------------------------------------------------- /src/typet5/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import warnings 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import * 7 | 8 | import pytorch_lightning as pl 9 | import torch 10 | import torch.nn as nn 11 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 12 | from pytorch_lightning.loggers import WandbLogger 13 | from torch.optim import AdamW 14 | from transformers import DataCollatorForSeq2Seq 15 | from transformers.modeling_outputs import Seq2SeqLMOutput 16 | 17 | from .data import ChunkedDataset, TokenizedSrcSet 18 | from .model import ( 19 | CtxArgs, 20 | DecodingArgs, 21 | ModelType, 22 | ModelWrapper, 23 | TokenizerType, 24 | dynamic_dataloader, 25 | ) 26 | from .tokenized_src import PreprocessArgs 27 | from .type_check import TypeCheckArgs 28 | from .utils import * 29 | 30 | 31 | @dataclass 32 | class ModelTrainingArgs: 33 | train_ctx_args: CtxArgs 34 | dec_args: DecodingArgs 35 | train_max_tokens: int 36 | eval_max_tokens: int 37 | max_epochs: int 38 | tc_args: TypeCheckArgs 39 | accumulate_grad_batches: int | dict | None = None 40 | 41 | 42 | class TrainingConfig(NamedTuple): 43 | quicktest: bool = False 44 | func_only: bool = True # whether to use functional dataset format 45 | pre_args: PreprocessArgs = PreprocessArgs() 46 | trained_on: str = "ManyTypes4Py" 47 | data_reduction: int = 1 48 | check_in_isolation: bool = False # DAgger 49 | inline_prev_gold: bool = False 50 | ctx_size: int = 4096 51 | left_margin: int = 2048 52 | # up to how much of the left_margin to be allocated as preamble 53 | preamble_size: int = 1000 54 | right_margin: int = 2048 - 512 55 | train_max_labels: int = 32 56 | dec_max_labels: int = 16 57 | use_small_model: bool = False 58 | grad_accum_labels = 32 # DAgger 59 | modifications: str = "" 60 | 61 | def as_dict(self) -> dict[str, Any]: 62 | return {attr: getattr(self, attr) for attr in self.__annotations__} 63 | 64 | def as_name(self) -> str: 65 | return self.get_model_name() 66 | 67 | def __repr__(self): 68 | return repr_modified_args(self, flatten=True) 69 | 70 | def get_model_name(self) -> str: 71 | return "model-v7--" + repr_modified_args(self, flatten=True) 72 | 73 | def train_ctx_args(self) -> CtxArgs: 74 | return CtxArgs( 75 | ctx_size=self.ctx_size, 76 | preamble_size=self.preamble_size, 77 | left_margin=self.left_margin, 78 | right_margin=self.right_margin, 79 | max_labels=self.train_max_labels, 80 | inline_prev_gold=self.inline_prev_gold, 81 | ) 82 | 83 | def get_preprocess_args(self): 84 | return self.pre_args 85 | 86 | def dec_ctx_args(self) -> CtxArgs: 87 | r = self.train_ctx_args() 88 | r.max_labels = self.dec_max_labels 89 | return r 90 | 91 | 92 | def train_spot_model( 93 | tk_dataset: dict[str, TokenizedSrcSet], 94 | model_name: str, 95 | train_args: ModelTrainingArgs, 96 | gpus: list[int], 97 | quicktest=False, 98 | use_early_stop=False, 99 | use_small_model=False, 100 | ) -> ModelWrapper: 101 | os.chdir(proj_root()) 102 | train_ctx_args = train_args.train_ctx_args 103 | dec_args = train_args.dec_args 104 | 105 | running_dir = get_model_dir(False) / model_name 106 | if running_dir.exists(): 107 | shutil.rmtree(running_dir) 108 | running_dir.mkdir(parents=True, exist_ok=True) 109 | 110 | print("Disk space left:") 111 | subprocess.run(["df", "-h", str(running_dir)]) 112 | 113 | model_path = ModelWrapper.get_codet5_path(use_small_model) 114 | lit_model = TrainModelWrapper(model_path, model_saving_path=running_dir / "ckpts") 115 | tokenizer: TokenizerType = lit_model.tokenizer 116 | 117 | common_type_names = tk_dataset["train"].common_type_names() 118 | wrapper = ModelWrapper( 119 | lit_model.model, tokenizer, dec_args, common_type_names=common_type_names 120 | ) 121 | 122 | chunks: dict[str, ChunkedDataset] = {} 123 | with run_long_task("Preparing chunked datasets", notify=False): 124 | for n in ["valid", "train"]: 125 | src = tk_dataset[n] 126 | chunks[n] = src.to_chunks(train_ctx_args) 127 | 128 | wandb_logger = WandbLogger() # assuming a run has already been initialized 129 | 130 | collate_fn = DataCollatorForSeq2Seq(lit_model.tokenizer, lit_model.model) 131 | train_dataloader = dynamic_dataloader( 132 | cast(Any, chunks["train"].data), 133 | max_tokens=train_args.train_max_tokens, 134 | collate_fn=collate_fn, 135 | shuffle=True, 136 | ) 137 | valid_dataloader = dynamic_dataloader( 138 | cast(Any, chunks["valid"].data), 139 | max_tokens=train_args.eval_max_tokens, 140 | collate_fn=collate_fn, 141 | shuffle=True, # doesn't hurt 142 | ) 143 | 144 | ckpt_interval = max(1, len(train_dataloader) // 10) 145 | val_interval = 1 if quicktest else max(500, ckpt_interval) 146 | 147 | checkpoint_cb = ModelCheckpoint( 148 | dirpath=running_dir, 149 | save_top_k=3, 150 | monitor="valid/loss", 151 | mode="min", 152 | save_on_train_epoch_end=False, 153 | verbose=quicktest, 154 | ) 155 | 156 | trainer = pl.Trainer( 157 | default_root_dir=str(running_dir), 158 | # fast_dev_run=6 if quicktest else False, 159 | # log_every_n_steps=500, 160 | accelerator="gpu" if gpus else "cpu", 161 | devices=gpus, 162 | precision=16, 163 | max_epochs=train_args.max_epochs, 164 | logger=wandb_logger, 165 | val_check_interval=val_interval, 166 | callbacks=( 167 | [checkpoint_cb, EarlyStopping("valid/loss", mode="min", verbose=quicktest)] 168 | if use_early_stop 169 | else [] 170 | ), 171 | gradient_clip_val=1.0, 172 | gradient_clip_algorithm="norm", 173 | accumulate_grad_batches=train_args.accumulate_grad_batches, 174 | # track_grad_norm=2, 175 | ) 176 | 177 | warnings.filterwarnings("ignore", "The dataloader.*does not have many workers.*") 178 | 179 | with run_long_task(f"Training {model_name}", notify=False): 180 | trainer.fit( 181 | model=lit_model, 182 | train_dataloaders=train_dataloader, 183 | val_dataloaders=valid_dataloader, 184 | ) 185 | 186 | save_dir = get_model_dir(True) / model_name 187 | 188 | final_eval = trainer.validate(model=lit_model, dataloaders=valid_dataloader)[0] 189 | 190 | try: 191 | if ( 192 | use_early_stop 193 | and (best_loss := checkpoint_cb.best_model_score) is not None 194 | and best_loss < final_eval["valid/loss"] 195 | ): 196 | print( 197 | f"Loading best model with score {best_loss} from: {checkpoint_cb.best_model_path}" 198 | ) 199 | wrapper.model = TrainModelWrapper.load_from_checkpoint( 200 | checkpoint_cb.best_model_path 201 | ).model 202 | if save_dir.exists(): 203 | shutil.rmtree(save_dir) 204 | save_dir.mkdir(parents=True, exist_ok=True) 205 | 206 | wrapper.save(save_dir) 207 | shutil.rmtree(running_dir) 208 | except Exception as e: 209 | logging.error( 210 | "Error encountered after training, returning partial results... Error:\n", e 211 | ) 212 | 213 | return wrapper 214 | 215 | 216 | class TrainModelWrapper(pl.LightningModule): 217 | "A pytorch lightening module that handles training and evaluation of the SPOT model." 218 | 219 | def __init__( 220 | self, model_checkpoint: str | Path, *, model_saving_path: Path 221 | ) -> None: 222 | super().__init__() 223 | self.save_hyperparameters() 224 | self.model: ModelType = load_model_spot(model_checkpoint) 225 | self.tokenizer: TokenizerType = TokenizerType.from_pretrained(model_checkpoint) 226 | self.model_saving_path = model_saving_path 227 | self.model_saving_interval: Optional[int] = None 228 | self.avg_loss = MovingAvg(alpha=0.01) 229 | self.labels_trained = 0 230 | 231 | def on_fit_start(self): 232 | # maps chunk id to the initial predictions made for that chunk immediately 233 | # before the model was trained on it 234 | if self.model_saving_interval is not None: 235 | self.batch_ids: list[list[int]] = [] 236 | self.saving_counter = 0 237 | self.model.save_pretrained(self.model_saving_path / f"n_batches=0") 238 | 239 | def configure_optimizers(self): 240 | return _configure_optimizers(self.model) 241 | 242 | def training_step(self, batch, batch_idx): 243 | if self.model_saving_interval is not None and self.current_epoch == 0: 244 | self.batch_ids.append(batch["chunk_id"].tolist()) 245 | self.saving_counter += 1 246 | if self.saving_counter >= self.model_saving_interval: 247 | self.saving_counter = 0 248 | # model can be used for `n_batches` and onward. 249 | self.model.save_pretrained( 250 | self.model_saving_path / f"n_batches={len(self.batch_ids)}" 251 | ) 252 | 253 | outputs = self.model.forward( 254 | input_ids=batch["input_ids"], 255 | attention_mask=batch["attention_mask"], 256 | labels=batch["labels"], 257 | ) 258 | assert isinstance(outputs, Seq2SeqLMOutput) 259 | loss = not_none(outputs.loss) 260 | n_labels = batch["n_labels"].sum().item() 261 | self.labels_trained += n_labels 262 | self.avg_loss.update(loss.item()) 263 | self.log("train/loss", self.avg_loss.value) 264 | self.log("train/lr", self.lr_schedulers().get_last_lr()[0]) # type: ignore 265 | self.log("train/labels", float(self.labels_trained)) 266 | return loss 267 | 268 | def validation_step(self, batch, batch_idx): 269 | outputs = self.model( 270 | input_ids=batch["input_ids"], 271 | attention_mask=batch["attention_mask"], 272 | labels=batch["labels"], 273 | ) 274 | loss = outputs.loss 275 | self.log("valid/loss", loss.item()) 276 | self.log("train/labels", float(self.labels_trained)) 277 | 278 | 279 | def concat_batches(batches: list[dict], keys: list[str]) -> dict: 280 | return {k: torch.concat([b[k] for b in batches]) for k in keys} 281 | 282 | 283 | def _configure_optimizers(model: nn.Module, base_lr: float = 2e-5): 284 | no_decay = ["bias", "LayerNorm.weight"] 285 | grouped_params = [ 286 | { 287 | "params": [ 288 | p 289 | for pn, p in model.named_parameters() 290 | if not any(n in pn for n in no_decay) 291 | ], 292 | "weight_decay": 0.01, 293 | }, 294 | { 295 | "params": [ 296 | p 297 | for pn, p in model.named_parameters() 298 | if any(n in pn for n in no_decay) 299 | ], 300 | "weight_decay": 0.0, 301 | }, 302 | ] 303 | optimizer = AdamW(grouped_params, lr=base_lr) 304 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.2) 305 | return [optimizer], [lr_scheduler] 306 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_func_decoding.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from typet5.function_dataset import FunctionSignature 7 | from typet5.static_analysis import cst 8 | from typet5.static_analysis import to_abs_import_path as to_abs 9 | from typet5.utils import * 10 | from typet5.utils import assert_eq, groupby, not_none, show_string_diff 11 | 12 | 13 | def test_function_signature(): 14 | ex_code = """ 15 | def f(x, y: int=3, *, v=3, **kwargs) -> int: 16 | u: int 17 | return 1 18 | """ 19 | 20 | f = cast(cst.FunctionDef, cst.parse_module(ex_code).body[0]) 21 | sig = FunctionSignature.from_function(f, False) 22 | 23 | new_sig = copy.deepcopy(sig) 24 | new_sig.params["x"] = cst.Annotation(cst.parse_expression("list[int]")) 25 | assert ( 26 | "def f(x: list[int], y: int=3, *, v=3, **kwargs) -> int" 27 | in cst.Module([new_sig.apply(f)]).code 28 | ) 29 | 30 | new_sig = copy.deepcopy(sig) 31 | new_sig.params["y"] = cst.Annotation(cst.parse_expression("list[int]")) 32 | assert ( 33 | "def f(x, y: list[int]=3, *, v=3, **kwargs) -> int" 34 | in cst.Module([new_sig.apply(f)]).code 35 | ) 36 | 37 | new_sig = copy.deepcopy(sig) 38 | new_sig.params["v"] = cst.Annotation(cst.parse_expression("list[int]")) 39 | assert ( 40 | "def f(x, y: int=3, *, v: list[int]=3, **kwargs) -> int" 41 | in cst.Module([new_sig.apply(f)]).code 42 | ) 43 | 44 | new_sig = copy.deepcopy(sig) 45 | new_sig.params["kwargs"] = cst.Annotation(cst.parse_expression("list[int]")) 46 | assert ( 47 | "def f(x, y: int=3, *, v=3, **kwargs: list[int]) -> int" 48 | in cst.Module([new_sig.apply(f)]).code 49 | ) 50 | 51 | new_sig = copy.deepcopy(sig) 52 | new_sig.returns = cst.Annotation(cst.parse_expression("list[int]")) 53 | assert ( 54 | "def f(x, y: int=3, *, v=3, **kwargs) -> list[int]" 55 | in cst.Module([new_sig.apply(f)]).code 56 | ) 57 | 58 | 59 | def test_method_signature(): 60 | ex_code = """ 61 | def f(self, x, y): 62 | u: int 63 | return 1 64 | """ 65 | 66 | f = cast(cst.FunctionDef, cst.parse_module(ex_code).body[0]) 67 | sig = FunctionSignature.from_function(f, False) 68 | assert len(sig.params) == 2 69 | 70 | ex_code2 = """ 71 | def f(a, x, y): 72 | u: int 73 | return 1 74 | """ 75 | 76 | f = cast(cst.FunctionDef, cst.parse_module(ex_code2).body[0]) 77 | sig = FunctionSignature.from_function(f, False) 78 | assert len(sig.params) == 3 79 | 80 | ex_code3 = """ 81 | def f(a=lambda x: x): 82 | return 1 83 | """ 84 | 85 | f = cast(cst.FunctionDef, cst.parse_module(ex_code3).body[0]) 86 | sig = FunctionSignature.from_function(f, False) 87 | assert len(sig.params) == 1 88 | -------------------------------------------------------------------------------- /tests/test_model_creation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig 3 | from transformers.models.t5.configuration_t5 import T5Config 4 | 5 | from typet5.data import CtxArgs 6 | from typet5.model import DecodingArgs, ModelType, ModelWrapper 7 | from typet5.train import TrainingConfig 8 | from typet5.utils import DefaultTokenizer 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def test_basic_torch_operations(): 14 | x = torch.randn(10).to(device) 15 | assert x.sum() <= 10 16 | 17 | 18 | def test_model_creation(): 19 | config = AutoConfig.from_pretrained(ModelWrapper.get_codet5_path()) 20 | model = ModelType(config).to(device) 21 | 22 | ctx_args = TrainingConfig().dec_ctx_args() 23 | dec_args = DecodingArgs(ctx_args, 32) 24 | wrapper = ModelWrapper(model, DefaultTokenizer, dec_args, set()) 25 | wrapper.to(device) 26 | 27 | ids = DefaultTokenizer.encode( 28 | "def get_count() -> : ...", return_tensors="pt" 29 | ) 30 | batch = {"input_ids": ids, "n_labels": [1]} 31 | out = wrapper.predict_on_batch(batch) 32 | assert True 33 | -------------------------------------------------------------------------------- /tests/test_type_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from typet5.static_analysis import FunctionSignature, mask_types 8 | from typet5.tokenized_src import PreprocessArgs 9 | from typet5.type_check import MypyResult, PythonType, remove_top_optional 10 | from typet5.type_env import ( 11 | AnnotCat, 12 | AnnotPath, 13 | AnyAnnot, 14 | SelectAnnotations, 15 | TypeInfAction, 16 | annot_path, 17 | apply_annotations, 18 | collect_annots_info, 19 | collect_user_annotations, 20 | mypy_checker, 21 | normalize_type, 22 | parse_type_str, 23 | type_inf_env, 24 | ) 25 | from typet5.utils import ( 26 | SpecialNames, 27 | as_any, 28 | assert_eq, 29 | cst, 30 | proj_root, 31 | read_file, 32 | write_file, 33 | ) 34 | 35 | os.chdir(proj_root()) 36 | 37 | 38 | def test_annotation_collection(): 39 | parsed = cst.parse_module(read_file("data/code/env_code_2.py")) 40 | annots = collect_annots_info(parsed) 41 | annot_paths = [(a.path, a.cat) for a in annots] 42 | correct_annot_paths: list[tuple[AnnotPath, AnnotCat]] = [ 43 | (annot_path("fib", "n"), AnnotCat.FuncArg), 44 | (annot_path("fib", SpecialNames.Return), AnnotCat.FuncReturn), 45 | (annot_path("foo", "bar"), AnnotCat.FuncArg), 46 | (annot_path("foo", SpecialNames.Return), AnnotCat.FuncReturn), 47 | (annot_path("Bar", "z"), AnnotCat.ClassAtribute), 48 | (annot_path("Bar", "w"), AnnotCat.ClassAtribute), 49 | (annot_path("Bar", "__init__", SpecialNames.Return), AnnotCat.FuncReturn), 50 | (annot_path("Bar", "__init__", "x"), AnnotCat.FuncArg), 51 | (annot_path("Bar", "__init__", "self.x"), AnnotCat.ClassAtribute), 52 | (annot_path("Bar", "__init__", "self.y"), AnnotCat.ClassAtribute), 53 | (annot_path("Bar", "reset", "w0"), AnnotCat.FuncArg), 54 | (annot_path("Bar", "reset", SpecialNames.Return), AnnotCat.FuncReturn), 55 | (annot_path("Bar", "foo", "z"), AnnotCat.FuncArg), 56 | (annot_path("Bar", "foo", SpecialNames.Return), AnnotCat.FuncReturn), 57 | (annot_path("bar"), AnnotCat.GlobalVar), 58 | ] 59 | for pair in correct_annot_paths: 60 | assert pair in annot_paths 61 | for pair in annot_paths: 62 | assert pair in correct_annot_paths 63 | 64 | 65 | def test_self_parameter_annotation(): 66 | code = """ 67 | def foo(self: float, x: int) -> str: 68 | return "1" 69 | """ 70 | parsed = cst.parse_module(code) 71 | _, types = collect_user_annotations(parsed) 72 | 73 | assert_eq(types, [PythonType.from_name("int"), PythonType.from_name("str")]) 74 | n_segs = len(mask_types(parsed).code.split(SpecialNames.TypeMask)) 75 | assert_eq(n_segs, len(types) + 1) 76 | 77 | sig = FunctionSignature.from_function(as_any(parsed.body[0]), False) 78 | assert len(sig.params) == len(types) - 1 79 | 80 | 81 | parsed = cst.parse_module(read_file("data/code/bad_code_1.py")) 82 | 83 | 84 | code_1_patch = { 85 | annot_path("fib", "n"): cst.Annotation(cst.Name("int")), 86 | annot_path("fib", SpecialNames.Return): cst.Annotation(cst.Name("int")), 87 | annot_path("t_add", SpecialNames.Return): cst.Annotation(cst.Name("str")), 88 | annot_path("bad_y"): AnyAnnot, 89 | } 90 | 91 | 92 | def test_annotation_applying(): 93 | old_annots = collect_annots_info(parsed) 94 | old_map = {a.path: a.annot for a in old_annots if a.annot is not None} 95 | new_parsed = apply_annotations(parsed, code_1_patch) 96 | new_annots = collect_annots_info(new_parsed) 97 | new_map = {a.path: a.annot for a in new_annots if a.annot is not None} 98 | 99 | for k, v in code_1_patch.items(): 100 | assert old_map[k].annotation.value != new_map[k].annotation.value # type: ignore 101 | assert new_map[k].annotation.value == v.annotation.value # type: ignore 102 | 103 | 104 | @pytest.mark.skip("Not used.") 105 | def test_mypy_checker_1(): 106 | with mypy_checker(Path("data/code"), wait_before_check=0.0) as checker: 107 | check_r = checker.recheck_project() 108 | assert isinstance(check_r, MypyResult) 109 | assert Path("data/code/bad_code_1.py").resolve() in check_r.error_dict 110 | assert Path("data/code/bad_code_2.py").resolve() in check_r.error_dict 111 | 112 | 113 | @pytest.mark.skip("Not used.") 114 | def test_mypy_checker_2(): 115 | with mypy_checker(Path("data/code_output"), wait_before_check=0.0) as checker: 116 | if Path("data/code_output/bad_code_1.py").exists(): 117 | os.remove("data/code_output/bad_code_1.py") 118 | oe = checker.recheck_project().num_errors 119 | write_file("data/code_output/bad_code_1.py", parsed.code) 120 | assert checker.recheck_project().num_errors > oe 121 | new_code = apply_annotations(parsed, code_1_patch).code 122 | write_file( 123 | "data/code_output/bad_code_1.py", 124 | new_code, 125 | ) 126 | c_r = checker.recheck_project() 127 | assert c_r.num_errors == oe, f"mypy_output: {c_r.output_str}\ncode: {new_code}" 128 | 129 | 130 | def test_type_parsing(): 131 | # test quoted types 132 | assert parse_type_str("'Foo[int]'") == parse_type_str("Foo[int]") 133 | assert parse_type_str('"Bar"') == parse_type_str("Bar") 134 | 135 | 136 | def test_type_normalization(): 137 | equiv_pairs: list[tuple[str, str]] = [ 138 | ("list[int]", "List[int]"), 139 | ("dict[str, list]", "Dict[str, List]"), 140 | ("'Foo[int]'", "Foo[int]"), 141 | ("typing.Union[str, List]", "typing.Union[list, str]"), 142 | ("typing.Union[str, typing.Union[str, int]]", "str | int"), 143 | ("typing.Union[str, float, typing.Union[str, int]]", "str | int | float"), 144 | ("Union[str, float, None]", "Optional[Union[str, float]]"), 145 | ("str | None", "Optional[str]"), 146 | ("Any | None", "Optional"), 147 | ("List[Any]", "List"), 148 | ("Dict[Any, Any]", "Dict"), 149 | ] 150 | 151 | for a, b in equiv_pairs: 152 | ta = parse_type_str(a) 153 | tb = parse_type_str(b) 154 | assert normalize_type(ta) == normalize_type(tb) 155 | 156 | nonequiv_pairs: list[tuple[str, str]] = [ 157 | ("Union[str, int]", "Union[str, list]"), 158 | ("typing.List[str]", "t.List[str]"), 159 | ("tuple[str, int]", "tuple[int, str]"), 160 | ("Dict[str, Any]", "Dict"), 161 | ] 162 | 163 | for a, b in nonequiv_pairs: 164 | ta = parse_type_str(a) 165 | tb = parse_type_str(b) 166 | assert normalize_type(ta) != normalize_type(tb) 167 | 168 | dict_ex = PythonType.from_str("Dict[Any, Any] | None") 169 | assert remove_top_optional(normalize_type(dict_ex)) == PythonType.from_name("Dict") 170 | 171 | dict2 = PythonType.from_str("Dict[str, Any]") 172 | assert normalize_type(dict2) == dict2 173 | 174 | 175 | import shutil 176 | 177 | from typet5.data import TokenizedSrcSet, type_check_src, type_check_src_in_project 178 | from typet5.utils import load_tokenizer_spot, proj_root 179 | 180 | 181 | @pytest.mark.skip("Not using type checker for the moment") 182 | def test_mypy_checking(): 183 | simple_dataset = TokenizedSrcSet.from_repos( 184 | proj_root() / "data", 185 | [proj_root() / "data/code"], 186 | PreprocessArgs(drop_comments=True), 187 | ) 188 | 189 | src_to_check = simple_dataset.get_src_by_file(Path("code/bad_code_2.py")) 190 | result_1 = type_check_src(src_to_check, {0: "int"}) 191 | assert len(result_1.feedbacks) == 0 192 | 193 | temp_dir = proj_root() / "mypy_temp/test_dir" 194 | shutil.rmtree(temp_dir, ignore_errors=True) 195 | 196 | with simple_dataset.setup_typechecking( 197 | [src_to_check], 198 | cleanup=True, 199 | skip_pre_fdbks=True, 200 | ) as env: 201 | result_2 = type_check_src_in_project( 202 | src_to_check, 203 | {0: "int"}, 204 | (env.template_root / "code"), 205 | "Skip", 206 | ) 207 | assert isinstance(result_2.feedbacks, list) and len(result_2.feedbacks) == 0 208 | 209 | rs = simple_dataset.type_check_each_file_in_project( 210 | [ 211 | ((simple_dataset.repos_root / "code/bad_code_2.py"), {0: "int"}), 212 | ((simple_dataset.repos_root / "code/bad_code_1.py"), {0: "str"}), 213 | ], 214 | ) 215 | fdbks3 = rs[0].feedbacks 216 | assert isinstance(fdbks3, list) and len(fdbks3) == 0 217 | # assert ( 218 | # 'Argument 1 to "fib" has incompatible type "int"; expected "str"' 219 | # in fdbks3[0].message 220 | # ) 221 | --------------------------------------------------------------------------------