├── .env
├── .gitattribute
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── Pipfile
├── Pipfile.lock
├── README.md
├── data
├── 9.11.accs_by_project.csv
├── 9.12.accs_by_project.csv
├── 9.14.accs_by_project.csv
├── TypeT5-Workflow.png
├── code
│ ├── bad_code_1.py
│ ├── bad_code_2.py
│ ├── code_with_slash.py
│ ├── dummy
│ │ ├── __init__.py
│ │ ├── dummy_1.py
│ │ └── dummy_2.py
│ ├── env_code_1.py
│ ├── env_code_2.py
│ └── good_code_1.py
├── ex_repo
│ ├── ex_code_1.py
│ └── ex_code_2.py
├── mypy-dependents-by-stars.json
├── repos_split.pkl
└── useful_repos.pkl
├── requirements.txt
├── scripts
├── analyze_dataset.ipynb
├── analyze_decoding_results.ipynb
├── archive
│ ├── analyze_dagger.ipynb
│ ├── code-t5-workflow.ipynb
│ ├── debug_mypy.ipynb
│ ├── fine_tune_t5.ipynb
│ ├── inference_spot.ipynb
│ ├── kill_dmypy.sh
│ ├── test_inf_env.ipynb
│ ├── train_dagger.py
│ └── train_spot.ipynb
├── collect_dataset.ipynb
├── experiments
│ ├── eval_file_model.ipynb
│ ├── eval_func_model.py
│ ├── run_hityper.ipynb
│ ├── run_type4py.ipynb
│ ├── run_typilus.ipynb
│ └── type_check_decoding.ipynb
├── prepare_dataset.ipynb
├── run_func_decoding.ipynb
├── run_typet5.ipynb
├── scratch.ipynb
└── train_model.py
├── setup.py
├── src
└── typet5
│ ├── __init__.py
│ ├── data.py
│ ├── decode.py
│ ├── experiments
│ ├── __init__.py
│ ├── hityper.py
│ ├── type4py.py
│ ├── typet5.py
│ ├── typilus.py
│ └── utils.py
│ ├── function_dataset.py
│ ├── function_decoding.py
│ ├── model.py
│ ├── static_analysis.py
│ ├── tokenized_src.py
│ ├── train.py
│ ├── type_check.py
│ ├── type_env.py
│ ├── utils.py
│ └── visualization.py
└── tests
├── __init__.py
├── test_func_decoding.py
├── test_model_creation.py
├── test_static_analysis.py
└── test_type_env.py
/.env:
--------------------------------------------------------------------------------
1 | LIBCST_PARSER_TYPE=native
2 |
--------------------------------------------------------------------------------
/.gitattribute:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-vendored
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | config
2 | *.egg-info
3 | code_output
4 | temp
5 | wandb
6 | checkpoints
7 | caches
8 | lightning_logs/
9 | output/
10 |
11 | build/
12 | __pycache__
13 | *.py[cod]
14 | *~
15 | /build
16 | /env*/
17 | docs/build/
18 | docs/source/_build
19 | mypyc/doc/_build
20 | *.iml
21 | /out/
22 | .venv
23 | venv/
24 | mypy_temp
25 | .mypy_cache/
26 | .incremental_checker_cache.json
27 | .cache
28 | dmypy.json
29 | .dmypy.json
30 | .coeditor_logs
31 |
32 | # Packages
33 | *.egg
34 | *.egg-info
35 | *.eggs
36 |
37 | # IDEs
38 | .idea
39 | .vscode
40 |
41 | # vim temporary files
42 | .*.sw?
43 | *.sw?
44 |
45 | # Operating Systems
46 | .DS_Store
47 |
48 | # Coverage Files
49 | htmlcov
50 | .coverage*
51 |
52 | # pytest cache
53 | .pytest_cache/
54 |
55 | # virtualenv
56 | .Python
57 | bin/
58 | lib/
59 | include/
60 | .python-version
61 | pyvenv.cfg
62 |
63 | .tox
64 | pip-wheel-metadata
65 |
66 |
67 | test_capi
68 | *.o
69 | *.a
70 | test_capi
71 | /.mypyc-flake8-cache.json
72 | /mypyc/lib-rt/build/
73 | /mypyc/lib-rt/*.so
74 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: 'tests/testcases'
2 |
3 | default_language_version:
4 | python: python3.11
5 |
6 | repos:
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v3.2.0
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: end-of-file-fixer
12 | - id: check-yaml
13 | - id: check-added-large-files
14 |
15 | - repo: https://github.com/pycqa/isort
16 | rev: 5.11.4
17 | hooks:
18 | - id: isort
19 | args: ["--profile", "black", "--filter-files"]
20 |
21 | - repo: https://github.com/psf/black
22 | rev: 22.12.0
23 | hooks:
24 | - id: black
25 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, Jiayi Wei
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | url = "https://pypi.org/simple"
3 | verify_ssl = true
4 | name = "pypi"
5 |
6 | [packages]
7 | mypy = "~=0.971"
8 | typet5 = {editable = true, path = "."}
9 | tqdm = "*"
10 | types-all = "*"
11 | dateparser = "~=1.1.1"
12 | pyrsistent = "*"
13 | plotly = "~=5.10"
14 | pandas = "~=1.4"
15 | ipywidgets = "~=8.0"
16 | datasets = "~=2.4.0"
17 | transformers = "~=4.21.3"
18 | wandb = "~=0.13.2"
19 | libcst = "<=0.4.2"
20 | pytorch-lightning = "~=1.7.5"
21 | colored = "~=1.4"
22 | ipykernel = "*"
23 | termcolor = "~=1.0"
24 | prettytable = "~=3.4.1"
25 | huggingface-hub = "*"
26 |
27 | [dev-packages]
28 | pytest = "*"
29 | black = "*"
30 | exceptiongroup = "*"
31 |
32 | [requires]
33 | python_version = "3.10"
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TypeT5: Seq2seq Type Inference using Static Analysis
2 |
3 |
4 |
5 | This repo contains the source code for the paper [TypeT5: Seq2seq Type Inference using Static Analysis](https://openreview.net/forum?id=4TyNEhI2GdN¬eId=EX_-kP9xah).
6 |
7 | ```
8 | @inproceedings{Wei2023TypeT5,
9 | title={TypeT5: Seq2seq Type Inference using Static Analysis},
10 | author={Jiayi Wei and Greg Durrett and Isil Dillig},
11 | booktitle={International Conference on Learning Representations},
12 | year={2023},
13 | url={https://openreview.net/forum?id=4TyNEhI2GdN}
14 | }
15 | ```
16 |
17 | ## Installation
18 |
19 | This project uses [pipenv](https://pipenv.pypa.io/en/latest/) to manage the package dependencies. Pipenv tracks the exact package versions and manages the (project-specific) virtual environment for you. To install all dependencies, make sure you have pipenv and Python 3.10 installed, then, at the project root, run the following two commands:
20 | ```bash
21 | pipenv --python # create a new environment for this project
22 | pipenv sync --dev # install all specificed dependencies
23 | ```
24 |
25 | More about pipenv:
26 | - To add new dependences into the virtual environment, you can either add them via `pipenv install ..` (using `pipenv`) or `pipenv run pip install ..` (using `pip` from within the virtual environment).
27 | - If your pytorch installation is not working properly, you might need to reinstall it via the `pipenv run pip install` approach rather than `pipenv install`.
28 | - All `.py` scripts below can be run via `pipenv run python `. For `.ipynb` notebooks, make sure you select the pipenv environment as the kernel. You can run all unit tests by running `pipenv run pytest` at the project root.
29 |
30 | If you are not using pipenv:
31 | - Make sure to add the environment variables in the [.env](.env) file to your shell environment when you run the scripts (needed by the parsing library).
32 | - We also provided a [requirements.txt](requirements.txt) file for you to install the dependencies via `pip install -r requirements.txt`.
33 |
34 |
35 | ## Using the trained model
36 | The notebook [scripts/run_typet5.ipynb](scripts/run_typet5.ipynb) shows you how to download the TypeT5 model from Huggingface and then use it to make type predictions for a specified codebase.
37 |
38 | ## Training a New Model
39 |
40 | - First, run the notebook [scripts/collect_dataset.ipynb](scripts/collect_dataset.ipynb) to download and split the BetterTypes4Py dataset used in our paper.
41 | - The exact list of repos we used for the experiments in paper can be loaded from `data/repos_split.pkl` using `pickle.load`. They can also be downloaded via this [Google Drive link](https://drive.google.com/drive/folders/1lXKtwi7AOI-w4ESgMi7J5YAHRGP-JhG5?usp=sharing).
42 | - Then, run [scripts/train_model.py](scripts/train_model.py) to train a new TypeT5 model. Training takes about 11 hours on a single Quadro RTX 8000 GPU with 48GB memory.
43 |
44 |
45 | ## Development
46 | - Formatter: We use `black` for formatting with the default options.
47 | - Type Checker: We use Pylance to type check this codebase. It's the built-in type checker shipped with the VSCode Python extension and can be enabled by setting `Python > Anlaysis > Type Checking Mode` to `basic`.
48 |
--------------------------------------------------------------------------------
/data/9.11.accs_by_project.csv:
--------------------------------------------------------------------------------
1 | ,project,no-neighbors,non-incr,double-traversal,label_rate,labels
2 | 0,srittau__FakeSMTPd,0.8362068965517241,0.8793103448275862,0.896551724137931,0.3483483483483483,116
3 | 1,scalableminds__webknossos-connect,0.6757188498402555,0.7507987220447284,0.7763578274760383,0.830238726790451,626
4 | 2,flopp__unicode-explorer,0.8205128205128205,0.8547008547008547,0.9230769230769231,0.5763546798029556,117
5 | 3,eirannejad__calcatime,0.46875,0.53125,0.5,0.8,32
6 | 4,jelford__webwatcher,0.7659574468085106,0.7659574468085106,0.7659574468085106,0.2186046511627907,47
7 | 5,road-master__video-archiver,0.8536585365853658,0.9105691056910569,0.9186991869918699,0.41694915254237286,123
8 | 6,flopp__GpxTrackPoster,0.5498154981549815,0.6568265682656826,0.7158671586715867,0.6545893719806763,271
9 | 7,boompig__book-classics,0.9058823529411765,0.9411764705882353,0.9411764705882353,0.5862068965517241,85
10 | 8,dropbox__sqlalchemy-stubs,0.5952380952380952,0.6666666666666666,0.6428571428571429,0.6086956521739131,42
11 | 9,reddit__baseplate.py-upgrader,0.7548387096774194,0.7612903225806451,0.8129032258064516,0.4305555555555556,155
12 | 10,typeddjango__pytest-mypy-plugins,0.7565217391304347,0.782608695652174,0.782608695652174,0.732484076433121,115
13 | 11,linw1995__data_extractor,0.6,0.676923076923077,0.6538461538461539,0.23090586145648312,130
14 | 12,albertyw__albertyw.com,0.5285714285714286,0.6285714285714286,0.6428571428571429,0.37433155080213903,70
15 | 13,marcosschroh__dataclasses-avroschema,0.5056947608200456,0.5034168564920274,0.4760820045558087,0.42787524366471735,439
16 | 14,antonagestam__collectfast,0.6896551724137931,0.7155172413793104,0.7586206896551724,0.5155555555555555,116
17 | 15,ammarnajjar__pautomate,0.8518518518518519,0.8703703703703703,0.9259259259259259,0.5046728971962616,54
18 | 16,brettkromkamp__topic-db,0.8476190476190476,0.8761904761904762,0.9238095238095239,0.3633217993079585,105
19 | 17,basilisp-lang__basilisp,0.5882352941176471,0.6233979625369701,0.6194544857048965,0.5269264069264069,3043
20 | 18,paulcwatts__drf-json-schema,0.6590909090909091,0.6439393939393939,0.6931818181818182,0.528,264
21 | 19,ShadowTemplate__beautiful-python-3,0.8991596638655462,0.957983193277311,0.9411764705882353,0.3287292817679558,119
22 | 20,AxelVoitier__lookups,0.3424657534246575,0.4155251141552511,0.3789954337899543,0.33435114503816793,219
23 | 21,Gerschtli__teamspeak-update-notifier,0.6,0.6666666666666666,0.6857142857142857,0.5497382198952879,105
24 | 22,joshtemple__lkml,0.7317073170731707,0.7658536585365854,0.7707317073170732,0.39575289575289574,205
25 | 23,kitsuyui__bamboo-crawler,0.6331360946745562,0.6449704142011834,0.6331360946745562,0.6450381679389313,169
26 | 24,amplify-education__python-hcl2,0.6301369863013698,0.6438356164383562,0.821917808219178,0.73,73
27 | 25,nubark__instark,0.8165374677002584,0.8578811369509044,0.8578811369509044,0.3467741935483871,387
28 | 26,albertyw__git-browse,0.5944444444444444,0.7166666666666667,0.7277777777777777,0.46272493573264784,180
29 | 27,kornicameister__axion,0.4909596662030598,0.5090403337969402,0.45479833101529904,0.6301489921121823,719
30 | 28,payscale__fables,0.7338709677419355,0.7903225806451613,0.75,0.3668639053254438,124
31 | 29,nabla-c0d3__sslyze,0.7120085015940489,0.7523910733262487,0.7630180658873539,0.5689238210399032,941
32 | 30,rakitaj__daily-programmer,0.737037037037037,0.7425925925925926,0.7666666666666667,0.5009276437847866,540
33 | 31,silasary__discord_connotations,0.6153846153846154,0.6923076923076923,0.6923076923076923,0.2708333333333333,13
34 | 32,JakobGM__quelf,0.5434782608695652,0.5652173913043478,0.5543478260869565,0.5287356321839081,92
35 | 33,ocf__slackbridge,0.6935483870967742,0.7096774193548387,0.7741935483870968,0.5767441860465117,124
36 | 34,sonic182__aiosonic,0.8393782383419689,0.8393782383419689,0.7927461139896373,0.291981845688351,193
37 | 35,lucaswerkmeister__tool-quickcategories,0.7580645161290323,0.7870967741935484,0.7645161290322581,0.6581740976645435,310
38 | 36,ohjames__babies,0.8507462686567164,0.8805970149253731,0.8805970149253731,0.32524271844660196,67
39 | 37,ClearcodeHQ__mirakuru,0.6907216494845361,0.7422680412371134,0.7731958762886598,0.383399209486166,97
40 | 38,futursolo__magichttp,0.509009009009009,0.6261261261261262,0.6981981981981982,0.35406698564593303,222
41 | 39,ActivityWatch__aw-research,0.8205128205128205,0.8376068376068376,0.8205128205128205,0.3667711598746082,117
42 | 40,lebrice__blurred-GAN,0.7012987012987013,0.6233766233766234,0.6883116883116883,0.18421052631578946,77
43 | 41,webrecorder__browsertrix,0.61328125,0.6875,0.6875,0.30732292917166865,256
44 | 42,seattleflu__id3c,0.7381889763779528,0.8326771653543307,0.8523622047244095,0.4425087108013937,508
45 | 43,TomerFi__aioswitcher,0.7263157894736842,0.8105263157894737,0.8947368421052632,0.336283185840708,190
46 | 44,jfly__jfly.github.io,0.696969696969697,0.8181818181818182,0.9696969696969697,0.23741007194244604,33
47 | 45,jreese__aql,0.5487179487179488,0.5743589743589743,0.5641025641025641,0.437219730941704,195
48 | 46,yeraydiazdiaz__wait_for_it.py,1.0,1.0,1.0,0.375,6
49 | 47,cliffxuan__mew,0.7638888888888888,0.8333333333333334,0.7777777777777778,0.36923076923076925,72
50 | 48,everyclass__everyclass-server,0.7237237237237237,0.7897897897897898,0.7477477477477478,0.3729003359462486,333
51 | 49,knowark__estimark,0.8630705394190872,0.8423236514522822,0.8464730290456431,0.33565459610027853,241
52 |
--------------------------------------------------------------------------------
/data/9.12.accs_by_project.csv:
--------------------------------------------------------------------------------
1 | ,project,no-neighbors,non-incr,double-traversal,label_size,label_rate,labels
2 | 17,basilisp-lang__basilisp,0.5882352941176471,0.6233979625369701,0.6194544857048965,1.2431810713112061,0.5269264069264069,3043
3 | 29,nabla-c0d3__sslyze,0.7120085015940489,0.7523910733262487,0.7630180658873539,1.2699256110520722,0.5689238210399032,941
4 | 27,kornicameister__axion,0.4909596662030598,0.5090403337969402,0.45479833101529904,1.6244784422809457,0.6301489921121823,719
5 | 1,scalableminds__webknossos-connect,0.6757188498402555,0.7507987220447284,0.7763578274760383,1.20926517571885,0.830238726790451,626
6 | 30,rakitaj__daily-programmer,0.737037037037037,0.7425925925925926,0.7666666666666667,1.5277777777777777,0.5009276437847866,540
7 | 42,seattleflu__id3c,0.7381889763779528,0.8326771653543307,0.8523622047244095,1.2066929133858268,0.4425087108013937,508
8 | 13,marcosschroh__dataclasses-avroschema,0.5056947608200456,0.5034168564920274,0.4760820045558087,1.4738041002277904,0.42787524366471735,439
9 | 25,nubark__instark,0.8165374677002584,0.8578811369509044,0.8578811369509044,1.1808785529715762,0.3467741935483871,387
10 | 48,everyclass__everyclass-server,0.7237237237237237,0.7897897897897898,0.7477477477477478,1.2012012012012012,0.3729003359462486,333
11 | 35,lucaswerkmeister__tool-quickcategories,0.7580645161290323,0.7870967741935484,0.7645161290322581,1.4,0.6581740976645435,310
12 | 6,flopp__GpxTrackPoster,0.5498154981549815,0.6568265682656826,0.7158671586715867,1.2472324723247232,0.6545893719806763,271
13 | 18,paulcwatts__drf-json-schema,0.6590909090909091,0.6439393939393939,0.6931818181818182,1.4090909090909092,0.528,264
14 | 41,webrecorder__browsertrix,0.61328125,0.6875,0.6875,1.17578125,0.30732292917166865,256
15 | 49,knowark__estimark,0.8630705394190872,0.8423236514522822,0.8464730290456431,1.3278008298755186,0.33565459610027853,241
16 | 38,futursolo__magichttp,0.509009009009009,0.6261261261261262,0.6981981981981982,1.1486486486486487,0.35406698564593303,222
17 | 20,AxelVoitier__lookups,0.3424657534246575,0.4155251141552511,0.3789954337899543,1.5936073059360731,0.33435114503816793,219
18 | 22,joshtemple__lkml,0.7317073170731707,0.7658536585365854,0.7707317073170732,1.4341463414634146,0.39575289575289574,205
19 | 45,jreese__aql,0.5487179487179488,0.5743589743589743,0.5641025641025641,1.4564102564102563,0.437219730941704,195
20 | 34,sonic182__aiosonic,0.8393782383419689,0.8393782383419689,0.7927461139896373,1.1347150259067358,0.291981845688351,193
21 | 43,TomerFi__aioswitcher,0.7263157894736842,0.8105263157894737,0.8947368421052632,1.1421052631578947,0.336283185840708,190
22 | 26,albertyw__git-browse,0.5944444444444444,0.7166666666666667,0.7277777777777777,1.0777777777777777,0.46272493573264784,180
23 | 23,kitsuyui__bamboo-crawler,0.6331360946745562,0.6449704142011834,0.6331360946745562,1.7514792899408285,0.6450381679389313,169
24 | 9,reddit__baseplate.py-upgrader,0.7548387096774194,0.7612903225806451,0.8129032258064516,1.264516129032258,0.4305555555555556,155
25 | 11,linw1995__data_extractor,0.6,0.676923076923077,0.6538461538461539,1.4230769230769231,0.23090586145648312,130
26 | 33,ocf__slackbridge,0.6935483870967742,0.7096774193548387,0.7741935483870968,1.3709677419354838,0.5767441860465117,124
27 | 28,payscale__fables,0.7338709677419355,0.7903225806451613,0.75,1.6774193548387097,0.3668639053254438,124
28 | 5,road-master__video-archiver,0.8536585365853658,0.9105691056910569,0.9186991869918699,1.1626016260162602,0.41694915254237286,123
29 | 19,ShadowTemplate__beautiful-python-3,0.8991596638655462,0.957983193277311,0.9411764705882353,1.1428571428571428,0.3287292817679558,119
30 | 39,ActivityWatch__aw-research,0.8205128205128205,0.8376068376068376,0.8205128205128205,1.6752136752136753,0.3667711598746082,117
31 | 2,flopp__unicode-explorer,0.8205128205128205,0.8547008547008547,0.9230769230769231,1.1452991452991452,0.5763546798029556,117
32 | 14,antonagestam__collectfast,0.6896551724137931,0.7155172413793104,0.7586206896551724,1.2586206896551724,0.5155555555555555,116
33 | 0,srittau__FakeSMTPd,0.8362068965517241,0.8793103448275862,0.896551724137931,1.2327586206896552,0.3483483483483483,116
34 | 10,typeddjango__pytest-mypy-plugins,0.7565217391304347,0.782608695652174,0.782608695652174,1.6956521739130435,0.732484076433121,115
35 | 21,Gerschtli__teamspeak-update-notifier,0.6,0.6666666666666666,0.6857142857142857,1.161904761904762,0.5497382198952879,105
36 | 16,brettkromkamp__topic-db,0.8476190476190476,0.8761904761904762,0.9238095238095239,1.0952380952380953,0.3633217993079585,105
37 | 37,ClearcodeHQ__mirakuru,0.6907216494845361,0.7422680412371134,0.7731958762886598,1.8556701030927836,0.383399209486166,97
38 | 32,JakobGM__quelf,0.5434782608695652,0.5652173913043478,0.5543478260869565,1.184782608695652,0.5287356321839081,92
39 | 7,boompig__book-classics,0.9058823529411765,0.9411764705882353,0.9411764705882353,1.2705882352941176,0.5862068965517241,85
40 | 40,lebrice__blurred-GAN,0.7012987012987013,0.6233766233766234,0.6883116883116883,1.077922077922078,0.18421052631578946,77
41 | 24,amplify-education__python-hcl2,0.6301369863013698,0.6438356164383562,0.821917808219178,1.0273972602739727,0.73,73
42 | 47,cliffxuan__mew,0.7638888888888888,0.8333333333333334,0.7777777777777778,1.0277777777777777,0.36923076923076925,72
43 | 12,albertyw__albertyw.com,0.5285714285714286,0.6285714285714286,0.6428571428571429,1.6,0.37433155080213903,70
44 | 36,ohjames__babies,0.8507462686567164,0.8805970149253731,0.8805970149253731,1.3582089552238805,0.32524271844660196,67
45 | 15,ammarnajjar__pautomate,0.8518518518518519,0.8703703703703703,0.9259259259259259,1.3888888888888888,0.5046728971962616,54
46 | 4,jelford__webwatcher,0.7659574468085106,0.7659574468085106,0.7659574468085106,1.3191489361702127,0.2186046511627907,47
47 | 8,dropbox__sqlalchemy-stubs,0.5952380952380952,0.6666666666666666,0.6428571428571429,1.2619047619047619,0.6086956521739131,42
48 | 44,jfly__jfly.github.io,0.696969696969697,0.8181818181818182,0.9696969696969697,1.1515151515151516,0.23741007194244604,33
49 | 3,eirannejad__calcatime,0.46875,0.53125,0.5,1.9375,0.8,32
50 | 31,silasary__discord_connotations,0.6153846153846154,0.6923076923076923,0.6923076923076923,1.4615384615384615,0.2708333333333333,13
51 | 46,yeraydiazdiaz__wait_for_it.py,1.0,1.0,1.0,1.3333333333333333,0.375,6
52 |
--------------------------------------------------------------------------------
/data/9.14.accs_by_project.csv:
--------------------------------------------------------------------------------
1 | ,project,non-incr,random-twice,double-traversal,label_size,label_rate,labels
2 | 17,basilisp-lang__basilisp,0.6087098886705959,0.49279633267845446,0.6175507531106745,1.2426326129666012,0.5258264462809917,3054
3 | 29,nabla-c0d3__sslyze,0.7771911298838438,0.7613516367476241,0.7697993664202746,1.2692713833157339,0.5660490137477585,947
4 | 27,kornicameister__axion,0.4884726224783862,0.4654178674351585,0.4812680115273775,1.644092219020173,0.6257889990982868,694
5 | 1,scalableminds__webknossos-connect,0.7231467473524962,0.7367624810892587,0.7609682299546142,1.2420574886535551,0.8304020100502513,661
6 | 35,lucaswerkmeister__tool-quickcategories,0.7504553734061931,0.7723132969034608,0.7905282331511839,1.5100182149362478,0.6052921719955898,549
7 | 42,seattleflu__id3c,0.7093235831809872,0.6983546617915904,0.716636197440585,1.2084095063985374,0.4494658997534922,547
8 | 30,rakitaj__daily-programmer,0.7388888888888889,0.75,0.7666666666666667,1.5277777777777777,0.5009276437847866,540
9 | 25,nubark__instark,0.8877284595300261,0.8903394255874674,0.8772845953002611,1.1723237597911227,0.3647619047619048,383
10 | 16,brettkromkamp__topic-db,0.7261904761904762,0.6934523809523809,0.7232142857142857,1.1458333333333333,0.5989304812834224,336
11 | 48,everyclass__everyclass-server,0.7777777777777778,0.8198198198198198,0.7837837837837838,1.2012012012012012,0.3729003359462486,333
12 | 6,flopp__GpxTrackPoster,0.6863468634686347,0.7011070110701108,0.6826568265682657,1.2472324723247232,0.6545893719806763,271
13 | 13,marcosschroh__dataclasses-avroschema,0.5576923076923077,0.5423076923076923,0.5346153846153846,1.353846153846154,0.29312288613303267,260
14 | 41,webrecorder__browsertrix,0.73046875,0.71484375,0.703125,1.17578125,0.30732292917166865,256
15 | 18,paulcwatts__drf-json-schema,0.6245059288537549,0.6600790513833992,0.6561264822134387,1.4031620553359683,0.5337552742616034,253
16 | 49,knowark__estimark,0.8672199170124482,0.8962655601659751,0.8630705394190872,1.3278008298755186,0.33565459610027853,241
17 | 38,futursolo__magichttp,0.581081081081081,0.6216216216216216,0.6756756756756757,1.1486486486486487,0.35406698564593303,222
18 | 20,AxelVoitier__lookups,0.4155251141552511,0.4794520547945205,0.4474885844748858,1.5936073059360731,0.337442218798151,219
19 | 22,joshtemple__lkml,0.751219512195122,0.8341463414634146,0.8146341463414634,1.4341463414634146,0.39575289575289574,205
20 | 45,jreese__aql,0.645320197044335,0.7044334975369458,0.7044334975369458,1.438423645320197,0.44713656387665196,203
21 | 34,sonic182__aiosonic,0.8031088082901554,0.8238341968911918,0.7305699481865285,1.1347150259067358,0.28550295857988167,193
22 | 43,TomerFi__aioswitcher,0.7684210526315789,0.8421052631578947,0.8789473684210526,1.1421052631578947,0.336283185840708,190
23 | 26,albertyw__git-browse,0.7555555555555555,0.7222222222222222,0.7777777777777778,1.0777777777777777,0.46272493573264784,180
24 | 23,kitsuyui__bamboo-crawler,0.6568047337278107,0.6804733727810651,0.6804733727810651,1.7514792899408285,0.6450381679389313,169
25 | 9,reddit__baseplate.py-upgrader,0.8292682926829268,0.8292682926829268,0.8475609756097561,1.25,0.4385026737967914,164
26 | 11,linw1995__data_extractor,0.6538461538461539,0.6846153846153846,0.6692307692307692,1.4230769230769231,0.2480916030534351,130
27 | 33,ocf__slackbridge,0.7419354838709677,0.75,0.7983870967741935,1.3709677419354838,0.5767441860465117,124
28 | 28,payscale__fables,0.7903225806451613,0.7661290322580645,0.7983870967741935,1.6774193548387097,0.3668639053254438,124
29 | 5,road-master__video-archiver,0.8617886178861789,0.8617886178861789,0.8699186991869918,1.1626016260162602,0.41694915254237286,123
30 | 19,ShadowTemplate__beautiful-python-3,0.9159663865546218,0.9159663865546218,0.9159663865546218,1.1428571428571428,0.3287292817679558,119
31 | 2,flopp__unicode-explorer,0.7777777777777778,0.8632478632478633,0.8632478632478633,1.1452991452991452,0.5763546798029556,117
32 | 39,ActivityWatch__aw-research,0.8290598290598291,0.8376068376068376,0.8547008547008547,1.6752136752136753,0.3667711598746082,117
33 | 14,antonagestam__collectfast,0.7327586206896551,0.7586206896551724,0.7586206896551724,1.2586206896551724,0.5155555555555555,116
34 | 0,srittau__FakeSMTPd,0.8706896551724138,0.9224137931034483,0.9310344827586207,1.2327586206896552,0.3483483483483483,116
35 | 10,typeddjango__pytest-mypy-plugins,0.7217391304347827,0.782608695652174,0.782608695652174,1.6956521739130435,0.732484076433121,115
36 | 21,Gerschtli__teamspeak-update-notifier,0.7619047619047619,0.7714285714285715,0.8571428571428571,1.161904761904762,0.5497382198952879,105
37 | 37,ClearcodeHQ__mirakuru,0.7731958762886598,0.8041237113402062,0.7938144329896907,1.8556701030927836,0.383399209486166,97
38 | 32,JakobGM__quelf,0.45652173913043476,0.45652173913043476,0.4673913043478261,1.184782608695652,0.5287356321839081,92
39 | 7,boompig__book-classics,0.9294117647058824,0.9294117647058824,0.9294117647058824,1.2705882352941176,0.5862068965517241,85
40 | 40,lebrice__blurred-GAN,0.7380952380952381,0.6904761904761905,0.6785714285714286,1.0714285714285714,0.18876404494382024,84
41 | 24,amplify-education__python-hcl2,0.7397260273972602,0.7808219178082192,0.821917808219178,1.0273972602739727,0.73,73
42 | 47,cliffxuan__mew,0.704225352112676,0.5915492957746479,0.5915492957746479,1.028169014084507,0.36597938144329895,71
43 | 12,albertyw__albertyw.com,0.6714285714285714,0.6571428571428571,0.6571428571428571,1.6,0.37433155080213903,70
44 | 36,ohjames__babies,0.8507462686567164,0.8208955223880597,0.8208955223880597,1.3582089552238805,0.32524271844660196,67
45 | 15,ammarnajjar__pautomate,0.9074074074074074,0.9259259259259259,0.9259259259259259,1.3888888888888888,0.5046728971962616,54
46 | 4,jelford__webwatcher,0.7659574468085106,0.7446808510638298,0.7446808510638298,1.3191489361702127,0.2186046511627907,47
47 | 8,dropbox__sqlalchemy-stubs,0.6904761904761905,0.6904761904761905,0.6904761904761905,1.2619047619047619,0.6086956521739131,42
48 | 44,jfly__jfly.github.io,0.7878787878787878,0.8181818181818182,0.8181818181818182,1.1515151515151516,0.23741007194244604,33
49 | 3,eirannejad__calcatime,0.90625,0.90625,0.90625,1.9375,0.8,32
50 | 31,silasary__discord_connotations,0.6666666666666666,0.6666666666666666,0.6666666666666666,1.4,0.28846153846153844,15
51 | 46,yeraydiazdiaz__wait_for_it.py,1.0,1.0,1.0,1.3333333333333333,0.375,6
52 |
--------------------------------------------------------------------------------
/data/TypeT5-Workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/TypeT5-Workflow.png
--------------------------------------------------------------------------------
/data/code/bad_code_1.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 |
4 | # A recursive fibonacci function
5 | def fib(n: str) -> list[int]:
6 | if n == 0:
7 | return 0
8 | elif n == 1:
9 | return 1
10 | else:
11 | return fib(n - 1) + fib(n - 2)
12 |
13 |
14 | def t_add(x: str, y: str) -> int:
15 | r = x + y
16 | return r
17 |
18 |
19 | x: int = fib(3)
20 | bad_y: str = 1
21 |
--------------------------------------------------------------------------------
/data/code/bad_code_2.py:
--------------------------------------------------------------------------------
1 | from bad_code_1 import fib
2 |
3 | i: int = 4
4 | fib(i)
5 |
--------------------------------------------------------------------------------
/data/code/code_with_slash.py:
--------------------------------------------------------------------------------
1 | class SlashClass:
2 | def __init__(self, check_interval: int, folder: Path, /) -> None:
3 | self._autolocked: Dict[Path, int] = {}
4 | self._lockers: Dict[Path, "DirectEdit"] = {}
5 | self._to_lock: Items = []
6 |
--------------------------------------------------------------------------------
/data/code/dummy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/code/dummy/__init__.py
--------------------------------------------------------------------------------
/data/code/dummy/dummy_1.py:
--------------------------------------------------------------------------------
1 | def f_int(x: int) -> int:
2 | return x
3 |
--------------------------------------------------------------------------------
/data/code/dummy/dummy_2.py:
--------------------------------------------------------------------------------
1 | from dummy.dummy_1 import f_int
2 |
3 | s: str = f_int(2)
4 |
--------------------------------------------------------------------------------
/data/code/env_code_1.py:
--------------------------------------------------------------------------------
1 | # Env example 1: no existing annotations
2 |
3 |
4 | def fib(n):
5 | if n == 0:
6 | return 0
7 | elif n == 1:
8 | return 1
9 | else:
10 | return fib(n - 1) + fib(n - 2)
11 |
12 |
13 | def foo(bar):
14 | return fib(bar)
15 |
16 |
17 | def int_add(a, b):
18 | return a + b + "c"
19 |
20 |
21 | def int_tripple_add(a, b, c):
22 | return a + b + c
23 |
--------------------------------------------------------------------------------
/data/code/env_code_2.py:
--------------------------------------------------------------------------------
1 | # Env example 2: some existing annotations
2 |
3 | from typing import *
4 |
5 |
6 | def fib(n: int):
7 | if n == 0:
8 | return 0
9 | elif n == 1:
10 | return 1
11 | else:
12 | return fib(n - 1) + fib(n - 2)
13 |
14 |
15 | def foo(bar: int):
16 | return fib(bar)
17 |
18 |
19 | class Bar:
20 | z: str = "hello"
21 | w: str
22 |
23 | def __init__(self, x: int):
24 | self.x: int = x
25 | self.y: Optional[int] = None
26 | self.reset(self.z)
27 |
28 | def reset(self, w0):
29 | self.w = w0
30 |
31 | def foo(self, z: str) -> int:
32 | return self.x + len(z)
33 |
34 |
35 | bar: Bar = Bar(3)
36 |
--------------------------------------------------------------------------------
/data/code/good_code_1.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Any # [added by SPOT]
3 | from typing import Optional
4 |
5 | print(math.sin(4))
6 |
7 | x_str: str = "x"
8 | y: Any = 1
9 | z_str: str = x_str + y
10 |
11 |
12 | class Foo:
13 | def __init__(self, x: int):
14 | self.x: int = x
15 | self.y: Optional[int] = None
16 | self.z = "hello"
17 |
18 | def foo(self, z: str) -> int:
19 | return self.x + len(z)
20 |
--------------------------------------------------------------------------------
/data/ex_repo/ex_code_1.py:
--------------------------------------------------------------------------------
1 | # Env example 1: no existing annotations
2 |
3 | good = 5
4 |
5 |
6 | def fib(n):
7 | if n == 0:
8 | return 0
9 | elif n == 1:
10 | return 1
11 | else:
12 | return fib(n - 1) + fib(n - 2)
13 |
14 |
15 | class Wrapper:
16 | x_elem: int
17 | y: str
18 |
19 | @staticmethod
20 | def foo(bar):
21 | return fib(bar)
22 |
23 | def inc(self):
24 | self.x_elem += 1
25 | return self.y
26 |
27 |
28 | def int_add(a, b):
29 | # this is a strange function
30 | return a + b + "c"
31 |
32 |
33 | def int_tripple_add(a, b, c):
34 | return a + b + c
35 |
--------------------------------------------------------------------------------
/data/ex_repo/ex_code_2.py:
--------------------------------------------------------------------------------
1 | # Env example 2: some existing annotations
2 |
3 | from typing import *
4 | from ex_code_1 import int_add
5 |
6 |
7 | def fib(n: int):
8 | if n == 0:
9 | return 0
10 | elif n == 1:
11 | return 1
12 | else:
13 | return fib(n - 1) + fib(n - 2)
14 |
15 |
16 | def foo(bar: int):
17 | return fib(bar)
18 |
19 |
20 | class Bar:
21 | z: str = "hello"
22 | w: str
23 |
24 | def __init__(self, x: int):
25 | self.x: int = x
26 | self.y: Optional[int] = None
27 | self.reset(self.z)
28 |
29 | def reset(self, w0):
30 | self.w = w0
31 |
32 | def foo(self, z: str) -> int:
33 | return int_add(self.x, len(z))
34 |
35 |
36 | bar: Bar = Bar(3)
37 |
--------------------------------------------------------------------------------
/data/repos_split.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/repos_split.pkl
--------------------------------------------------------------------------------
/data/useful_repos.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/data/useful_repos.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -i https://pypi.org/simple
2 | absl-py==1.4.0 ; python_version >= '3.6'
3 | aiohttp==3.8.4 ; python_version >= '3.6'
4 | aiosignal==1.3.1 ; python_version >= '3.7'
5 | appdirs==1.4.4
6 | asttokens==2.2.1
7 | async-timeout==4.0.2 ; python_version >= '3.6'
8 | attrs==22.2.0 ; python_version >= '3.6'
9 | backcall==0.2.0
10 | cachetools==5.3.0 ; python_version ~= '3.7'
11 | certifi==2022.12.7 ; python_version >= '3.6'
12 | cffi==1.15.1
13 | charset-normalizer==3.0.1 ; python_version >= '3.6'
14 | click==8.1.3 ; python_version >= '3.7'
15 | colored==1.4.4
16 | comm==0.1.2 ; python_version >= '3.6'
17 | cryptography==39.0.1 ; python_version >= '3.6'
18 | datasets==2.4.0
19 | dateparser==1.1.7
20 | debugpy==1.6.6 ; python_version >= '3.7'
21 | decorator==5.1.1 ; python_version >= '3.5'
22 | dill==0.3.5.1 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
23 | docker-pycreds==0.4.0
24 | executing==1.2.0
25 | filelock==3.9.0 ; python_version >= '3.7'
26 | frozenlist==1.3.3 ; python_version >= '3.7'
27 | fsspec[http]==2023.1.0 ; python_version >= '3.7'
28 | gitdb==4.0.10 ; python_version >= '3.7'
29 | gitpython==3.1.31 ; python_version >= '3.7'
30 | google-auth==2.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'
31 | google-auth-oauthlib==0.4.6 ; python_version >= '3.6'
32 | grpcio==1.51.1 ; python_version >= '3.7'
33 | huggingface-hub==0.12.0
34 | idna==3.4 ; python_version >= '3.5'
35 | ipykernel==6.21.2
36 | ipython==8.10.0 ; python_version >= '3.8'
37 | ipywidgets==8.0.4
38 | jedi==0.18.2 ; python_version >= '3.6'
39 | jupyter-client==8.0.3 ; python_version >= '3.8'
40 | jupyter-core==5.2.0 ; python_version >= '3.8'
41 | jupyterlab-widgets==3.0.5 ; python_version >= '3.7'
42 | libcst==0.4.2
43 | markdown==3.4.1 ; python_version >= '3.7'
44 | markupsafe==2.1.2 ; python_version >= '3.7'
45 | matplotlib-inline==0.1.6 ; python_version >= '3.5'
46 | multidict==6.0.4 ; python_version >= '3.7'
47 | multiprocess==0.70.13 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
48 | mypy==0.991
49 | mypy-extensions==1.0.0 ; python_version >= '3.5'
50 | nest-asyncio==1.5.6 ; python_version >= '3.5'
51 | numpy==1.24.2 ; python_version >= '3.8'
52 | nvidia-cublas-cu11==11.10.3.66 ; platform_system == 'Linux'
53 | nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == 'Linux'
54 | nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == 'Linux'
55 | nvidia-cudnn-cu11==8.5.0.96 ; platform_system == 'Linux'
56 | oauthlib==3.2.2 ; python_version >= '3.6'
57 | packaging==23.0 ; python_version >= '3.7'
58 | pandas==1.5.3
59 | parso==0.8.3 ; python_version >= '3.6'
60 | pathtools==0.1.2
61 | pexpect==4.8.0 ; sys_platform != 'win32'
62 | pickleshare==0.7.5
63 | platformdirs==3.0.0 ; python_version >= '3.7'
64 | plotly==5.13.0
65 | prettytable==3.4.1
66 | prompt-toolkit==3.0.36 ; python_full_version >= '3.6.2'
67 | protobuf==4.22.0 ; python_version >= '3.10' and sys_platform == 'linux'
68 | psutil==5.9.4 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
69 | ptyprocess==0.7.0
70 | pure-eval==0.2.2
71 | pyarrow==11.0.0 ; python_version >= '3.7'
72 | pyasn1==0.4.8
73 | pyasn1-modules==0.2.8
74 | pycparser==2.21
75 | pydeprecate==0.3.2 ; python_version >= '3.6'
76 | pygments==2.14.0 ; python_version >= '3.6'
77 | pyrsistent==0.19.3
78 | python-dateutil==2.8.2 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
79 | pytorch-lightning==1.7.7
80 | pytz==2022.7.1
81 | pytz-deprecation-shim==0.1.0.post0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'
82 | pyyaml==6.0 ; python_version >= '3.6'
83 | pyzmq==25.0.0 ; python_version >= '3.6'
84 | regex==2022.10.31 ; python_version >= '3.6'
85 | requests==2.28.2 ; python_version >= '3.7' and python_version < '4'
86 | requests-oauthlib==1.3.1 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
87 | responses==0.18.0 ; python_version >= '3.7'
88 | rsa==4.9 ; python_version >= '3.6'
89 | sentry-sdk==1.15.0
90 | setproctitle==1.3.2 ; python_version >= '3.7'
91 | setuptools==67.3.2 ; python_version >= '3.7'
92 | six==1.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
93 | smmap==5.0.0 ; python_version >= '3.6'
94 | stack-data==0.6.2
95 | tenacity==8.2.1 ; python_version >= '3.6'
96 | tensorboard==2.12.0 ; python_version >= '3.8'
97 | tensorboard-data-server==0.7.0 ; python_version >= '3.7'
98 | tensorboard-plugin-wit==1.8.1
99 | termcolor==1.1.0
100 | tokenizers==0.12.1
101 | tomli==2.0.1 ; python_version < '3.11'
102 | torch==1.13.1 ; python_full_version >= '3.7.0'
103 | torchmetrics==0.11.1 ; python_version >= '3.7'
104 | tornado==6.2 ; python_version >= '3.7'
105 | tqdm==4.64.1
106 | traitlets==5.9.0 ; python_version >= '3.7'
107 | transformers==4.21.3
108 | types-aiofiles==22.1.0.8
109 | types-all==1.0.0
110 | types-annoy==1.17.8.1
111 | types-atomicwrites==1.4.5.1
112 | types-backports==0.1.3
113 | types-backports-abc==0.5.2
114 | types-bleach==6.0.0.0
115 | types-boto==2.49.18.5
116 | types-cachetools==5.3.0.0
117 | types-certifi==2021.10.8.3
118 | types-cffi==1.15.1.5
119 | types-characteristic==14.3.7
120 | types-chardet==5.0.4.1
121 | types-click==7.1.8
122 | types-click-spinner==0.1.13.2
123 | types-colorama==0.4.15.7
124 | types-contextvars==2.4.7
125 | types-croniter==1.3.2.4
126 | types-cryptography==3.3.23.2
127 | types-dataclasses==0.6.6
128 | types-dateparser==1.1.4.7
129 | types-datetimerange==2.0.0.1
130 | types-decorator==5.1.8.2
131 | types-deprecated==1.2.9
132 | types-docopt==0.6.11.1
133 | types-docutils==0.19.1.4
134 | types-emoji==2.1.0.1
135 | types-enum34==1.1.8
136 | types-fb303==1.0.0
137 | types-filelock==3.2.7
138 | types-first==2.0.5
139 | types-flask==1.1.6
140 | types-freezegun==1.1.10
141 | types-frozendict==2.0.9
142 | types-futures==3.3.8
143 | types-geoip2==3.0.0
144 | types-ipaddress==1.0.8
145 | types-itsdangerous==1.1.6
146 | types-jack-client==0.5.10.5
147 | types-jinja2==2.11.9
148 | types-kazoo==0.1.3
149 | types-markdown==3.4.2.4
150 | types-markupsafe==1.1.10
151 | types-maxminddb==1.5.0
152 | types-mock==5.0.0.4
153 | types-mypy-extensions==1.0.0.1
154 | types-nmap==0.1.6
155 | types-openssl-python==0.1.3
156 | types-orjson==3.6.2
157 | types-paramiko==3.0.0.3
158 | types-pathlib2==2.3.0
159 | types-pillow==9.4.0.12
160 | types-pkg-resources==0.1.3
161 | types-polib==1.1.12.1
162 | types-protobuf==4.21.0.6
163 | types-pyaudio==0.2.16.5
164 | types-pycurl==7.45.2.2
165 | types-pyfarmhash==0.3.1
166 | types-pyjwt==1.7.1
167 | types-pymssql==2.1.0
168 | types-pymysql==1.0.19.3
169 | types-pyopenssl==23.0.0.3
170 | types-pyrfc3339==1.1.1.2
171 | types-pysftp==0.2.17.1
172 | types-python-dateutil==2.8.19.7
173 | types-python-gflags==3.1.7.1
174 | types-python-slugify==8.0.0.0
175 | types-pytz==2022.7.1.0
176 | types-pyvmomi==8.0.0.0
177 | types-pyyaml==6.0.12.6
178 | types-redis==4.5.1.1
179 | types-requests==2.28.11.13
180 | types-retry==0.9.9.1
181 | types-routes==2.5.0
182 | types-scribe==2.0.0
183 | types-simplejson==3.18.0.0
184 | types-singledispatch==4.0.0.0
185 | types-six==1.16.21.4
186 | types-tabulate==0.9.0.0
187 | types-termcolor==1.1.6
188 | types-toml==0.10.8.4
189 | types-tornado==5.1.1
190 | types-typed-ast==1.5.8.3
191 | types-tzlocal==4.2.2.2
192 | types-ujson==5.7.0.0
193 | types-urllib3==1.26.25.6
194 | types-waitress==2.1.4.4
195 | types-werkzeug==1.0.9
196 | types-xxhash==3.0.5.1
197 | -e .
198 | typing-extensions==4.5.0 ; python_version >= '3.7'
199 | typing-inspect==0.8.0
200 | tzdata==2022.7 ; python_version >= '3.6'
201 | tzlocal==4.2 ; python_version >= '3.6'
202 | urllib3==1.26.14 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'
203 | wandb==0.13.10
204 | wcwidth==0.2.6
205 | werkzeug==2.2.3 ; python_version >= '3.7'
206 | wheel==0.38.4 ; python_version >= '3.7'
207 | widgetsnbextension==4.0.5 ; python_version >= '3.7'
208 | xxhash==3.2.0 ; python_version >= '3.6'
209 | yarl==1.8.2 ; python_version >= '3.7'
210 |
--------------------------------------------------------------------------------
/scripts/archive/analyze_dagger.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "import os\n",
13 | "from typing import *\n",
14 | "\n",
15 | "from typet5.utils import proj_root, get_data_dir\n",
16 | "\n",
17 | "os.chdir(proj_root())\n",
18 | "\n",
19 | "datadir = get_data_dir()"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stderr",
29 | "output_type": "stream",
30 | "text": [
31 | "/home/jiayi/Projects/SPOT/.venv/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
32 | " warn(f\"Failed to load image Python extension: {e}\")\n"
33 | ]
34 | },
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "quicktest=False\n",
40 | "Loading datasets: tk_dataset-all_labels-drop_comments\n"
41 | ]
42 | }
43 | ],
44 | "source": [
45 | "# experiment configurations\n",
46 | "\n",
47 | "from typet5.data import (\n",
48 | " TokenizedSrcSet,\n",
49 | " get_dataset_name,\n",
50 | " load_tokenized_srcsets,\n",
51 | " TypeCheckSettings,\n",
52 | ")\n",
53 | "from typet5.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper\n",
54 | "from typet5.train import TrainingConfig, TypeCheckArgs\n",
55 | "\n",
56 | "config = TrainingConfig(\n",
57 | " quicktest=False,\n",
58 | " all_labels=True,\n",
59 | " ctx_size=2048,\n",
60 | " left_margin=1024,\n",
61 | " right_margin=1023,\n",
62 | " modifications=\"no_type_checker\",\n",
63 | ")\n",
64 | "gpu_id = 1\n",
65 | "TypeCheckSettings.temp_path = f\"DAgger-{gpu_id}\"\n",
66 | "\n",
67 | "print(f\"quicktest={config.quicktest}\")\n",
68 | "\n",
69 | "project_name = \"test-SPOT\" if config.quicktest else \"SPOT\"\n",
70 | "train_ctx_args = config.train_ctx_args()\n",
71 | "tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)\n",
72 | "\n",
73 | "datasets_name = get_dataset_name(\n",
74 | " drop_comments=config.drop_comments,\n",
75 | " all_labels=config.all_labels,\n",
76 | " imports_in_preamble=config.imports_in_preamble,\n",
77 | ")\n",
78 | "\n",
79 | "model_name = \"DAgger-model--\" + config.as_name()\n",
80 | "\n",
81 | "tk_dataset = load_tokenized_srcsets(\n",
82 | " datadir,\n",
83 | " datasets_name,\n",
84 | " data_reduction=config.data_reduction,\n",
85 | " quicktest=config.quicktest,\n",
86 | ")"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 7,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "# load the model\n",
96 | "from typet5.model import load_model_spot, DefaultTokenizer\n",
97 | "from typet5.model import ModelWrapper\n",
98 | "from typet5.dagger import DAggerModel\n",
99 | "import torch\n",
100 | "\n",
101 | "dec_args = DecodingArgs(\n",
102 | " sampling_max_tokens=8 * config.ctx_size,\n",
103 | " ctx_args=config.dec_ctx_args(),\n",
104 | " do_sample=True,\n",
105 | " num_beams=None, # try greedy decoding\n",
106 | " top_p=0.9,\n",
107 | ")\n",
108 | "\n",
109 | "wrapper = ModelWrapper.from_pretrained(datadir / f\"checkpoints/saved/{model_name}\")\n",
110 | "device = torch.device(f\"cuda:{gpu_id}\" if torch.cuda.is_available() else \"cpu\")\n",
111 | "wrapper.to(device)\n",
112 | "dmodel = DAggerModel(wrapper)"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 5,
118 | "metadata": {},
119 | "outputs": [
120 | {
121 | "name": "stderr",
122 | "output_type": "stream",
123 | "text": [
124 | "compute_preexisting_fdbks: 100%|██████████| 50/50 [00:04<00:00, 11.99it/s]\n",
125 | "eval_on_data: 100%|██████████| 16950/16950 [46:45<00:00, 6.04it/s]\n"
126 | ]
127 | },
128 | {
129 | "name": "stdout",
130 | "output_type": "stream",
131 | "text": [
132 | "partial_acc (ImNone): 69.01% (count=16.9k)\n",
133 | "full_acc (ImNone): 64.86% (count=16.9k)\n",
134 | "partial_acc: 67.43% (count=16.9k)\n",
135 | "ast_acc: 56.81% (count=21.3k)\n",
136 | "full_acc: 62.03% (count=16.9k)\n",
137 | "partial_acc_by_cat:\n",
138 | " FuncArg: 62.96% (count=8.0k)\n",
139 | " FuncReturn: 78.14% (count=5.7k)\n",
140 | " ClassAtribute: 58.21% (count=2.7k)\n",
141 | " GlobalVar: 75.96% (count=104)\n",
142 | " LocalVar: 64.22% (count=531)\n",
143 | "partial_acc_by_pos:\n",
144 | " range(0, 1): 80.39% (count=933)\n",
145 | " range(1, 2): 77.13% (count=870)\n",
146 | " range(2, 4): 77.58% (count=1.5k)\n",
147 | " range(4, 8): 74.05% (count=2.4k)\n",
148 | " range(8, 16): 72.80% (count=3.1k)\n",
149 | " range(16, 32): 67.31% (count=3.2k)\n",
150 | " range(32, 64): 63.86% (count=2.3k)\n",
151 | " range(64, 128): 53.42% (count=1.1k)\n",
152 | " range(128, 256): 40.00% (count=735)\n",
153 | " range(256, 512): 32.89% (count=672)\n",
154 | " range(512, 1024): 52.83% (count=53)\n",
155 | "avg_label_size: 1.2589\n",
156 | "avg_pred_size: 1.1258\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "# evaluate (greedy)\n",
162 | "from typet5.utils import pretty_print_dict, pretty_show_dict\n",
163 | "from typet5.visualization import visualize_preds_on_code\n",
164 | "\n",
165 | "eval_r = await dmodel.eval_on_data(tk_dataset[\"test\"])\n",
166 | "pretty_print_dict(eval_r.accuracies)"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 10,
172 | "metadata": {},
173 | "outputs": [
174 | {
175 | "name": "stderr",
176 | "output_type": "stream",
177 | "text": [
178 | "compute_preexisting_fdbks: 100%|██████████| 6/6 [00:02<00:00, 2.16it/s]\n",
179 | "eval_on_data: 100%|██████████| 180/180 [04:57<00:00, 1.65s/it]\n"
180 | ]
181 | },
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "partial_acc (ImNone): 63.89% (count=180)\n",
187 | "full_acc (ImNone): 61.11% (count=180)\n",
188 | "partial_acc: 65.00% (count=180)\n",
189 | "ast_acc: 49.59% (count=244)\n",
190 | "full_acc: 60.56% (count=180)\n",
191 | "partial_acc_by_cat:\n",
192 | " FuncArg: 58.42% (count=101)\n",
193 | " FuncReturn: 71.23% (count=73)\n",
194 | " GlobalVar: 100.00% (count=3)\n",
195 | " LocalVar: 100.00% (count=3)\n",
196 | "partial_acc_by_pos:\n",
197 | " range(0, 1): 90.91% (count=11)\n",
198 | " range(1, 2): 100.00% (count=8)\n",
199 | " range(2, 4): 100.00% (count=12)\n",
200 | " range(4, 8): 69.23% (count=13)\n",
201 | " range(8, 16): 54.55% (count=11)\n",
202 | " range(16, 32): 43.75% (count=16)\n",
203 | " range(32, 64): 50.00% (count=32)\n",
204 | " range(64, 128): 60.94% (count=64)\n",
205 | " range(128, 256): 76.92% (count=13)\n",
206 | "avg_label_size: 1.3556\n",
207 | "avg_pred_size: 1.1\n"
208 | ]
209 | }
210 | ],
211 | "source": [
212 | "# evaluate\n",
213 | "from numpy import roll\n",
214 | "from typet5.utils import pretty_print_dict, pretty_show_dict\n",
215 | "from typet5.visualization import visualize_preds_on_code\n",
216 | "\n",
217 | "dmodel.wrapper.args = DecodingArgs(\n",
218 | " sampling_max_tokens=8 * config.ctx_size,\n",
219 | " ctx_args=config.dec_ctx_args(),\n",
220 | " do_sample=True, # use necleus sampling during training\n",
221 | " top_p=0.9,\n",
222 | ")\n",
223 | "\n",
224 | "eval_r = await dmodel.eval_on_data(tk_dataset[\"train\"][1:105:10])\n",
225 | "pretty_print_dict(eval_r.accuracies)"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 12,
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "name": "stderr",
235 | "output_type": "stream",
236 | "text": [
237 | "Exporting: 100%|██████████| 11/11 [00:00<00:00, 131.04it/s]\n",
238 | "Computing accuracies: 100%|██████████| 11/11 [00:00<00:00, 9128.88it/s]\n"
239 | ]
240 | }
241 | ],
242 | "source": [
243 | "from typet5.data import TokenizedSrcSet\n",
244 | "from typet5.visualization import export_preds_on_code\n",
245 | "\n",
246 | "viz_ds = TokenizedSrcSet(tk_dataset[\"test\"].repos_root, eval_r.final_srcs)\n",
247 | "viz_preds = eval_r.final_preds\n",
248 | "\n",
249 | "export_preds_on_code(viz_ds, viz_preds, proj_root() / \"caches/DAgger-preds-on-code\")"
250 | ]
251 | }
252 | ],
253 | "metadata": {
254 | "kernelspec": {
255 | "display_name": "Python 3.10.4 ('.venv': pipenv)",
256 | "language": "python",
257 | "name": "python3"
258 | },
259 | "language_info": {
260 | "codemirror_mode": {
261 | "name": "ipython",
262 | "version": 3
263 | },
264 | "file_extension": ".py",
265 | "mimetype": "text/x-python",
266 | "name": "python",
267 | "nbconvert_exporter": "python",
268 | "pygments_lexer": "ipython3",
269 | "version": "3.10.4"
270 | },
271 | "orig_nbformat": 4,
272 | "vscode": {
273 | "interpreter": {
274 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8"
275 | }
276 | }
277 | },
278 | "nbformat": 4,
279 | "nbformat_minor": 2
280 | }
281 |
--------------------------------------------------------------------------------
/scripts/archive/code-t5-workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": []
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%load_ext autoreload\n",
17 | "%autoreload 2\n",
18 | "\n",
19 | "import os\n",
20 | "import pickle\n",
21 | "from concurrent.futures import ProcessPoolExecutor\n",
22 | "from pathlib import Path\n",
23 | "from typing import *\n",
24 | "\n",
25 | "import pandas as pd\n",
26 | "import plotly.express as px\n",
27 | "\n",
28 | "from typet5.data import GitRepo, ModuleRemapUnpickler\n",
29 | "from typet5.type_env import (\n",
30 | " AnnotPath,\n",
31 | " MypyChecker,\n",
32 | " SelectAnnotations,\n",
33 | " TypeInfAction,\n",
34 | " TypeInfEnv,\n",
35 | " TypeInfState,\n",
36 | " collect_annotations,\n",
37 | " mypy_checker,\n",
38 | ")\n",
39 | "from typet5.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file\n",
40 | "\n",
41 | "os.chdir(proj_root())\n",
42 | "\n",
43 | "datadir = Path(os.getenv(\"datadir\"))\n",
44 | "repos_dir = datadir / \"SPOT-data/repos\"\n",
45 | "\n",
46 | "useful_repos_path = proj_root() / \"scripts\" / \"useful_repos.pkl\"\n",
47 | "rename_module = lambda n: \"typet5.data\" if n == \"typet5.data_prepare\" else n\n",
48 | "with useful_repos_path.open(\"rb\") as f:\n",
49 | " useful_repos: list[GitRepo] = ModuleRemapUnpickler(f, rename_module).load()"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 2,
55 | "metadata": {},
56 | "outputs": [
57 | {
58 | "name": "stderr",
59 | "output_type": "stream",
60 | "text": [
61 | "/home/jiayi/Projects/SPOT/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1402: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n",
62 | " warnings.warn(\n"
63 | ]
64 | }
65 | ],
66 | "source": [
67 | "# loading pre-trained model and tokenizer\n",
68 | "\n",
69 | "model_dir = datadir/\"checkpoints/saved/SPOT-CodeT5-with_margin/\"\n",
70 | "\n",
71 | "import torch\n",
72 | "from transformers import (\n",
73 | " DataCollatorForSeq2Seq,\n",
74 | " RobertaTokenizer,\n",
75 | " T5ForConditionalGeneration,\n",
76 | ")\n",
77 | "from transformers.models.t5 import T5ForConditionalGeneration\n",
78 | "\n",
79 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
80 | "tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(model_dir)\n",
81 | "model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(\n",
82 | " model_dir\n",
83 | ").to(device)\n",
84 | "max_target_length = 128"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 25,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "from typet5.data import mask_type_annots, output_ids_as_types, tokenize_masked\n",
94 | "\n",
95 | "test_code = \"\"\"\n",
96 | "@dataclass\n",
97 | "class GitRepo:\n",
98 | " author: str\n",
99 | " name: str\n",
100 | " url: str\n",
101 | " stars: int\n",
102 | " forks: int\n",
103 | "\n",
104 | " def authorname(self):\n",
105 | " return self.author + \"__\" + self.name\n",
106 | "\n",
107 | " def repo_dir(self, repos_dir: Path) -> Path:\n",
108 | " return repos_dir / \"downloaded\" / self.authorname()\n",
109 | "\n",
110 | " def download(self, repos_dir: Path, timeout=None) -> bool:\n",
111 | " pass\n",
112 | "\"\"\"\n",
113 | "\n",
114 | "\n",
115 | "def run_model(code: str, num_beams=16):\n",
116 | " tks = tokenize_masked(mask_type_annots((Path('no_source'), code)), tokenizer, device)\n",
117 | " input_ids = tks[\"input_ids\"]\n",
118 | " with torch.no_grad():\n",
119 | " loss = model.forward(**tks).loss\n",
120 | " dec = model.generate(\n",
121 | " input_ids,\n",
122 | " max_length=max_target_length,\n",
123 | " num_beams=num_beams,\n",
124 | " # do_sample=True,\n",
125 | " )[0]\n",
126 | " return {\n",
127 | " \"loss\": loss,\n",
128 | " \"predicted_types\": output_ids_as_types(dec, tokenizer),\n",
129 | " \"labels\": output_ids_as_types(tks[\"labels\"][0], tokenizer),\n",
130 | " \"generation\": tokenizer.decode(dec),\n",
131 | " \"input_ids\": input_ids[0],\n",
132 | " \"output_ids\": dec,\n",
133 | " }\n",
134 | "\n",
135 | "\n",
136 | "result = run_model(test_code, num_beams=10)"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 27,
142 | "metadata": {},
143 | "outputs": [
144 | {
145 | "name": "stdout",
146 | "output_type": "stream",
147 | "text": [
148 | "\n",
149 | "@dataclass\n",
150 | "class GitRepo:\n",
151 | " author:\n",
152 | " name:\n",
153 | " url:\n",
154 | " stars:\n",
155 | " forks:\n",
156 | "\n",
157 | " def authorname(self):\n",
158 | " return self.author + \"__\" + self.name\n",
159 | "\n",
160 | " def repo_dir(self, repos_dir:) ->:\n",
161 | " return repos_dir / \"downloaded\" / self.authorname()\n",
162 | "\n",
163 | " def download(self, repos_dir:, timeout=None) ->:\n",
164 | " pass\n",
165 | "\n"
166 | ]
167 | }
168 | ],
169 | "source": [
170 | "# Step 1: Replace all types to predict with special tokens\n",
171 | "print(tokenizer.decode(result['input_ids']))"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 28,
177 | "metadata": {},
178 | "outputs": [
179 | {
180 | "name": "stdout",
181 | "output_type": "stream",
182 | "text": [
183 | "['', 'Ċ', '@', 'data', 'class', 'Ċ', 'class', 'ĠGit', 'Repo', ':', 'Ċ', 'ĠĠĠ', 'Ġauthor', ':', '', 'Ċ', 'ĠĠĠ', 'Ġname', ':', '', 'Ċ', 'ĠĠĠ', 'Ġurl', ':', '', 'Ċ', 'ĠĠĠ', 'Ġstars', ':', '', 'Ċ', 'ĠĠĠ', 'Ġfor', 'ks', ':', '', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġauthor', 'name', '(', 'self', '):', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġself', '.', 'author', 'Ġ+', 'Ġ\"__', '\"', 'Ġ+', 'Ġself', '.', 'name', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġrepo', '_', 'dir', '(', 'self', ',', 'Ġrepos', '_', 'dir', ':', '', ')', 'Ġ->', '', ':', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġrepos', '_', 'dir', 'Ġ/', 'Ġ\"', 'down', 'loaded', '\"', 'Ġ/', 'Ġself', '.', 'author', 'name', '()', 'Ċ', 'Ċ', 'ĠĠĠ', 'Ġdef', 'Ġdownload', '(', 'self', ',', 'Ġrepos', '_', 'dir', ':', '', ',', 'Ġtimeout', '=', 'None', ')', 'Ġ->', '', ':', 'Ċ', 'ĠĠĠĠĠĠĠ', 'Ġpass', 'Ċ', '']\n"
184 | ]
185 | }
186 | ],
187 | "source": [
188 | "# Step 2: Tokenize using Byte Pair Encoding (BPE)\n",
189 | "print(tokenizer.convert_ids_to_tokens(result['input_ids']))"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 29,
195 | "metadata": {},
196 | "outputs": [
197 | {
198 | "name": "stdout",
199 | "output_type": "stream",
200 | "text": [
201 | "['', '', '', 'str', '', 'str', '', 'str', '', 'List', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', '', 'List', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', 'Ġ+', 'ĠList', '[', 'str', ']', '', 'Path', '', 'Path', 'Ġ.', 'ĠPath', '', 'Path', 'Ġ.', 'ĠPath', '', 'Path', 'Ġ.', 'ĠPath', 'Ġ[', 'Ġstr', ']', '']\n"
202 | ]
203 | }
204 | ],
205 | "source": [
206 | "# Step 3: Let model predict a sequence of types using BPE\n",
207 | "print(tokenizer.convert_ids_to_tokens(result['output_ids']))"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 30,
213 | "metadata": {},
214 | "outputs": [
215 | {
216 | "name": "stdout",
217 | "output_type": "stream",
218 | "text": [
219 | "[str, str, str, Any, Any, Path, Path.Path, Path.Path, Path.Path[str]]\n"
220 | ]
221 | }
222 | ],
223 | "source": [
224 | "# Step 4: Extract the predicted types\n",
225 | "print(result['predicted_types'])"
226 | ]
227 | }
228 | ],
229 | "metadata": {
230 | "interpreter": {
231 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8"
232 | },
233 | "kernelspec": {
234 | "display_name": "Python 3.9.7 ('.venv': pipenv)",
235 | "language": "python",
236 | "name": "python3"
237 | },
238 | "language_info": {
239 | "codemirror_mode": {
240 | "name": "ipython",
241 | "version": 3
242 | },
243 | "file_extension": ".py",
244 | "mimetype": "text/x-python",
245 | "name": "python",
246 | "nbconvert_exporter": "python",
247 | "pygments_lexer": "ipython3",
248 | "version": "3.10.4"
249 | },
250 | "orig_nbformat": 4
251 | },
252 | "nbformat": 4,
253 | "nbformat_minor": 2
254 | }
255 |
--------------------------------------------------------------------------------
/scripts/archive/debug_mypy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 26,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Daemon stopped\n",
13 | "Daemon started\n",
14 | "---checking code_1---\n",
15 | "Success: no issues found in 1 source file\n",
16 | "---checking code_2---\n",
17 | "code.py:5: error: Incompatible return value type (got \"int\", expected \"str\")\n",
18 | "code.py:7: error: Incompatible return value type (got \"int\", expected \"str\")\n",
19 | "Found 2 errors in 1 file (checked 2 source files)\n",
20 | "---checking code_3---\n",
21 | "code.py:5: error: Incompatible return value type (got \"int\", expected \"str\")\n",
22 | "code.py:7: error: Incompatible return value type (got \"int\", expected \"str\")\n",
23 | "Found 2 errors in 1 file (checked 2 source files)\n",
24 | "test finished.\n",
25 | "---wait and check code_3 again---\n",
26 | "Success: no issues found in 2 source files\n",
27 | "test finished.\n"
28 | ]
29 | }
30 | ],
31 | "source": [
32 | "from pathlib import Path\n",
33 | "import subprocess\n",
34 | "import time\n",
35 | "\n",
36 | "# no type error\n",
37 | "code_1 = '''\n",
38 | "from typing import Any\n",
39 | "def fib(n: int) -> int:\n",
40 | " if n == 0:\n",
41 | " return 0\n",
42 | " elif n == 1:\n",
43 | " return 1\n",
44 | " else:\n",
45 | " return fib(n-1) + fib(n-2)\n",
46 | "'''\n",
47 | "\n",
48 | "# incorrect return type\n",
49 | "code_2 = '''\n",
50 | "from typing import Any\n",
51 | "def fib(n: int) -> str:\n",
52 | " if n == 0:\n",
53 | " return 0\n",
54 | " elif n == 1:\n",
55 | " return 1\n",
56 | " else:\n",
57 | " return fib(n-1) + fib(n-2)\n",
58 | "'''\n",
59 | "\n",
60 | "# changed return type to Any, should not error\n",
61 | "code_3 = '''\n",
62 | "from typing import Any\n",
63 | "def fib(n: int) -> Any:\n",
64 | " if n == 0:\n",
65 | " return 0\n",
66 | " elif n == 1:\n",
67 | " return 1\n",
68 | " else:\n",
69 | " return fib(n-1) + fib(n-2)\n",
70 | "'''\n",
71 | "\n",
72 | "# this should be the dmypy path in the current virtual env\n",
73 | "dmypy_path = '/home/jiayi/Projects/SPOT/.venv/bin/dmypy'\n",
74 | "\n",
75 | "check_dir = Path(\"temp/type_check\")\n",
76 | "check_dir.mkdir(exist_ok=True, parents=True)\n",
77 | "with open(check_dir / \"code.py\", \"w\") as f:\n",
78 | " f.write(code_1)\n",
79 | "subprocess.run(['python', dmypy_path, 'restart', '--', '--follow-imports=skip'],cwd=check_dir)\n",
80 | "\n",
81 | "print('---checking code_1---')\n",
82 | "subprocess.run(['python', dmypy_path, 'check', '.'],cwd=check_dir)\n",
83 | "\n",
84 | "with open(check_dir / \"code.py\", \"w\") as f:\n",
85 | " f.write(code_2)\n",
86 | "print('---checking code_2---')\n",
87 | "subprocess.run(['python', dmypy_path, 'recheck', \"--update\", \"code.py\"],cwd=check_dir)\n",
88 | "\n",
89 | "with open(check_dir / \"code.py\", \"w\") as f:\n",
90 | " f.write(code_3)\n",
91 | "print('---checking code_3---')\n",
92 | "subprocess.run(['python', dmypy_path, 'recheck', \"--update\", \"code.py\"],cwd=check_dir)\n",
93 | "print(\"test finished.\")\n",
94 | "\n",
95 | "print('---wait and check code_3 again---')\n",
96 | "time.sleep(1.0) # will not work if waiting time is shorter\n",
97 | "with open(check_dir / \"code.py\", \"w\") as f: # need this rewriting\n",
98 | " f.write(code_3)\n",
99 | "subprocess.run(['python', dmypy_path, 'recheck', \"--update\", \"code.py\"],cwd=check_dir)\n",
100 | "print(\"test finished.\")"
101 | ]
102 | }
103 | ],
104 | "metadata": {
105 | "interpreter": {
106 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8"
107 | },
108 | "kernelspec": {
109 | "display_name": "Python 3.10.4 ('.venv': pipenv)",
110 | "language": "python",
111 | "name": "python3"
112 | },
113 | "language_info": {
114 | "codemirror_mode": {
115 | "name": "ipython",
116 | "version": 3
117 | },
118 | "file_extension": ".py",
119 | "mimetype": "text/x-python",
120 | "name": "python",
121 | "nbconvert_exporter": "python",
122 | "pygments_lexer": "ipython3",
123 | "version": "3.10.4"
124 | },
125 | "orig_nbformat": 4
126 | },
127 | "nbformat": 4,
128 | "nbformat_minor": 2
129 | }
130 |
--------------------------------------------------------------------------------
/scripts/archive/kill_dmypy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | ps aux | grep -i dmypy | awk '{print $2}' | xargs kill
4 |
--------------------------------------------------------------------------------
/scripts/archive/train_dagger.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | import asyncio
4 | import os
5 | from typing import *
6 |
7 | from typet5.utils import get_data_dir, not_none, proj_root
8 |
9 | os.chdir(proj_root())
10 |
11 | datadir = get_data_dir()
12 |
13 | # %%
14 | # experiment configurations
15 |
16 | from termcolor import colored
17 |
18 | from typet5.data import TypeCheckSettings, get_tk_dataset_name, load_tokenized_srcsets
19 | from typet5.model import CtxArgs, DecodingArgs, ModelType, ModelWrapper
20 | from typet5.train import TrainingConfig, TypeCheckArgs
21 |
22 | use_type_checker = False
23 |
24 | config = TrainingConfig(
25 | quicktest=False,
26 | ctx_size=2048,
27 | left_margin=1024,
28 | right_margin=1023,
29 | modifications="no_type_checker",
30 | )
31 | gpu_id = 0
32 | TypeCheckSettings.temp_path = f"DAgger-{gpu_id}"
33 |
34 | if config.quicktest:
35 | print(colored("quicktest: True", "red"))
36 |
37 | project_name = "test-SPOT" if config.quicktest else "SPOT"
38 | train_ctx_args = config.train_ctx_args()
39 | tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)
40 |
41 | datasets_name = get_tk_dataset_name(
42 | drop_comments=config.drop_comments,
43 | )
44 |
45 | model_name = "DAgger-model--" + config.as_name()
46 |
47 | tk_dataset = load_tokenized_srcsets(
48 | datadir,
49 | datasets_name,
50 | data_reduction=config.data_reduction,
51 | quicktest=config.quicktest,
52 | )
53 |
54 |
55 | import torch
56 |
57 | from typet5.dagger import DAggerModel
58 |
59 | # %%
60 | # initialize the model
61 | from typet5.model import DefaultTokenizer, ModelWrapper, load_model_spot
62 |
63 | train_dec_args = DecodingArgs(
64 | sampling_max_tokens=8 * config.ctx_size,
65 | ctx_args=config.dec_ctx_args(),
66 | do_sample=True, # use necleus sampling during training
67 | top_p=0.9,
68 | )
69 |
70 | model = load_model_spot("Salesforce/codet5-base")
71 | wrapper = ModelWrapper(model, DefaultTokenizer, train_dec_args)
72 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
73 | wrapper.to(device)
74 | dmodel = DAggerModel(wrapper, use_type_checker=use_type_checker)
75 |
76 |
77 | # %%
78 | # pre-train evaluation
79 | # from typet5.utils import pretty_print_dict
80 |
81 | # eval_r = asyncio.run(dmodel.eval_on_data(tk_dataset["test"][0:50]))
82 | # pretty_print_dict(eval_r.accuracies)
83 |
84 |
85 | import shutil
86 |
87 | import wandb
88 |
89 | # %%
90 | # train the model
91 | from typet5.dagger import DAggerArgs
92 | from typet5.utils import run_long_task
93 |
94 | ckpt_dir = datadir / f"checkpoints/running/{model_name}"
95 |
96 | with run_long_task("DAgger training"):
97 | wandb.init(
98 | project=project_name,
99 | name=model_name,
100 | config=config.as_dict(),
101 | dir=str(datadir),
102 | )
103 |
104 | dargs = DAggerArgs(
105 | ckpt_dir,
106 | grad_accum_steps=config.grad_accum_labels,
107 | replay_buffer_size=1000,
108 | )
109 |
110 | finished = False
111 | try:
112 | asyncio.run(
113 | dmodel.train_on_data(
114 | tk_dataset["train"],
115 | dargs,
116 | log_fn=lambda t, x: wandb.log({"train/step": t, **x}),
117 | )
118 | )
119 | finished = True
120 | finally:
121 | save_tpye = "saved" if finished else "saved-emergent"
122 | save_path = datadir / f"checkpoints/{save_tpye}/{model_name}"
123 | print(colored(f"Saving trained model to: {save_path}", "blue"))
124 | shutil.rmtree(save_path, ignore_errors=True)
125 | wrapper.save(save_path)
126 |
127 | # %%
128 | # post-train full evaluation
129 | from typet5.utils import PickleCache, pretty_print_dict, pretty_show_dict
130 | from typet5.visualization import string_to_html
131 |
132 | test_dec_args = DecodingArgs(
133 | sampling_max_tokens=8 * config.ctx_size,
134 | ctx_args=CtxArgs(
135 | ctx_size=4096,
136 | left_margin=2048,
137 | right_margin=1023,
138 | ),
139 | do_sample=False,
140 | num_beams=8,
141 | )
142 | dmodel.wrapper.args = test_dec_args
143 |
144 | eval_cache = PickleCache(save_path / "eval_cache") # type: ignore
145 |
146 | eval_r = eval_cache.cached(
147 | "eval_test", lambda: asyncio.run(dmodel.eval_on_data(tk_dataset["test"]))
148 | )
149 | pretty_print_dict(eval_r.accuracies)
150 |
151 |
152 | def wandb_string(s: str):
153 | return wandb.Html(string_to_html(s))
154 |
155 |
156 | wandb.log({"test/accuracies": wandb_string(pretty_show_dict(eval_r.accuracies))})
157 |
158 | # %%
159 | # compute valid set performance
160 | import re
161 |
162 | from typet5.utils import not_none
163 |
164 | validset = tk_dataset["valid"][0:-1:3]
165 | # dmodel.wrapper.args = train_dec_args
166 |
167 | with run_long_task("DAgger evaluating (valid set)"):
168 | for model_path in ckpt_dir.glob("step=*"):
169 | print(colored(f"Evaluating model checkpoint: {model_path}", "blue"))
170 | m = not_none(re.match("step=(.+)", model_path.name)).groups()[0]
171 | step = int(m)
172 | wrapper = ModelWrapper.load(model_path)
173 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
174 | wrapper.to(device)
175 | dmodel = DAggerModel(wrapper)
176 | eval_r = asyncio.run(dmodel.eval_on_data(validset))
177 | wandb.log(
178 | {
179 | "valid/full_acc": eval_r.accuracies["full_acc"].acc,
180 | "train/step": step,
181 | }
182 | )
183 |
--------------------------------------------------------------------------------
/scripts/collect_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "import json\n",
13 | "import logging\n",
14 | "import os\n",
15 | "import shutil\n",
16 | "import subprocess\n",
17 | "import time\n",
18 | "from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed\n",
19 | "from pathlib import Path\n",
20 | "\n",
21 | "import libcst as cst\n",
22 | "from tqdm import tqdm\n",
23 | "\n",
24 | "from typet5.data import GitRepo, get_dataset_dir\n",
25 | "from typet5.type_env import collect_annots_info, mypy_checker\n",
26 | "from typet5.utils import proj_root, read_file, write_file, not_none\n",
27 | "\n",
28 | "os.chdir(proj_root())"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 4,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "name": "stdout",
38 | "output_type": "stream",
39 | "text": [
40 | "Repos already downloaded.\n",
41 | "Reading last updates...\n"
42 | ]
43 | },
44 | {
45 | "name": "stderr",
46 | "output_type": "stream",
47 | "text": [
48 | "100%|██████████| 4890/4890 [00:27<00:00, 175.00it/s]"
49 | ]
50 | },
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "Downloaded 4890/5996 repos.\n"
56 | ]
57 | },
58 | {
59 | "name": "stderr",
60 | "output_type": "stream",
61 | "text": [
62 | "\n"
63 | ]
64 | }
65 | ],
66 | "source": [
67 | "# download all candidate repos\n",
68 | "\n",
69 | "all_repos = json.loads(read_file(\"data/mypy-dependents-by-stars.json\"))\n",
70 | "all_repos = [GitRepo.from_json(r) for r in all_repos]\n",
71 | "# all_repos=all_repos[:10] # for testing\n",
72 | "\n",
73 | "repos_dir = get_dataset_dir(\"ManyTypes4Py\") / \"repos\"\n",
74 | "\n",
75 | "def clear_downloaded_repos(repos_dir):\n",
76 | " shutil.rmtree(repos_dir)\n",
77 | "\n",
78 | "\n",
79 | "def download_repos(\n",
80 | " to_download: list[GitRepo], repos_dir, download_timeout=10.0, max_workers=10\n",
81 | ") -> list[GitRepo]:\n",
82 | " def download_single(repo: GitRepo):\n",
83 | " try:\n",
84 | " if repo.download(repos_dir, timeout=download_timeout):\n",
85 | " repo.read_last_update(repos_dir)\n",
86 | " return repo\n",
87 | " else:\n",
88 | " return None\n",
89 | " except subprocess.TimeoutExpired:\n",
90 | " return None\n",
91 | " except Exception as e:\n",
92 | " logging.warning(f\"Failed to download {repo.name}. Exception: {e}\")\n",
93 | " return None\n",
94 | "\n",
95 | " print(\"Downloading repos from Github...\")\n",
96 | " t_start = time.time()\n",
97 | " with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
98 | " fs = [executor.submit(download_single, repo) for repo in to_download]\n",
99 | " rs = [f.result() for f in tqdm(as_completed(fs), total=len(fs))]\n",
100 | " print(f\"Downloading took {time.time() - t_start} seconds.\")\n",
101 | " downloaded = [r for r in rs if r is not None]\n",
102 | " return downloaded\n",
103 | "\n",
104 | "\n",
105 | "if not repos_dir.exists():\n",
106 | " (repos_dir / \"downloading\").mkdir(parents=True)\n",
107 | " (repos_dir / \"downloaded\").mkdir(parents=True)\n",
108 | " downloaded_repos = download_repos(all_repos, repos_dir)\n",
109 | " print(\"Deleting failed repos...\")\n",
110 | " shutil.rmtree(repos_dir / \"downloading\")\n",
111 | "else:\n",
112 | " print(\"Repos already downloaded.\")\n",
113 | " downloaded_dirs = set(d.name for d in (repos_dir / \"downloaded\").iterdir())\n",
114 | " downloaded_repos = [r for r in all_repos if r.authorname() in downloaded_dirs]\n",
115 | " print(\"Reading last updates...\")\n",
116 | " for r in tqdm(downloaded_repos):\n",
117 | " r.read_last_update(repos_dir)\n",
118 | "print(f\"Downloaded {len(downloaded_repos)}/{len(all_repos)} repos.\")\n"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 33,
124 | "metadata": {},
125 | "outputs": [
126 | {
127 | "name": "stdout",
128 | "output_type": "stream",
129 | "text": [
130 | "1218 / 4890 repos are updated within a year.\n"
131 | ]
132 | },
133 | {
134 | "name": "stderr",
135 | "output_type": "stream",
136 | "text": [
137 | "100%|██████████| 1218/1218 [00:05<00:00, 243.41it/s]"
138 | ]
139 | },
140 | {
141 | "name": "stdout",
142 | "output_type": "stream",
143 | "text": [
144 | "1181/1218 repos are within the size limit.\n"
145 | ]
146 | },
147 | {
148 | "name": "stderr",
149 | "output_type": "stream",
150 | "text": [
151 | "\n"
152 | ]
153 | }
154 | ],
155 | "source": [
156 | "# filter out repos that are too old or too big\n",
157 | "\n",
158 | "from datetime import datetime, timezone\n",
159 | "\n",
160 | "date_threshold = datetime(2021, 4, 20)\n",
161 | "new_repos = [r for r in downloaded_repos if not_none(r.last_update) > date_threshold]\n",
162 | "print(f\"{len(new_repos)} / {len(downloaded_repos)} repos are updated within a year.\")\n",
163 | "loc_limit = 50000\n",
164 | "\n",
165 | "small_repos = []\n",
166 | "for rep in tqdm(new_repos):\n",
167 | " try:\n",
168 | " loc = rep.count_lines_of_code(repos_dir)\n",
169 | " if loc < loc_limit:\n",
170 | " small_repos.append(rep)\n",
171 | " except UnicodeDecodeError:\n",
172 | " # nothing we can do\n",
173 | " pass\n",
174 | " except Exception as e:\n",
175 | " logging.warning(f\"Failed to count lines of code for {rep.name}. Exception: {e}\")\n",
176 | "\n",
177 | "print(\n",
178 | " f\"{len(small_repos)}/{len(new_repos)} repos are within the size limit ({loc_limit} LOC).\"\n",
179 | ")\n"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "# filter away repos with too few annotations\n",
189 | "\n",
190 | "def count_repo_annots(rep):\n",
191 | " try:\n",
192 | " rep.count_annotations(repos_dir)\n",
193 | " if rep.n_type_annots / rep.lines_of_code > 0.05:\n",
194 | " return rep\n",
195 | " except Exception as e:\n",
196 | " logging.warning(f\"Failed to count annotations for {rep.name}. Exception: {e}\")\n",
197 | " return None\n",
198 | "\n",
199 | "\n",
200 | "with ProcessPoolExecutor(max_workers=30) as executor:\n",
201 | " fs = [executor.submit(count_repo_annots, rep) for rep in small_repos]\n",
202 | " rs = [f.result() for f in tqdm(as_completed(fs), total=len(fs))]\n",
203 | "useful_repos: list[GitRepo] = [\n",
204 | " r for r in rs if r is not None and \"typeshed\" not in r.name\n",
205 | "]\n",
206 | "\n",
207 | "print(\n",
208 | " f\"{len(useful_repos)}/{len(small_repos)} repos are parsable and have enough portions of type annotations.\"\n",
209 | ")\n"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 35,
215 | "metadata": {},
216 | "outputs": [
217 | {
218 | "name": "stdout",
219 | "output_type": "stream",
220 | "text": [
221 | "Total number of manual annotations: 343595\n",
222 | "Total number of type places: 544497\n",
223 | "Total number of lines of code: 3342911\n"
224 | ]
225 | },
226 | {
227 | "data": {
228 | "text/plain": [
229 | "[GitRepo(author='skorokithakis', name='catt', url='https://github.com/skorokithakis/catt', stars=1740, forks=762, lines_of_code=2036, last_update=datetime.datetime(2022, 4, 10, 1, 30, 43), n_type_annots=140, n_type_places=433),\n",
230 | " GitRepo(author='encode', name='databases', url='https://github.com/encode/databases', stars=769, forks=48, lines_of_code=3124, last_update=datetime.datetime(2022, 3, 6, 12, 25, 10), n_type_annots=323, n_type_places=498),\n",
231 | " GitRepo(author='Curt-Park', name='rainbow-is-all-you-need', url='https://github.com/Curt-Park/rainbow-is-all-you-need', stars=490, forks=110, lines_of_code=107, last_update=datetime.datetime(2022, 1, 13, 23, 4, 48), n_type_annots=26, n_type_places=30),\n",
232 | " GitRepo(author='jreese', name='aiomultiprocess', url='https://github.com/jreese/aiomultiprocess', stars=585, forks=45, lines_of_code=1140, last_update=datetime.datetime(2022, 2, 4, 21, 28, 7), n_type_annots=138, n_type_places=213),\n",
233 | " GitRepo(author='instaloader', name='instaloader', url='https://github.com/instaloader/instaloader', stars=874, forks=134, lines_of_code=5417, last_update=datetime.datetime(2022, 4, 18, 9, 49, 34), n_type_annots=569, n_type_places=843)]"
234 | ]
235 | },
236 | "execution_count": 35,
237 | "metadata": {},
238 | "output_type": "execute_result"
239 | }
240 | ],
241 | "source": [
242 | "# Some summary statistics\n",
243 | "\n",
244 | "# print total number of manual annotations\n",
245 | "n_total_annots = sum(not_none(rep.n_type_annots) for rep in useful_repos)\n",
246 | "print(\"Total number of manual annotations:\", n_total_annots)\n",
247 | "\n",
248 | "# print total number of type places\n",
249 | "n_total_places = sum(not_none(rep.n_type_places) for rep in useful_repos)\n",
250 | "print(\"Total number of type places:\", n_total_places)\n",
251 | "\n",
252 | "# print total number of lines of code\n",
253 | "n_total_lines = sum(not_none(rep.lines_of_code) for rep in useful_repos)\n",
254 | "print(\"Total number of lines of code:\", n_total_lines)\n",
255 | "\n",
256 | "# print average number of type annotations per line of code excluding projects with more than 1000 lines of code\n",
257 | "n_avg_annots = (\n",
258 | " sum(not_none(rep.n_type_annots) for rep in useful_repos if rep.lines_of_code < 1000)\n",
259 | " / n_total_lines\n",
260 | ")\n"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": 5,
266 | "metadata": {},
267 | "outputs": [
268 | {
269 | "name": "stdout",
270 | "output_type": "stream",
271 | "text": [
272 | "[GitRepo(author='typeddjango', name='pytest-mypy-plugins', url='https://github.com/typeddjango/pytest-mypy-plugins', stars=12, forks=0, lines_of_code=1039, last_update=datetime.datetime(2022, 4, 18, 23, 25, 40), n_type_annots=155, n_type_places=158), GitRepo(author='jfly', name='jfly.github.io', url='https://github.com/jfly/jfly.github.io', stars=0, forks=0, lines_of_code=650, last_update=datetime.datetime(2022, 4, 12, 8, 23, 39), n_type_annots=39, n_type_places=122), GitRepo(author='seattleflu', name='id3c', url='https://github.com/seattleflu/id3c', stars=2, forks=0, lines_of_code=8883, last_update=datetime.datetime(2022, 4, 21, 15, 38, 59), n_type_annots=675, n_type_places=1068)]\n"
273 | ]
274 | }
275 | ],
276 | "source": [
277 | "import pickle\n",
278 | "\n",
279 | "useful_repos_path = proj_root() / \"scripts\" / \"useful_repos.pkl\"\n",
280 | "with useful_repos_path.open(\"wb\") as f:\n",
281 | " pickle.dump(useful_repos, f)\n",
282 | "print(f\"Saved {len(useful_repos)} useful repos to {useful_repos_path}.\")\n",
283 | "with useful_repos_path.open(\"rb\") as f:\n",
284 | " print(pickle.load(f)[:3])\n"
285 | ]
286 | },
287 | {
288 | "attachments": {},
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": [
292 | "The cell below tries to split the dataset based on the original dataset split used by the paper. But since that the list of repos returned by the GitHub query above can change over time, some repos might no longer be present. You might consider perform your own splitting if the issue is serious."
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "metadata": {},
299 | "outputs": [],
300 | "source": [
301 | "from typet5.utils import pickle_load, Path, proj_root, tqdm\n",
302 | "import shutil\n",
303 | "from typet5.data import GitRepo, get_dataset_dir\n",
304 | "\n",
305 | "os.chdir(proj_root())\n",
306 | "\n",
307 | "repos_split = pickle_load(Path(\"data/repos_split.pkl\"))\n",
308 | "repos_dir = get_dataset_dir(\"ManyTypes4Py\") / \"repos\"\n",
309 | "\n",
310 | "for split, repos in repos_split.items():\n",
311 | " for r in tqdm(repos, desc=f\"Moving {split} repos.\"):\n",
312 | " r: GitRepo\n",
313 | " split: str\n",
314 | " src = repos_dir / r.authorname()\n",
315 | " (repos_dir / split).mkdir(parents=True, exist_ok=True)\n",
316 | " dest = repos_dir / split / r.authorname()\n",
317 | " if src.exists():\n",
318 | " shutil.move(src, dest)\n",
319 | " else:\n",
320 | " print(f\"Repo {r.name} not found.\")"
321 | ]
322 | }
323 | ],
324 | "metadata": {
325 | "interpreter": {
326 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8"
327 | },
328 | "kernelspec": {
329 | "display_name": "Python 3.9.7 ('.venv': pipenv)",
330 | "language": "python",
331 | "name": "python3"
332 | },
333 | "language_info": {
334 | "codemirror_mode": {
335 | "name": "ipython",
336 | "version": 3
337 | },
338 | "file_extension": ".py",
339 | "mimetype": "text/x-python",
340 | "name": "python",
341 | "nbconvert_exporter": "python",
342 | "pygments_lexer": "ipython3",
343 | "version": "3.10.8"
344 | },
345 | "orig_nbformat": 4
346 | },
347 | "nbformat": 4,
348 | "nbformat_minor": 2
349 | }
350 |
--------------------------------------------------------------------------------
/scripts/experiments/eval_file_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 7,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "The autoreload extension is already loaded. To reload it, use:\n",
13 | " %reload_ext autoreload\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "%load_ext autoreload\n",
19 | "%autoreload 2\n",
20 | "\n",
21 | "import asyncio\n",
22 | "import os\n",
23 | "from typing import *\n",
24 | "\n",
25 | "import torch\n",
26 | "import wandb\n",
27 | "from typet5.data import get_tk_dataset_name\n",
28 | "from typet5.function_dataset import data_project_from_dir\n",
29 | "from typet5.model import ModelWrapper\n",
30 | "from typet5.train import TrainingConfig, PreprocessArgs\n",
31 | "from typet5.type_env import AccuracyMetric\n",
32 | "from typet5.utils import (\n",
33 | " PickleCache,\n",
34 | " assert_eq,\n",
35 | " get_dataroot,\n",
36 | " get_dataset_dir,\n",
37 | " get_eval_dir,\n",
38 | " get_gpu_id,\n",
39 | " get_model_dir,\n",
40 | " pickle_dump,\n",
41 | " pmap,\n",
42 | " pretty_print_dict,\n",
43 | " pretty_show_dict,\n",
44 | " proj_root,\n",
45 | " run_long_task,\n",
46 | " write_file,\n",
47 | ")\n",
48 | "from typet5.visualization import string_to_html\n",
49 | "from termcolor import colored\n",
50 | "\n",
51 | "os.chdir(proj_root())\n",
52 | "\n",
53 | "\n",
54 | "def wandb_string(s: str):\n",
55 | " return wandb.Html(string_to_html(s))"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 8,
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "GPU_ID not set, using: 1\n",
68 | "\u001b[32mUse GPU: 1\u001b[0m\n"
69 | ]
70 | }
71 | ],
72 | "source": [
73 | "# experiment configurations\n",
74 | "quicktest = False\n",
75 | "\n",
76 | "gpu_id = get_gpu_id(1)\n",
77 | "# model_name = \"model-v6--TrainingConfig(func_only=False, left_margin=2048, preamble_size=800, right_margin=1536)\"\n",
78 | "model_name = \"model-v6--TrainingConfig(func_only=False, imports_in_preamble=False, stub_in_preamble=False, left_margin=2048, right_margin=1536)\"\n",
79 | "pre_args = PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)\n",
80 | "dataset_name = \"ManyTypes4Py\"\n",
81 | "# dataset_name = \"InferTypes4Py\"\n",
82 | "# dataset_name = \"SPOT-src\"\n",
83 | "experiment_name = dataset_name + \": \" + model_name\n",
84 | "\n",
85 | "print(colored(f\"Use GPU: {gpu_id}\", \"green\"))\n"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 9,
91 | "metadata": {},
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "Loading TokenizedSrcSets: /mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)\n",
98 | "254M\t/mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "# load test data\n",
104 | "from typet5.data import load_tokenized_srcsets, create_tokenized_srcsets\n",
105 | "\n",
106 | "sdata_name = get_tk_dataset_name(dataset_name, pre_args, func_only=False)\n",
107 | "sdata_path = get_dataroot() / \"TokenizedSrcSets\" / sdata_name\n",
108 | "recreate=False\n",
109 | "if recreate or not sdata_path.exists():\n",
110 | " create_tokenized_srcsets(\n",
111 | " dataset_name,\n",
112 | " sdata_path,\n",
113 | " func_only=False,\n",
114 | " pre_args=pre_args,\n",
115 | " )\n",
116 | "tk_dataset = load_tokenized_srcsets(\n",
117 | " sdata_path,\n",
118 | " quicktest=quicktest,\n",
119 | " sets_to_load=[\"test\"],\n",
120 | ")\n"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 10,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "# model evaluation\n",
130 | "\n",
131 | "from typet5.function_decoding import (\n",
132 | " DecodingOrders,\n",
133 | " EvalResult,\n",
134 | " PreprocessArgs,\n",
135 | " RolloutCtx,\n",
136 | ")\n",
137 | "from typet5.function_dataset import sigmap_from_file_predictions\n",
138 | "from typet5.static_analysis import SignatureErrorAnalysis\n",
139 | "\n",
140 | "# load model\n",
141 | "model = ModelWrapper.from_pretrained(get_model_dir() / model_name)\n",
142 | "device = torch.device(f\"cuda:{gpu_id}\" if torch.cuda.is_available() else \"cpu\")\n",
143 | "model.to(device)\n",
144 | "\n",
145 | "ctx_args = model.args.ctx_args\n",
146 | "model.args.sampling_max_tokens = ctx_args.ctx_size\n",
147 | "model.args.do_sample = False\n",
148 | "model.args.num_beams = 10\n",
149 | "model.args.tokens_per_type = 16\n",
150 | "\n",
151 | "eval_cache = PickleCache(get_eval_dir(dataset_name, model_name) / f\"{pre_args}\")\n",
152 | "# eval_cache.clear()\n",
153 | "pre_r = eval_cache.cached(\n",
154 | " \"DatasetPredResult.pkl\",\n",
155 | " lambda: model.eval_on_dataset(tk_dataset[\"test\"]),\n",
156 | ")\n"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 11,
162 | "metadata": {},
163 | "outputs": [
164 | {
165 | "name": "stderr",
166 | "output_type": "stream",
167 | "text": [
168 | "Loading test projects: 100%|██████████| 50/50 [00:27<00:00, 1.85it/s]\n"
169 | ]
170 | },
171 | {
172 | "name": "stdout",
173 | "output_type": "stream",
174 | "text": [
175 | "Accuracies on all types:\n",
176 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
177 | "67.07 & 67.47 & 72.12 & 44.05 & 73.44\n",
178 | "Accuracies on common types:\n",
179 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
180 | "76.74 & 78.04 & 82.43 & 53.03 & 82.44\n",
181 | "Accuracies on rare types:\n",
182 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
183 | "49.47 & 52.95 & 57.28 & 34.26 & 57.65\n",
184 | "full_acc:\n",
185 | " full_acc: 67.07% (count=15.7k)\n",
186 | " full_acc_by_cat:\n",
187 | " FuncArg: 62.00% (count=8.0k)\n",
188 | " FuncReturn: 77.89% (count=5.8k)\n",
189 | " ClassAtribute: 55.36% (count=1.8k)\n",
190 | " GlobalVar: 63.55% (count=107)\n",
191 | " full_acc_by_simple:\n",
192 | " complex: 41.80% (count=3.3k)\n",
193 | " simple: 73.71% (count=12.4k)\n",
194 | " full_acc_label_size: 1.4194\n",
195 | " full_acc_pred_size: 1.4107\n",
196 | " full_acc_ignored_labels: 0\n",
197 | " n_missing_types: 53\n",
198 | "full_acc_common:\n",
199 | " full_acc_common: 76.74% (count=10.1k)\n",
200 | " full_acc_common_by_cat:\n",
201 | " FuncArg: 76.04% (count=5.3k)\n",
202 | " FuncReturn: 76.74% (count=3.8k)\n",
203 | " ClassAtribute: 79.57% (count=984)\n",
204 | " GlobalVar: 88.10% (count=84)\n",
205 | " full_acc_common_by_simple:\n",
206 | " complex: 48.52% (count=1.8k)\n",
207 | " simple: 82.65% (count=8.4k)\n",
208 | " full_acc_common_label_size: 1.3693\n",
209 | " full_acc_common_pred_size: 1.3379\n",
210 | " full_acc_common_ignored_labels: 5569\n",
211 | " n_missing_types: 53\n",
212 | "full_acc_rare:\n",
213 | " full_acc_rare: 49.47% (count=5.6k)\n",
214 | " full_acc_rare_by_cat:\n",
215 | " FuncArg: 51.05% (count=3.2k)\n",
216 | " FuncReturn: 48.03% (count=2.0k)\n",
217 | " ClassAtribute: 44.10% (count=356)\n",
218 | " GlobalVar: 41.67% (count=36)\n",
219 | " full_acc_rare_by_simple:\n",
220 | " complex: 34.04% (count=1.5k)\n",
221 | " simple: 55.24% (count=4.1k)\n",
222 | " full_acc_rare_label_size: 1.5103\n",
223 | " full_acc_rare_pred_size: 1.543\n",
224 | " full_acc_rare_ignored_labels: 10129\n",
225 | " n_missing_types: 53\n",
226 | "acc:\n",
227 | " acc: 67.47% (count=13.2k)\n",
228 | " acc_by_cat:\n",
229 | " FuncArg: 66.28% (count=6.7k)\n",
230 | " FuncReturn: 68.91% (count=4.9k)\n",
231 | " ClassAtribute: 68.28% (count=1.5k)\n",
232 | " GlobalVar: 63.64% (count=99)\n",
233 | " acc_by_simple:\n",
234 | " complex: 44.05% (count=2.2k)\n",
235 | " simple: 72.12% (count=11.0k)\n",
236 | " acc_label_size: 1.3155\n",
237 | " acc_pred_size: 1.2971\n",
238 | " acc_ignored_labels: 2521\n",
239 | " n_missing_types: 53\n",
240 | "acc_common:\n",
241 | " acc_common: 78.04% (count=7.6k)\n",
242 | " acc_common_by_cat:\n",
243 | " FuncArg: 77.28% (count=4.0k)\n",
244 | " FuncReturn: 78.69% (count=2.7k)\n",
245 | " ClassAtribute: 80.51% (count=816)\n",
246 | " GlobalVar: 64.91% (count=57)\n",
247 | " acc_common_by_simple:\n",
248 | " complex: 53.03% (count=1.1k)\n",
249 | " simple: 82.43% (count=6.5k)\n",
250 | " acc_common_label_size: 1.2978\n",
251 | " acc_common_pred_size: 1.2665\n",
252 | " acc_common_ignored_labels: 8072\n",
253 | " n_missing_types: 53\n",
254 | "acc_rare:\n",
255 | " acc_rare: 52.95% (count=5.6k)\n",
256 | " acc_rare_by_cat:\n",
257 | " FuncArg: 55.36% (count=3.2k)\n",
258 | " FuncReturn: 50.94% (count=2.0k)\n",
259 | " ClassAtribute: 43.26% (count=356)\n",
260 | " GlobalVar: 44.44% (count=36)\n",
261 | " acc_rare_by_simple:\n",
262 | " complex: 34.26% (count=1.0k)\n",
263 | " simple: 57.28% (count=4.5k)\n",
264 | " acc_rare_label_size: 1.3399\n",
265 | " acc_rare_pred_size: 1.3392\n",
266 | " acc_rare_ignored_labels: 10147\n",
267 | " n_missing_types: 53\n",
268 | "base_acc:\n",
269 | " base_acc: 73.44% (count=13.2k)\n",
270 | " base_acc_by_cat:\n",
271 | " FuncArg: 72.19% (count=6.7k)\n",
272 | " FuncReturn: 74.45% (count=4.9k)\n",
273 | " ClassAtribute: 75.77% (count=1.5k)\n",
274 | " GlobalVar: 72.73% (count=99)\n",
275 | " base_acc_ignored_labels: 2521\n",
276 | " n_missing_types: 53\n",
277 | "base_acc_common:\n",
278 | " base_acc_common: 82.44% (count=8.4k)\n",
279 | " base_acc_common_by_cat:\n",
280 | " FuncArg: 82.34% (count=4.4k)\n",
281 | " FuncReturn: 82.62% (count=3.1k)\n",
282 | " ClassAtribute: 82.96% (count=886)\n",
283 | " GlobalVar: 73.02% (count=63)\n",
284 | " base_acc_common_ignored_labels: 7305\n",
285 | " n_missing_types: 53\n",
286 | "base_acc_rare:\n",
287 | " base_acc_rare: 57.65% (count=4.8k)\n",
288 | " base_acc_rare_by_cat:\n",
289 | " FuncArg: 58.01% (count=2.8k)\n",
290 | " FuncReturn: 55.16% (count=1.6k)\n",
291 | " ClassAtribute: 65.96% (count=332)\n",
292 | " GlobalVar: 67.74% (count=31)\n",
293 | " base_acc_rare_ignored_labels: 10914\n",
294 | " n_missing_types: 53\n"
295 | ]
296 | }
297 | ],
298 | "source": [
299 | "repos_dir = get_dataset_dir(dataset_name) / \"repos\" / \"test\"\n",
300 | "test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]\n",
301 | "test_projects = pmap(\n",
302 | " data_project_from_dir,\n",
303 | " test_repo_paths,\n",
304 | " desc=\"Loading test projects\",\n",
305 | ")\n",
306 | "assert len(test_projects) > 0\n",
307 | "\n",
308 | "common_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)\n",
309 | "pred_map, label_map = sigmap_from_file_predictions(pre_r, test_projects, repos_dir)\n",
310 | "accs = {\n",
311 | " m.name: SignatureErrorAnalysis(pred_map, label_map, m).accuracies\n",
312 | " for m in AccuracyMetric.default_metrics(common_names)\n",
313 | "}\n",
314 | "\n",
315 | "from typet5.experiments.typet5 import accs_as_table_row\n",
316 | "accs_as_table_row(accs)\n",
317 | "pretty_print_dict(accs)"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 12,
323 | "metadata": {},
324 | "outputs": [
325 | {
326 | "name": "stderr",
327 | "output_type": "stream",
328 | "text": [
329 | "Exporting: 100%|██████████| 1851/1851 [00:18<00:00, 100.04it/s]\n",
330 | "Computing accuracies: 100%|██████████| 1851/1851 [00:00<00:00, 9748.66it/s]\n"
331 | ]
332 | }
333 | ],
334 | "source": [
335 | "from typet5.utils import decode_tokens, Path\n",
336 | "from typet5.visualization import export_preds_on_code\n",
337 | "\n",
338 | "export_to = Path(f\"caches/model_predictions/eval_file_model/{dataset_name}\")\n",
339 | "export_preds_on_code(pre_r.chunks, pre_r.predictions, export_to, AccuracyMetric(common_names))"
340 | ]
341 | }
342 | ],
343 | "metadata": {
344 | "kernelspec": {
345 | "display_name": "Python 3.10.4 ('.venv': pipenv)",
346 | "language": "python",
347 | "name": "python3"
348 | },
349 | "language_info": {
350 | "codemirror_mode": {
351 | "name": "ipython",
352 | "version": 3
353 | },
354 | "file_extension": ".py",
355 | "mimetype": "text/x-python",
356 | "name": "python",
357 | "nbconvert_exporter": "python",
358 | "pygments_lexer": "ipython3",
359 | "version": "3.10.4"
360 | },
361 | "orig_nbformat": 4,
362 | "vscode": {
363 | "interpreter": {
364 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8"
365 | }
366 | }
367 | },
368 | "nbformat": 4,
369 | "nbformat_minor": 2
370 | }
371 |
--------------------------------------------------------------------------------
/scripts/experiments/eval_func_model.py:
--------------------------------------------------------------------------------
1 | # evaluate the model trained on fucntional dataset using
2 | # incremental decoding.
3 |
4 | # %%
5 | import asyncio
6 | import copy
7 | import os
8 | from typing import *
9 |
10 | import torch
11 | import wandb
12 | from termcolor import colored
13 |
14 | from typet5.experiments.typet5 import TypeT5Configs
15 | from typet5.function_dataset import data_project_from_dir
16 | from typet5.model import ModelWrapper
17 | from typet5.train import PreprocessArgs, TrainingConfig
18 | from typet5.type_env import AccuracyMetric
19 | from typet5.utils import (
20 | assert_eq,
21 | get_dataset_dir,
22 | get_eval_dir,
23 | get_gpu_id,
24 | get_model_dir,
25 | get_modified_args,
26 | pickle_dump,
27 | pickle_load,
28 | pmap,
29 | pretty_show_dict,
30 | proj_root,
31 | run_long_task,
32 | write_file,
33 | )
34 | from typet5.visualization import string_to_html
35 |
36 | os.chdir(proj_root())
37 |
38 |
39 | def wandb_string(s: str):
40 | return wandb.Html(string_to_html(s))
41 |
42 |
43 | # %%
44 |
45 | # experiment configurations
46 |
47 | load_results = True
48 | use_oracle = True
49 | gpu_id = get_gpu_id(1)
50 | train_config = TypeT5Configs.Default
51 |
52 | model_name = train_config.get_model_name()
53 | # model_name = (
54 | # "model-v7--TrainingConfig(drop_env_types=False, add_implicit_rel_imports=True)"
55 | # )
56 | dataset_name = "ManyTypes4Py"
57 | # dataset_name = "InferTypes4Py"
58 | # dataset_name = "TinyEval"
59 |
60 | test_pre_args = train_config.pre_args
61 | oracle_tag = "(use-oracle) " if use_oracle else ""
62 | # group_tag = "(implicit_imports, new) "
63 | # group_tag = "(ablation) "
64 | group_tag = ""
65 | experiment_name = oracle_tag + group_tag + model_name
66 |
67 | print(colored(f"Use GPU: {gpu_id}", "green"))
68 |
69 | # %%
70 |
71 | # load model
72 | model = ModelWrapper.load(get_model_dir() / model_name)
73 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
74 | model.to(device)
75 | print(f"Model loaded:", model_name)
76 |
77 | # load test projects
78 | repos_dir = get_dataset_dir(dataset_name) / "repos" / "test"
79 | test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
80 | if not load_results:
81 | test_projects = pmap(
82 | data_project_from_dir,
83 | test_repo_paths,
84 | desc="Loading test projects",
85 | )
86 | assert len(test_projects) > 0
87 |
88 | # %%
89 |
90 | from typet5.experiments.typet5 import accs_as_table_row
91 | from typet5.function_decoding import DecodingOrders, EvalResult, RolloutCtx
92 |
93 | ctx_args = model.args.ctx_args
94 | ctx_args.max_labels = 16
95 | model.args.sampling_max_tokens = ctx_args.ctx_size
96 | model.args.do_sample = False
97 | model.args.num_beams = 16
98 | model.args.tokens_per_type = 16
99 |
100 | rctx = RolloutCtx(model=model)
101 |
102 | decode_orders = {
103 | # "double-traversal": DecodingOrders.DoubleTraversal(),
104 | # "reverse-double-traversal": DecodingOrders.Reversed(
105 | # DecodingOrders.DoubleTraversal()
106 | # ),
107 | # "non-incr": DecodingOrders.IndependentOrder(),
108 | # "random": DecodingOrders.RandomOrder(),
109 | # "no-neighbors": DecodingOrders.IndependentOrder(),
110 | "callee2caller": DecodingOrders.Callee2Caller(),
111 | # "caller2callee": DecodingOrders.Caller2Callee(),
112 | # "random-twice": DecodingOrders.RandomTwice(),
113 | }
114 |
115 | metrics = AccuracyMetric.default_metrics(model.common_type_names)
116 | with run_long_task("Evaluating different decoding strategy", notify=not load_results):
117 | results_dir = get_eval_dir(dataset_name, experiment_name)
118 | results_dir.mkdir(exist_ok=True, parents=True)
119 | print(colored(f"Results will be saved to: {str(results_dir)}", "green"))
120 |
121 | if not load_results:
122 | wandb.init(
123 | project="SPOT-eval",
124 | name=dataset_name + ": " + experiment_name,
125 | dir=str(results_dir),
126 | config=get_modified_args(model.args),
127 | )
128 |
129 | evals = dict[str, EvalResult]()
130 | for oname, order in decode_orders.items():
131 | result_path = results_dir / f"{oname}-EvalResult.pkl"
132 | if not load_results:
133 | print(f"Evaluating decoding strategy: {oname}")
134 | pre_args = copy.deepcopy(test_pre_args)
135 | if oname == "no-neighbors":
136 | pre_args.max_callers = 0
137 | pre_args.max_callees = 0
138 | evalr = asyncio.run(
139 | rctx.evaluate_on_projects(
140 | test_projects, # type: ignore
141 | pre_args,
142 | order,
143 | use_oracle=use_oracle,
144 | )
145 | )
146 | pickle_dump(result_path, evalr)
147 | else:
148 | if not result_path.exists():
149 | print(f"Result file not found, skip: {result_path}")
150 | continue
151 | evalr = pickle_load(result_path)
152 | evals[oname] = evalr
153 | accs = {m.name: evalr.error_analysis(None, m).accuracies for m in metrics}
154 | accs_str = pretty_show_dict(accs)
155 | write_file(results_dir / f"{oname}-accuracy.txt", accs_str)
156 | if not load_results:
157 | wandb.log({f"test/{oname}": wandb_string(accs_str)})
158 | print(f"========== {oname} ===========")
159 | print(accs_str)
160 | accs_as_table_row(accs)
161 |
162 | # %%
163 | if False:
164 | # print predictions
165 | for oname, evalr in evals.items():
166 | print(f"========== {oname} ===========")
167 | evalr.print_predictions()
168 |
169 | import prettytable as pt
170 |
171 | # %%
172 | from prettytable import PrettyTable
173 |
174 | common_type_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)
175 | results_table = PrettyTable()
176 | results_table.field_names = ["order", *(m.name for m in metrics)]
177 | results_table.align = "r"
178 | results_table.set_style(pt.SINGLE_BORDER)
179 | results_table.float_format = ".4"
180 |
181 | for oname in evals:
182 | accs = [
183 | evals[oname].error_analysis(None, metric).accuracies[metric.name].acc
184 | for metric in metrics
185 | ]
186 | results_table.add_row([oname, *accs])
187 |
188 | print(results_table)
189 | write_file(results_dir / "comparison.txt", results_table.get_string())
190 |
--------------------------------------------------------------------------------
/scripts/experiments/run_hityper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "from typet5.utils import *\n",
13 | "os.chdir(proj_root())\n"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 15,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "name": "stderr",
23 | "output_type": "stream",
24 | "text": [
25 | "Removing newer syntax: 100%|██████████| 1594/1594 [00:01<00:00, 873.90it/s] \n"
26 | ]
27 | },
28 | {
29 | "name": "stdout",
30 | "output_type": "stream",
31 | "text": [
32 | "1594 / 1594 files have been rewritten.\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "from typet5.experiments.utils import remove_newer_syntax_for_repo, Path\n",
38 | "from typet5.experiments.type4py import Type4PySupportedSyntax\n",
39 | "from typet5.utils import get_dataset_dir\n",
40 | "import shutil\n",
41 | "\n",
42 | "# dataset_name = \"InferTypes4Py\"\n",
43 | "dataset_name = \"ManyTypes4Py\"\n",
44 | "repos_dir = get_dataset_dir(dataset_name) / \"repos\"\n",
45 | "shutil.rmtree(repos_dir / \"test-hityper\", ignore_errors=True)\n",
46 | "shutil.copytree(repos_dir / \"test\", repos_dir / \"test-hityper\")\n",
47 | "remove_newer_syntax_for_repo(repos_dir / \"test-hityper\", Type4PySupportedSyntax)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 16,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "from typet5.experiments.hityper import run_hityper, eval_hityper_on_repos\n",
57 | "\n",
58 | "test_repos = [\n",
59 | " p\n",
60 | " for p in (repos_dir / \"test-hityper\").iterdir()\n",
61 | " if p.is_dir()\n",
62 | "]\n",
63 | "\n",
64 | "hityper_path = Path(\"/home/jiayi/anaconda3/envs/hityper/bin/python\")\n",
65 | "out_dir = Path(\"output/hityper\")\n",
66 | "cache = PickleCache(Path(f\"caches/run_hityper\"))\n",
67 | "\n",
68 | "eval_r = cache.cached(f\"{dataset_name}.pkl\", lambda: eval_hityper_on_repos(test_repos, hityper_path, out_dir, max_workers=4))"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 17,
74 | "metadata": {},
75 | "outputs": [
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | "Accuracies on all types:\n",
81 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
82 | "42.53 & 41.95 & 44.86 & 19.02 & 47.88\n",
83 | "Accuracies on common types:\n",
84 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
85 | "59.20 & 54.28 & 57.70 & 26.44 & 59.01\n",
86 | "Accuracies on rare types:\n",
87 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
88 | "10.30 & 25.51 & 27.59 & 9.79 & 29.33\n",
89 | "----------------------------------------\n",
90 | "full_acc:\n",
91 | " full_acc: 42.53% (count=10.5k)\n",
92 | " full_acc_by_cat:\n",
93 | " FuncArg: 26.80% (count=5.9k)\n",
94 | " FuncReturn: 63.41% (count=4.5k)\n",
95 | " GlobalVar: 18.84% (count=69)\n",
96 | " full_acc_by_simple:\n",
97 | " complex: 14.59% (count=1.7k)\n",
98 | " simple: 47.85% (count=8.8k)\n",
99 | " full_acc_label_size: 1.3993\n",
100 | " full_acc_pred_size: 1.3325\n",
101 | " full_acc_ignored_labels: 0\n",
102 | " n_skipped_types: 3293\n",
103 | " n_missing_types: 1922\n",
104 | "full_acc_common:\n",
105 | " full_acc_common: 59.20% (count=6.9k)\n",
106 | " full_acc_common_by_cat:\n",
107 | " FuncArg: 58.18% (count=4.0k)\n",
108 | " FuncReturn: 60.82% (count=2.9k)\n",
109 | " GlobalVar: 50.00% (count=52)\n",
110 | " full_acc_common_by_simple:\n",
111 | " complex: 20.43% (count=940)\n",
112 | " simple: 65.27% (count=6.0k)\n",
113 | " full_acc_common_label_size: 1.3505\n",
114 | " full_acc_common_pred_size: 1.2762\n",
115 | " full_acc_common_ignored_labels: 3592\n",
116 | " n_skipped_types: 3293\n",
117 | " n_missing_types: 1922\n",
118 | "full_acc_rare:\n",
119 | " full_acc_rare: 10.30% (count=3.6k)\n",
120 | " full_acc_rare_by_cat:\n",
121 | " FuncArg: 10.04% (count=2.2k)\n",
122 | " FuncReturn: 10.80% (count=1.4k)\n",
123 | " GlobalVar: 0.00% (count=9)\n",
124 | " full_acc_rare_by_simple:\n",
125 | " complex: 7.24% (count=746)\n",
126 | " simple: 11.10% (count=2.8k)\n",
127 | " full_acc_rare_label_size: 1.4936\n",
128 | " full_acc_rare_pred_size: 1.4413\n",
129 | " full_acc_rare_ignored_labels: 6944\n",
130 | " n_skipped_types: 3293\n",
131 | " n_missing_types: 1922\n",
132 | "acc:\n",
133 | " acc: 41.95% (count=8.4k)\n",
134 | " acc_by_cat:\n",
135 | " FuncArg: 41.69% (count=4.8k)\n",
136 | " FuncReturn: 42.24% (count=3.5k)\n",
137 | " GlobalVar: 45.61% (count=57)\n",
138 | " acc_by_simple:\n",
139 | " complex: 19.02% (count=941)\n",
140 | " simple: 44.86% (count=7.4k)\n",
141 | " acc_label_size: 1.3238\n",
142 | " acc_pred_size: 1.2303\n",
143 | " acc_ignored_labels: 2180\n",
144 | " n_skipped_types: 3293\n",
145 | " n_missing_types: 1922\n",
146 | "acc_common:\n",
147 | " acc_common: 54.28% (count=4.8k)\n",
148 | " acc_common_by_cat:\n",
149 | " FuncArg: 53.86% (count=2.9k)\n",
150 | " FuncReturn: 54.78% (count=1.8k)\n",
151 | " GlobalVar: 68.00% (count=25)\n",
152 | " acc_common_by_simple:\n",
153 | " complex: 26.44% (count=522)\n",
154 | " simple: 57.70% (count=4.3k)\n",
155 | " acc_common_label_size: 1.3241\n",
156 | " acc_common_pred_size: 1.2162\n",
157 | " acc_common_ignored_labels: 5763\n",
158 | " n_skipped_types: 3293\n",
159 | " n_missing_types: 1922\n",
160 | "acc_rare:\n",
161 | " acc_rare: 25.51% (count=3.6k)\n",
162 | " acc_rare_by_cat:\n",
163 | " FuncArg: 23.96% (count=2.2k)\n",
164 | " FuncReturn: 28.09% (count=1.3k)\n",
165 | " GlobalVar: 22.22% (count=9)\n",
166 | " acc_rare_by_simple:\n",
167 | " complex: 9.79% (count=419)\n",
168 | " simple: 27.59% (count=3.2k)\n",
169 | " acc_rare_label_size: 1.3235\n",
170 | " acc_rare_pred_size: 1.249\n",
171 | " acc_rare_ignored_labels: 6953\n",
172 | " n_skipped_types: 3293\n",
173 | " n_missing_types: 1922\n",
174 | "base_acc:\n",
175 | " base_acc: 47.88% (count=8.4k)\n",
176 | " base_acc_by_cat:\n",
177 | " FuncArg: 47.48% (count=4.8k)\n",
178 | " FuncReturn: 48.34% (count=3.5k)\n",
179 | " GlobalVar: 54.39% (count=57)\n",
180 | " base_acc_ignored_labels: 2180\n",
181 | " n_skipped_types: 3293\n",
182 | " n_missing_types: 1922\n",
183 | "base_acc_common:\n",
184 | " base_acc_common: 59.01% (count=5.2k)\n",
185 | " base_acc_common_by_cat:\n",
186 | " FuncArg: 58.99% (count=3.1k)\n",
187 | " FuncReturn: 59.19% (count=2.1k)\n",
188 | " GlobalVar: 48.28% (count=29)\n",
189 | " base_acc_common_ignored_labels: 5313\n",
190 | " n_skipped_types: 3293\n",
191 | " n_missing_types: 1922\n",
192 | "base_acc_rare:\n",
193 | " base_acc_rare: 29.33% (count=3.1k)\n",
194 | " base_acc_rare_by_cat:\n",
195 | " FuncArg: 29.17% (count=2.0k)\n",
196 | " FuncReturn: 29.44% (count=1.2k)\n",
197 | " GlobalVar: 57.14% (count=7)\n",
198 | " base_acc_rare_ignored_labels: 7403\n",
199 | " n_skipped_types: 3293\n",
200 | " n_missing_types: 1922\n"
201 | ]
202 | }
203 | ],
204 | "source": [
205 | "from typet5.static_analysis import SignatureErrorAnalysis, AccuracyMetric\n",
206 | "from typet5.experiments.typet5 import accs_as_table_row, ModelWrapper\n",
207 | "\n",
208 | "\n",
209 | "common_names = ModelWrapper.load_common_type_names(\n",
210 | " get_model_dir() / \"model-v7--TrainingConfig(drop_env_types=False)\"\n",
211 | ")\n",
212 | "metrics = AccuracyMetric.default_metrics(common_type_names=common_names)\n",
213 | "# acc_metric = AccuracyMetric(common_type_names=ubiq_names)\n",
214 | "\n",
215 | "accs = {\n",
216 | " m.name: SignatureErrorAnalysis(\n",
217 | " eval_r.pred_maps,\n",
218 | " eval_r.label_maps,\n",
219 | " m,\n",
220 | " error_on_mismatched_signature=False,\n",
221 | " ).accuracies\n",
222 | " for m in metrics\n",
223 | "}\n",
224 | "\n",
225 | "accs_as_table_row(accs)\n",
226 | "print(\"-\" * 40)\n",
227 | "pretty_print_dict(accs)"
228 | ]
229 | }
230 | ],
231 | "metadata": {
232 | "kernelspec": {
233 | "display_name": "Python 3.10.4 ('.venv': pipenv)",
234 | "language": "python",
235 | "name": "python3"
236 | },
237 | "language_info": {
238 | "codemirror_mode": {
239 | "name": "ipython",
240 | "version": 3
241 | },
242 | "file_extension": ".py",
243 | "mimetype": "text/x-python",
244 | "name": "python",
245 | "nbconvert_exporter": "python",
246 | "pygments_lexer": "ipython3",
247 | "version": "3.10.4"
248 | },
249 | "orig_nbformat": 4,
250 | "vscode": {
251 | "interpreter": {
252 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8"
253 | }
254 | }
255 | },
256 | "nbformat": 4,
257 | "nbformat_minor": 2
258 | }
259 |
--------------------------------------------------------------------------------
/scripts/experiments/run_typilus.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "from typet5.utils import proj_root, os\n",
13 | "os.chdir(proj_root())"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "name": "stderr",
23 | "output_type": "stream",
24 | "text": [
25 | "Removing newer syntax: 100%|██████████| 1594/1594 [00:01<00:00, 799.31it/s]"
26 | ]
27 | },
28 | {
29 | "name": "stdout",
30 | "output_type": "stream",
31 | "text": [
32 | "1594 / 1594 files have been rewritten.\n"
33 | ]
34 | },
35 | {
36 | "name": "stderr",
37 | "output_type": "stream",
38 | "text": [
39 | "\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "from typet5.experiments.utils import remove_newer_syntax_for_repo, Path\n",
45 | "from typet5.experiments.typilus import eval_typilus_on_repos, TypilusSupportedSyntax\n",
46 | "from typet5.utils import get_dataset_dir\n",
47 | "import shutil\n",
48 | "\n",
49 | "# dataset_name = \"InferTypes4Py\"\n",
50 | "dataset_name = \"ManyTypes4Py\"\n",
51 | "repos_dir = get_dataset_dir(dataset_name) / \"repos\"\n",
52 | "shutil.rmtree(repos_dir / \"test-typilus\", ignore_errors=True)\n",
53 | "shutil.copytree(repos_dir / \"test\", repos_dir / \"test-typilus\")\n",
54 | "remove_newer_syntax_for_repo(repos_dir / \"test-typilus\", TypilusSupportedSyntax)\n"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 5,
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "name": "stderr",
64 | "output_type": "stream",
65 | "text": [
66 | "Running Typilus: 100%|██████████| 50/50 [00:46<00:00, 1.09it/s]\n",
67 | "Collecting labels: 100%|██████████| 50/50 [00:15<00:00, 3.17it/s]\n",
68 | "WARNING:root:Missing 27 predictions for module: test_fakesmtpd.syntax\n",
69 | "WARNING:root:Missing 1 predictions for module: archive\n",
70 | "WARNING:root:Missing 1 predictions for module: tests.test_utils\n",
71 | "WARNING:root:Missing 1 predictions for module: tests.test_init_import\n",
72 | "WARNING:root:Missing 1 predictions for module: tests.test_axion_plugins\n",
73 | "WARNING:root:Missing 1 predictions for module: typesafety.conftest\n",
74 | "WARNING:root:Missing 1 predictions for module: tests.test_base\n"
75 | ]
76 | },
77 | {
78 | "name": "stdout",
79 | "output_type": "stream",
80 | "text": [
81 | "Accuracies on all types:\n",
82 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
83 | "47.10 & 54.05 & 55.12 & 33.23 & 60.37\n",
84 | "Accuracies on common types:\n",
85 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
86 | "47.10 & 54.05 & 55.12 & 33.23 & 60.37\n",
87 | "Accuracies on rare types:\n",
88 | "header: ['full.all', 'calibrated.all', 'calibrated.simple', 'calibrated.complex', 'base.all']\n",
89 | "nan & nan & N/A & N/A & nan\n",
90 | "full_acc:\n",
91 | " full_acc: 47.10% (count=6.9k)\n",
92 | " full_acc_by_cat:\n",
93 | " FuncArg: 47.86% (count=4.4k)\n",
94 | " FuncReturn: 46.79% (count=1.3k)\n",
95 | " ClassAtribute: 45.40% (count=1.1k)\n",
96 | " GlobalVar: 20.45% (count=44)\n",
97 | " full_acc_by_simple:\n",
98 | " complex: 20.42% (count=529)\n",
99 | " simple: 49.31% (count=6.4k)\n",
100 | " full_acc_label_size: 1.4992\n",
101 | " full_acc_pred_size: 1.1339\n",
102 | " full_acc_ignored_labels: 0\n",
103 | " n_missing: 3958\n",
104 | " n_skipped_rare: 4892\n",
105 | "full_acc_common:\n",
106 | " full_acc_common: 47.10% (count=6.9k)\n",
107 | " full_acc_common_by_cat:\n",
108 | " FuncArg: 47.86% (count=4.4k)\n",
109 | " FuncReturn: 46.79% (count=1.3k)\n",
110 | " ClassAtribute: 45.40% (count=1.1k)\n",
111 | " GlobalVar: 20.45% (count=44)\n",
112 | " full_acc_common_by_simple:\n",
113 | " complex: 20.42% (count=529)\n",
114 | " simple: 49.31% (count=6.4k)\n",
115 | " full_acc_common_label_size: 1.4992\n",
116 | " full_acc_common_pred_size: 1.1339\n",
117 | " full_acc_common_ignored_labels: 0\n",
118 | " n_missing: 3958\n",
119 | " n_skipped_rare: 4892\n",
120 | "full_acc_rare:\n",
121 | " full_acc_rare: nan% (count=0)\n",
122 | " full_acc_rare_by_cat:\n",
123 | " full_acc_rare_by_simple:\n",
124 | " full_acc_rare_label_size: nan\n",
125 | " full_acc_rare_pred_size: nan\n",
126 | " full_acc_rare_ignored_labels: 6903\n",
127 | " n_missing: 3958\n",
128 | " n_skipped_rare: 4892\n",
129 | "acc:\n",
130 | " acc: 54.05% (count=6.9k)\n",
131 | " acc_by_cat:\n",
132 | " FuncArg: 55.27% (count=4.4k)\n",
133 | " FuncReturn: 50.30% (count=1.3k)\n",
134 | " ClassAtribute: 54.41% (count=1.1k)\n",
135 | " GlobalVar: 36.36% (count=44)\n",
136 | " acc_by_simple:\n",
137 | " complex: 33.23% (count=337)\n",
138 | " simple: 55.12% (count=6.6k)\n",
139 | " acc_label_size: 1.2988\n",
140 | " acc_pred_size: 1.0755\n",
141 | " acc_ignored_labels: 15\n",
142 | " n_missing: 3958\n",
143 | " n_skipped_rare: 4892\n",
144 | "acc_common:\n",
145 | " acc_common: 54.05% (count=6.9k)\n",
146 | " acc_common_by_cat:\n",
147 | " FuncArg: 55.27% (count=4.4k)\n",
148 | " FuncReturn: 50.30% (count=1.3k)\n",
149 | " ClassAtribute: 54.41% (count=1.1k)\n",
150 | " GlobalVar: 36.36% (count=44)\n",
151 | " acc_common_by_simple:\n",
152 | " complex: 33.23% (count=337)\n",
153 | " simple: 55.12% (count=6.6k)\n",
154 | " acc_common_label_size: 1.2988\n",
155 | " acc_common_pred_size: 1.0755\n",
156 | " acc_common_ignored_labels: 15\n",
157 | " n_missing: 3958\n",
158 | " n_skipped_rare: 4892\n",
159 | "acc_rare:\n",
160 | " acc_rare: nan% (count=0)\n",
161 | " acc_rare_by_cat:\n",
162 | " acc_rare_by_simple:\n",
163 | " acc_rare_label_size: nan\n",
164 | " acc_rare_pred_size: nan\n",
165 | " acc_rare_ignored_labels: 6903\n",
166 | " n_missing: 3958\n",
167 | " n_skipped_rare: 4892\n",
168 | "base_acc:\n",
169 | " base_acc: 60.37% (count=6.9k)\n",
170 | " base_acc_by_cat:\n",
171 | " FuncArg: 60.62% (count=4.4k)\n",
172 | " FuncReturn: 61.90% (count=1.3k)\n",
173 | " ClassAtribute: 57.94% (count=1.1k)\n",
174 | " GlobalVar: 47.73% (count=44)\n",
175 | " base_acc_ignored_labels: 15\n",
176 | " n_missing: 3958\n",
177 | " n_skipped_rare: 4892\n",
178 | "base_acc_common:\n",
179 | " base_acc_common: 60.37% (count=6.9k)\n",
180 | " base_acc_common_by_cat:\n",
181 | " FuncArg: 60.62% (count=4.4k)\n",
182 | " FuncReturn: 61.90% (count=1.3k)\n",
183 | " ClassAtribute: 57.94% (count=1.1k)\n",
184 | " GlobalVar: 47.73% (count=44)\n",
185 | " base_acc_common_ignored_labels: 15\n",
186 | " n_missing: 3958\n",
187 | " n_skipped_rare: 4892\n",
188 | "base_acc_rare:\n",
189 | " base_acc_rare: nan% (count=0)\n",
190 | " base_acc_rare_by_cat:\n",
191 | " base_acc_rare_ignored_labels: 6903\n",
192 | " n_missing: 3958\n",
193 | " n_skipped_rare: 4892\n"
194 | ]
195 | }
196 | ],
197 | "source": [
198 | "from typet5.model import ModelWrapper\n",
199 | "from typet5.static_analysis import AccuracyMetric\n",
200 | "from typet5.utils import *\n",
201 | "from typet5.experiments.typet5 import accs_as_table_row\n",
202 | "\n",
203 | "test_repos = [p for p in (repos_dir / \"test-typilus\").iterdir() if p.is_dir()]\n",
204 | "\n",
205 | "common_names = ModelWrapper.load_common_type_names(\n",
206 | " get_model_dir() / \"model-v7--TrainingConfig(drop_env_types=False)\"\n",
207 | ")\n",
208 | "metrics = AccuracyMetric.default_metrics(common_names)\n",
209 | "typilus_path = Path(\"~/Projects/typilus-action/\").expanduser()\n",
210 | "work_dir = Path(\"~/Projects/typilus-action/data_out\").expanduser()\n",
211 | "\n",
212 | "cache = PickleCache(Path(f\"caches/run_typilus\"))\n",
213 | "cache.remove(f\"{dataset_name}.pkl\")\n",
214 | "accs = cache.cached(\n",
215 | " f\"{dataset_name}.pkl\",\n",
216 | " lambda: eval_typilus_on_repos(\n",
217 | " test_repos, metrics, typilus_path, work_dir, max_workers=4\n",
218 | " ),\n",
219 | ")\n",
220 | "\n",
221 | "accs_as_table_row(accs)\n",
222 | "pretty_print_dict(accs)\n"
223 | ]
224 | }
225 | ],
226 | "metadata": {
227 | "kernelspec": {
228 | "display_name": "Python 3.10.4 ('.venv': pipenv)",
229 | "language": "python",
230 | "name": "python3"
231 | },
232 | "language_info": {
233 | "codemirror_mode": {
234 | "name": "ipython",
235 | "version": 3
236 | },
237 | "file_extension": ".py",
238 | "mimetype": "text/x-python",
239 | "name": "python",
240 | "nbconvert_exporter": "python",
241 | "pygments_lexer": "ipython3",
242 | "version": "3.10.4"
243 | },
244 | "orig_nbformat": 4,
245 | "vscode": {
246 | "interpreter": {
247 | "hash": "f6ffc72953da4dd16b2e00785be9c4013ef131f465a8658f3921b6634d4eeec8"
248 | }
249 | }
250 | },
251 | "nbformat": 4,
252 | "nbformat_minor": 2
253 | }
254 |
--------------------------------------------------------------------------------
/scripts/run_typet5.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "import os\n",
13 | "from typing import *\n",
14 | "\n",
15 | "import torch\n",
16 | "\n",
17 | "from typet5.model import ModelWrapper\n",
18 | "from typet5.train import PreprocessArgs\n",
19 | "from typet5.utils import *\n",
20 | "from typet5.function_decoding import (\n",
21 | " RolloutCtx,\n",
22 | " PreprocessArgs,\n",
23 | " DecodingOrders,\n",
24 | " AccuracyMetric,\n",
25 | ")\n",
26 | "from typet5.static_analysis import PythonProject\n",
27 | "\n",
28 | "os.chdir(proj_root())"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "data": {
38 | "application/vnd.jupyter.widget-view+json": {
39 | "model_id": "3852e189479e408a801e24d3e050cff7",
40 | "version_major": 2,
41 | "version_minor": 0
42 | },
43 | "text/plain": [
44 | "Fetching 9 files: 0%| | 0/9 [00:00, ?it/s]"
45 | ]
46 | },
47 | "metadata": {},
48 | "output_type": "display_data"
49 | },
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "model loaded\n"
55 | ]
56 | }
57 | ],
58 | "source": [
59 | "# download or load the model\n",
60 | "wrapper = ModelWrapper.load_from_hub(\"MrVPlusOne/TypeT5-v7\")\n",
61 | "device = torch.device(f\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
62 | "wrapper.to(device)\n",
63 | "print(\"model loaded\")\n"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "# set up the rollout parameters\n",
73 | "rctx = RolloutCtx(model=wrapper)\n",
74 | "pre_args = PreprocessArgs()\n",
75 | "# we use the double-traversal decoding order, where the model can make corrections \n",
76 | "# to its previous predictions in the second pass\n",
77 | "decode_order = DecodingOrders.DoubleTraversal()"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 4,
83 | "metadata": {},
84 | "outputs": [
85 | {
86 | "name": "stdout",
87 | "output_type": "stream",
88 | "text": [
89 | "ex_code_1/good: int\n",
90 | "ex_code_1/fib: (n: int) -> int\n",
91 | "ex_code_1/Wrapper.foo: (bar: int) -> int\n",
92 | "ex_code_1/Wrapper.inc: () -> str\n",
93 | "ex_code_1/int_add: (a: int, b: int) -> str\n",
94 | "ex_code_1/int_tripple_add: (a: int, b: int, c: int) -> int\n",
95 | "ex_code_2/fib: (n: int) -> int\n",
96 | "ex_code_2/foo: (bar: int) -> int\n",
97 | "ex_code_2/Bar.x: int\n",
98 | "ex_code_2/Bar.y: int\n",
99 | "ex_code_2/Bar.reset: (w0: str) -> None\n",
100 | "ex_code_2/Bar.__init__: (x: int) -> None\n",
101 | "(updated) ex_code_1/int_add: (a: int, b: int) -> int\n"
102 | ]
103 | }
104 | ],
105 | "source": [
106 | "# Use case 1: Run TypeT5 on a given project, taking advantage of existing user \n",
107 | "# annotations and only make predictions for missing types.\n",
108 | "\n",
109 | "project = PythonProject.parse_from_root(proj_root() / \"data/ex_repo\")\n",
110 | "rollout = await rctx.run_on_project(project, pre_args, decode_order)"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": 5,
116 | "metadata": {},
117 | "outputs": [
118 | {
119 | "name": "stderr",
120 | "output_type": "stream",
121 | "text": [
122 | "evaluate_on_projects: 100%|██████████| 35/35 [00:04<00:00, 7.20it/s]"
123 | ]
124 | },
125 | {
126 | "name": "stdout",
127 | "output_type": "stream",
128 | "text": [
129 | "==================== /home/jiayi/Projects/TypeT5/data/ex_repo ====================\n",
130 | "\tex_code_1/fib: (n: int) -> int\n",
131 | "\tex_code_1/Wrapper.foo: (bar: int) -> int\n",
132 | "\tex_code_1/Wrapper.inc: () -> int\n",
133 | "\tex_code_1/int_add: (a: int, b: int) -> str\n",
134 | "\tex_code_1/int_tripple_add: (a: int, b: int, c: int) -> int\n",
135 | "\tex_code_2/fib: (n: int) -> int\n",
136 | "\tex_code_2/foo: (bar: int) -> int\n",
137 | "\tex_code_2/Bar.__init__: (x: int) -> None\n",
138 | "\tex_code_2/Bar.reset: (w0: int_add) -> None\n",
139 | "\tex_code_2/Bar.foo: (z: str) -> str\n",
140 | "\tex_code_1/good: int\n",
141 | "\tex_code_1/Wrapper.x_elem: int\n",
142 | "\tex_code_1/Wrapper.y: int\n",
143 | "\tex_code_2/Bar.z: str\n",
144 | "\tex_code_2/Bar.w: int_add\n",
145 | "\tex_code_2/Bar.x: int\n",
146 | "\tex_code_2/Bar.y: int\n",
147 | "\tex_code_2/bar: Bar\n"
148 | ]
149 | },
150 | {
151 | "name": "stderr",
152 | "output_type": "stream",
153 | "text": [
154 | "\n"
155 | ]
156 | }
157 | ],
158 | "source": [
159 | "# Use case 2: Run TypeT5 on a test project where all user annotations will be treated as\n",
160 | "# labels and removed before running the model.\n",
161 | "\n",
162 | "eval_r = await rctx.evaluate_on_projects([project], pre_args, decode_order)\n",
163 | "eval_r.print_predictions()"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 6,
169 | "metadata": {},
170 | "outputs": [
171 | {
172 | "name": "stdout",
173 | "output_type": "stream",
174 | "text": [
175 | "full_acc:\n",
176 | " full_acc: 70.00% (count=10)\n",
177 | " full_acc_by_cat:\n",
178 | " FuncArg: 100.00% (count=4)\n",
179 | " FuncReturn: 0.00% (count=1)\n",
180 | " ClassAtribute: 50.00% (count=4)\n",
181 | " GlobalVar: 100.00% (count=1)\n",
182 | " full_acc_by_simple:\n",
183 | " simple: 70.00% (count=10)\n",
184 | " full_acc_label_size: 1\n",
185 | " full_acc_pred_size: 1\n",
186 | " full_acc_ignored_labels: 0\n",
187 | "full_acc_common:\n",
188 | " full_acc_common: 66.67% (count=9)\n",
189 | " full_acc_common_by_cat:\n",
190 | " FuncArg: 100.00% (count=4)\n",
191 | " FuncReturn: 0.00% (count=1)\n",
192 | " ClassAtribute: 50.00% (count=4)\n",
193 | " full_acc_common_by_simple:\n",
194 | " simple: 66.67% (count=9)\n",
195 | " full_acc_common_label_size: 1\n",
196 | " full_acc_common_pred_size: 1\n",
197 | " full_acc_common_ignored_labels: 1\n",
198 | "full_acc_rare:\n",
199 | " full_acc_rare: 100.00% (count=1)\n",
200 | " full_acc_rare_by_cat:\n",
201 | " ClassAtribute: 100.00% (count=1)\n",
202 | " full_acc_rare_by_simple:\n",
203 | " simple: 100.00% (count=1)\n",
204 | " full_acc_rare_label_size: 1\n",
205 | " full_acc_rare_pred_size: 1\n",
206 | " full_acc_rare_ignored_labels: 9\n",
207 | "acc:\n",
208 | " acc: 70.00% (count=10)\n",
209 | " acc_by_cat:\n",
210 | " FuncArg: 100.00% (count=4)\n",
211 | " FuncReturn: 0.00% (count=1)\n",
212 | " ClassAtribute: 50.00% (count=4)\n",
213 | " GlobalVar: 100.00% (count=1)\n",
214 | " acc_by_simple:\n",
215 | " simple: 70.00% (count=10)\n",
216 | " acc_label_size: 1\n",
217 | " acc_pred_size: 1\n",
218 | " acc_ignored_labels: 0\n",
219 | "acc_common:\n",
220 | " acc_common: 66.67% (count=9)\n",
221 | " acc_common_by_cat:\n",
222 | " FuncArg: 100.00% (count=4)\n",
223 | " FuncReturn: 0.00% (count=1)\n",
224 | " ClassAtribute: 50.00% (count=4)\n",
225 | " acc_common_by_simple:\n",
226 | " simple: 66.67% (count=9)\n",
227 | " acc_common_label_size: 1\n",
228 | " acc_common_pred_size: 1\n",
229 | " acc_common_ignored_labels: 1\n",
230 | "acc_rare:\n",
231 | " acc_rare: 100.00% (count=1)\n",
232 | " acc_rare_by_cat:\n",
233 | " ClassAtribute: 100.00% (count=1)\n",
234 | " acc_rare_by_simple:\n",
235 | " simple: 100.00% (count=1)\n",
236 | " acc_rare_label_size: 1\n",
237 | " acc_rare_pred_size: 1\n",
238 | " acc_rare_ignored_labels: 9\n",
239 | "base_acc:\n",
240 | " base_acc: 70.00% (count=10)\n",
241 | " base_acc_by_cat:\n",
242 | " FuncArg: 100.00% (count=4)\n",
243 | " FuncReturn: 0.00% (count=1)\n",
244 | " ClassAtribute: 50.00% (count=4)\n",
245 | " GlobalVar: 100.00% (count=1)\n",
246 | " base_acc_ignored_labels: 0\n",
247 | "base_acc_common:\n",
248 | " base_acc_common: 66.67% (count=9)\n",
249 | " base_acc_common_by_cat:\n",
250 | " FuncArg: 100.00% (count=4)\n",
251 | " FuncReturn: 0.00% (count=1)\n",
252 | " ClassAtribute: 50.00% (count=4)\n",
253 | " base_acc_common_ignored_labels: 1\n",
254 | "base_acc_rare:\n",
255 | " base_acc_rare: 100.00% (count=1)\n",
256 | " base_acc_rare_by_cat:\n",
257 | " ClassAtribute: 100.00% (count=1)\n",
258 | " base_acc_rare_ignored_labels: 9\n"
259 | ]
260 | }
261 | ],
262 | "source": [
263 | "metrics = AccuracyMetric.default_metrics(wrapper.common_type_names)\n",
264 | "for metric in metrics:\n",
265 | " accs = eval_r.error_analysis(None, metric).accuracies\n",
266 | " pretty_print_dict({metric.name: accs})\n",
267 | " "
268 | ]
269 | }
270 | ],
271 | "metadata": {
272 | "kernelspec": {
273 | "display_name": ".venv",
274 | "language": "python",
275 | "name": "python3"
276 | },
277 | "language_info": {
278 | "codemirror_mode": {
279 | "name": "ipython",
280 | "version": 3
281 | },
282 | "file_extension": ".py",
283 | "mimetype": "text/x-python",
284 | "name": "python",
285 | "nbconvert_exporter": "python",
286 | "pygments_lexer": "ipython3",
287 | "version": "3.10.8"
288 | },
289 | "orig_nbformat": 4
290 | },
291 | "nbformat": 4,
292 | "nbformat_minor": 2
293 | }
294 |
--------------------------------------------------------------------------------
/scripts/train_model.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | from typing import *
4 |
5 | from termcolor import colored
6 |
7 | from typet5.data import (
8 | TypeCheckSettings,
9 | create_tokenized_srcsets,
10 | get_tk_dataset_name,
11 | load_tokenized_srcsets,
12 | )
13 | from typet5.experiments.typet5 import TypeT5Configs
14 | from typet5.model import DecodingArgs, ModelWrapper
15 | from typet5.train import TypeCheckArgs
16 | from typet5.utils import *
17 |
18 | os.chdir(proj_root())
19 |
20 | # %%
21 | # -----------------------------------------------------------
22 | # experiment configurations
23 |
24 | gpu_id = get_gpu_id(0) # which GPU to use
25 | eval_only = False # whether to skip training and only evaluate the model
26 | recreate_dataset = False # whether to recreate the tokenized dataset if found
27 |
28 | config = TypeT5Configs.Default # which model configuration to use
29 |
30 | # %%
31 | # -----------------------------------------------------------
32 |
33 |
34 | TypeCheckSettings.temp_path = f"GPU-{gpu_id}"
35 | print(colored(f"Use GPU: {gpu_id}", "green"))
36 |
37 | if config.quicktest:
38 | print(colored("Quicktest mode", "red"))
39 | if eval_only:
40 | print(colored("Model Evaluating Mode", "blue"))
41 |
42 | project_name = "test-SPOT" if config.quicktest else "SPOT"
43 | train_ctx_args = config.train_ctx_args()
44 | if train_ctx_args.window_size < 100:
45 | print(
46 | colored(
47 | f"[Warning] window size is very small: {train_ctx_args.window_size}", "red"
48 | )
49 | )
50 | tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)
51 |
52 | max_tokens_per_file = config.ctx_size
53 | dec_args = DecodingArgs(
54 | sampling_max_tokens=8 * max_tokens_per_file,
55 | ctx_args=config.dec_ctx_args(),
56 | )
57 |
58 | dataset = config.trained_on
59 | print("Model will be trained on dataset:", colored(dataset, "blue"))
60 |
61 | sdata_name = get_tk_dataset_name(
62 | dataset, config.pre_args, config.func_only, data_reduction=config.data_reduction
63 | )
64 | sdata_path = get_dataroot() / "TokenizedSrcSets" / sdata_name
65 | if recreate_dataset or not sdata_path.exists():
66 | create_tokenized_srcsets(
67 | dataset,
68 | sdata_path,
69 | func_only=config.func_only,
70 | pre_args=config.pre_args,
71 | data_reduction=config.data_reduction,
72 | )
73 |
74 | tk_dataset = load_tokenized_srcsets(
75 | sdata_path,
76 | quicktest=config.quicktest,
77 | )
78 | print("Training set stats:")
79 | tk_dataset["train"].print_stats()
80 |
81 | model_name = config.get_model_name()
82 | print(colored(f"Training model: {model_name}", "green"))
83 |
84 | import torch
85 | import wandb
86 |
87 | # %%
88 | # -----------------------------------------------------------
89 | # train the model
90 | from typet5.train import ModelTrainingArgs, TypeCheckArgs, train_spot_model
91 | from typet5.utils import run_long_task
92 |
93 | if not eval_only:
94 | train_args = ModelTrainingArgs(
95 | train_ctx_args,
96 | dec_args,
97 | train_max_tokens=max_tokens_per_file,
98 | eval_max_tokens=2 * max_tokens_per_file,
99 | max_epochs=1,
100 | tc_args=tc_args,
101 | )
102 |
103 | wandb.init(
104 | project=project_name,
105 | name=model_name,
106 | config=config.as_dict(),
107 | dir=str(get_dataroot()),
108 | )
109 |
110 | with run_long_task("Training spot model"):
111 | wrapper = train_spot_model(
112 | tk_dataset,
113 | model_name,
114 | train_args=train_args,
115 | gpus=[gpu_id],
116 | quicktest=config.quicktest,
117 | use_small_model=config.use_small_model,
118 | use_early_stop=False,
119 | )
120 | else:
121 | wrapper = ModelWrapper.load(get_model_dir() / model_name)
122 |
123 | device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
124 | wrapper.to(device)
125 |
126 |
127 | # %%
128 | # -----------------------------------------------------------
129 | # model evaluation
130 |
131 | from typet5.type_env import AccuracyMetric
132 | from typet5.utils import PickleCache
133 | from typet5.visualization import pretty_print_dict
134 |
135 | bs_args = DecodingArgs(
136 | sampling_max_tokens=max_tokens_per_file,
137 | ctx_args=config.dec_ctx_args(),
138 | do_sample=False,
139 | num_beams=16,
140 | )
141 | wrapper.args = bs_args
142 |
143 | eval_cache = PickleCache(get_eval_dir(dataset, model_name) / "eval_cache")
144 | # eval_cache.clear()
145 | eval_r = eval_cache.cached(
146 | "dataset_pred.pkl",
147 | lambda: wrapper.eval_on_dataset(tk_dataset["test"]),
148 | )
149 | common_names = wrapper.common_type_names
150 | metrics = AccuracyMetric.default_metrics(common_names)
151 | r0_accs = {m.name: eval_r.accuracies(m) for m in metrics}
152 | print("Accuracies on all user annotations:")
153 | pretty_print_dict(r0_accs)
154 |
155 |
156 | import wandb
157 |
158 | # %%
159 | # -----------------------------------------------------------
160 | # close wandb
161 | from typet5.utils import pretty_show_dict
162 | from typet5.visualization import string_to_html
163 |
164 |
165 | def wandb_string(s: str):
166 | return wandb.Html(string_to_html(s))
167 |
168 |
169 | if not eval_only:
170 | wandb.log({f"test/accuracies": wandb_string(pretty_show_dict(r0_accs))})
171 |
172 | from typet5.function_dataset import data_project_from_dir, sigmap_from_file_predictions
173 |
174 | # %%
175 | # -----------------------------------------------------------
176 | # compute accuracies on the top-level elements
177 | from typet5.static_analysis import SignatureErrorAnalysis
178 |
179 | repos_dir = get_dataset_dir(dataset) / "repos" / "test"
180 | test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
181 | test_projects = pmap(
182 | data_project_from_dir,
183 | test_repo_paths,
184 | desc="Loading test projects",
185 | )
186 |
187 | eval_r = eval_r
188 | pred_map, label_map = sigmap_from_file_predictions(eval_r, test_projects, repos_dir)
189 | api_accs = {
190 | m.name: SignatureErrorAnalysis(pred_map, label_map, m).accuracies
191 | for m in AccuracyMetric.default_metrics(common_names)
192 | }
193 |
194 | print("Accuracies on top-level elements:")
195 | pretty_print_dict(api_accs)
196 | if not eval_only:
197 | wandb.log({f"test/api_accuracies": wandb_string(pretty_show_dict(api_accs))})
198 |
199 | # %%
200 | # -----------------------------------------------------------
201 | # export the code with inlined predictions as HTML
202 |
203 | from typet5.visualization import export_preds_on_code, proj_root
204 |
205 | export_preds = True
206 |
207 | if export_preds:
208 | max_samples = 500
209 | sample_every = max(1, len(eval_r.chunks) // max_samples)
210 | sub_ids = range(0, len(eval_r.chunks), sample_every)
211 | export_to = proj_root() / "caches" / "model_predictions" / model_name
212 | export_preds_on_code(
213 | eval_r.chunks[sub_ids],
214 | [eval_r.predictions[i] for i in sub_ids],
215 | export_to=export_to,
216 | metric=AccuracyMetric(common_names),
217 | )
218 | print(f"Model predictions exported to '{export_to}'")
219 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(
4 | name="TypeT5",
5 | version="0.1",
6 | packages=["typet5"],
7 | package_dir={"": "src"},
8 | license="BSD 3-Clause",
9 | )
10 |
--------------------------------------------------------------------------------
/src/typet5/__init__.py:
--------------------------------------------------------------------------------
1 | from .type_env import AnnotCat, PythonType
2 | from .utils import proj_root
3 |
4 | __all__ = ["AnnotCat", "PythonType", "proj_root"]
5 |
--------------------------------------------------------------------------------
/src/typet5/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/src/typet5/experiments/__init__.py
--------------------------------------------------------------------------------
/src/typet5/experiments/hityper.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import requests
4 |
5 | from typet5.experiments.type4py import Type4PyEvalResult, Type4PySupportedSyntax
6 | from typet5.experiments.utils import SupportedSyntax, remove_newer_syntax
7 | from typet5.function_dataset import SignatureMap
8 | from typet5.static_analysis import (
9 | ElemSignature,
10 | FunctionSignature,
11 | ModuleName,
12 | ProjectPath,
13 | PythonProject,
14 | VariableSignature,
15 | )
16 | from typet5.type_check import normalize_type, parse_type_expr, parse_type_str
17 | from typet5.type_env import AccuracyMetric
18 | from typet5.utils import *
19 |
20 | PredList = list[list] # of the form [[type1, score1], [type2, score2], ...]
21 |
22 |
23 | def eval_hityper_on_repos(
24 | repo_roots: list[Path],
25 | hityper_python: Path,
26 | work_dir: Path,
27 | max_workers: int | None = None,
28 | ):
29 | out_dirs = [work_dir / r.name for r in repo_roots]
30 | for out_dir in out_dirs:
31 | out_dir.mkdir(parents=True, exist_ok=True)
32 | model_outputs = pmap(
33 | run_hityper,
34 | repo_roots,
35 | out_dirs,
36 | [hityper_python] * len(repo_roots),
37 | desc="Running HiTyper",
38 | max_workers=max_workers,
39 | )
40 |
41 | projects = [
42 | PythonProject.parse_from_root(r, discard_bad_files=True) for r in repo_roots
43 | ]
44 |
45 | label_signatures: dict[str, SignatureMap] = {
46 | project.name: {e.path: e.get_signature() for e in project.all_elems()}
47 | for project in projects
48 | }
49 |
50 | pred_signatures: dict[str, SignatureMap] = {n: dict() for n in label_signatures}
51 | for proj, sigmap in zip(projects, model_outputs):
52 | pred_signatures[proj.name] = sigmap
53 |
54 | return Type4PyEvalResult(
55 | pred_maps=pred_signatures,
56 | label_maps=label_signatures,
57 | )
58 |
59 |
60 | class HiTyperResponseParser:
61 | def __init__(self, module: ModuleName):
62 | self.module = module
63 | self.assignment: SignatureMap = dict()
64 |
65 | def parse(self, res_json: dict[str, list]) -> SignatureMap:
66 | self.assignment = dict()
67 |
68 | def parse_var(x: dict) -> tuple[str, cst.Annotation | None]:
69 | return x["name"], _parse_annot(x["type"])
70 |
71 | for e_name, e_list in res_json.items():
72 | if e_name == "global@global":
73 | vars = [v := parse_var(x) for x in e_list if x["category"] == "local"]
74 | for name, annot in vars:
75 | path = ProjectPath(self.module, name)
76 | self.assignment[path] = VariableSignature(annot, in_class=False)
77 | else:
78 | name, parent = e_name.split("@")
79 | parent = "" if parent == "global" else parent
80 | path = ProjectPath(self.module, parent).append(name)
81 | params = [parse_var(x) for x in e_list if x["category"] == "arg"]
82 | returns = [parse_var(x) for x in e_list if x["category"] == "return"]
83 | rt = returns[0][1] if returns else None
84 | self.assignment[path] = FunctionSignature(
85 | {v[0]: v[1] for v in params},
86 | rt,
87 | in_class=False,
88 | )
89 |
90 | return self.assignment
91 |
92 |
93 | def run_hityper(repo: Path, out_dir: Path, python_path: Path) -> SignatureMap:
94 | out_dir.mkdir(parents=True, exist_ok=True)
95 | out = subprocess.run(
96 | [
97 | python_path,
98 | "-m",
99 | "hityper",
100 | "infer",
101 | "--type4py",
102 | "-p",
103 | repo.resolve(),
104 | "-d",
105 | out_dir.resolve(),
106 | ],
107 | cwd=out_dir,
108 | # env={"PYTHONPATH": "src"},
109 | capture_output=True,
110 | )
111 | if out.returncode != 0:
112 | raise RuntimeError(
113 | f"HiTyper failed on {repo} with error: {out.stderr.decode()}"
114 | )
115 | results = json.loads(read_file(out_dir / "inferred_types.json"))
116 | sigmap = SignatureMap()
117 | for fname, mres in results.items():
118 | mname = PythonProject.rel_path_to_module_name(
119 | Path(fname).relative_to(repo.resolve())
120 | )
121 | parser = HiTyperResponseParser(mname)
122 | sigmap.update(parser.parse(mres))
123 | return sigmap
124 |
125 |
126 | def _parse_annot(ts: list[str]) -> cst.Annotation | None:
127 | if not ts:
128 | return None
129 | try:
130 | if len(ts) == 1:
131 | return cst.Annotation(cst.parse_expression(ts[0]))
132 | else:
133 | return cst.Annotation(cst.parse_expression(" | ".join(ts)))
134 | except cst.ParserSyntaxError:
135 | return None
136 |
--------------------------------------------------------------------------------
/src/typet5/experiments/type4py.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | from typet5.experiments.utils import SupportedSyntax, remove_newer_syntax
4 | from typet5.function_dataset import SignatureMap
5 | from typet5.static_analysis import (
6 | ElemSignature,
7 | FunctionSignature,
8 | ModuleName,
9 | ProjectPath,
10 | PythonProject,
11 | VariableSignature,
12 | reorder_signature_map,
13 | )
14 | from typet5.type_check import normalize_type, parse_type_expr
15 | from typet5.utils import *
16 |
17 | PredList = list[list] # of the form [[type1, score1], [type2, score2], ...]
18 |
19 |
20 | class Type4PyResponseParser:
21 | def __init__(self, module: ModuleName):
22 | self.module = module
23 | self.assignment: dict[ProjectPath, ElemSignature] = dict()
24 |
25 | def parse(self, res_json: dict) -> dict[ProjectPath, ElemSignature]:
26 | self.assignment = dict()
27 | res = res_json["response"]
28 |
29 | for name, pred in res["variables_p"].items():
30 | annot = self.parse_prediction(pred)
31 | sig = VariableSignature(annot, in_class=False)
32 | self.assignment[ProjectPath(self.module, name)] = sig
33 | for f in res["funcs"]:
34 | self._parse_func(
35 | f,
36 | ProjectPath(self.module, ""),
37 | in_class=False,
38 | )
39 | for c in res["classes"]:
40 | self._parse_cls(c, ProjectPath(self.module, ""))
41 | return self.assignment
42 |
43 | def _parse_cls(self, cls_json: dict, base: ProjectPath):
44 | attr_preds: dict[str, PredList] = cls_json["variables_p"]
45 | new_base = base.append(cls_json["name"])
46 | for name, pred in attr_preds.items():
47 | annot = self.parse_prediction(pred)
48 | sig = VariableSignature(annot, in_class=True)
49 | self.assignment[new_base.append(name)] = sig
50 | for func_json in cls_json["funcs"]:
51 | try:
52 | self._parse_func(func_json, new_base, in_class=True)
53 | except:
54 | print(f"Failed to parse function")
55 | print("JSON:")
56 | display(func_json)
57 | raise
58 |
59 | @staticmethod
60 | def parse_prediction(pred: PredList) -> cst.Annotation | None:
61 | if pred:
62 | return cst.Annotation(cst.parse_expression(pred[0][0]))
63 | else:
64 | return None
65 |
66 | def _parse_func(self, func_json: dict, base: ProjectPath, in_class: bool):
67 | preds = func_json["params_p"]
68 | params_pred = {
69 | v: self.parse_prediction(preds[v]) for v in preds if len(preds[v]) > 0
70 | }
71 | ret_pred = None
72 | if "ret_type_p" in func_json:
73 | ret_pred = self.parse_prediction(func_json["ret_type_p"])
74 | if ret_pred is None:
75 | ret_pred = cst.Annotation(cst.parse_expression("None"))
76 | sig = FunctionSignature(
77 | params_pred,
78 | ret_pred,
79 | in_class=in_class,
80 | )
81 | self.assignment[base.append(func_json["name"])] = sig
82 |
83 |
84 | def run_type4py_request(
85 | code: str, module: ModuleName
86 | ) -> dict[ProjectPath, ElemSignature] | str:
87 | res = requests.post("https://type4py.com/api/predict?tc=0", code.encode()).json()
88 | if res["response"] is None:
89 | return res["error"]
90 | return Type4PyResponseParser(module).parse(res)
91 |
92 |
93 | @dataclass
94 | class Type4PyEvalResult:
95 | pred_maps: dict[str, SignatureMap]
96 | label_maps: dict[str, SignatureMap]
97 |
98 | def __post_init__(self):
99 | # reorder the function args to match the labels
100 | for pname, pred_map in self.pred_maps.items():
101 | if pname not in self.label_maps:
102 | continue
103 | label_map = self.label_maps[pname]
104 | self.pred_maps[pname] = reorder_signature_map(pred_map, label_map)
105 |
106 |
107 | Type4PySupportedSyntax = SupportedSyntax(
108 | pattern_match=False, union_types=False, basic_types=False
109 | )
110 |
111 |
112 | def eval_type4py_on_projects(
113 | projects: list[PythonProject],
114 | max_workers: int = 4,
115 | ) -> Type4PyEvalResult:
116 | name2project = {p.name: p for p in projects}
117 | module_srcs = {
118 | (project.name, name): remove_newer_syntax(m.tree, Type4PySupportedSyntax).code
119 | for project in projects
120 | for name, m in project.modules.items()
121 | }
122 | model_outputs = pmap(
123 | run_type4py_request,
124 | list(module_srcs.values()),
125 | [mname for pname, mname in module_srcs.keys()],
126 | desc="Calling Type4Py",
127 | max_workers=max_workers,
128 | )
129 |
130 | label_signatures: dict[str, SignatureMap] = {
131 | project.name: {e.path: e.get_signature() for e in project.all_elems()}
132 | for project in projects
133 | }
134 |
135 | pred_signatures: dict[str, SignatureMap] = {n: dict() for n in label_signatures}
136 | for (pname, mname), o in zip(module_srcs.keys(), model_outputs):
137 | if isinstance(o, str):
138 | if list(name2project[pname].modules[mname].all_elements()):
139 | # only warn for non-empty modules
140 | logging.warning(
141 | f"In project {pname} module {mname}, Type4Py errored: {o}"
142 | )
143 | else:
144 | pred_signatures[pname].update(o)
145 |
146 | return Type4PyEvalResult(
147 | pred_maps=pred_signatures,
148 | label_maps=label_signatures,
149 | )
150 |
--------------------------------------------------------------------------------
/src/typet5/experiments/typet5.py:
--------------------------------------------------------------------------------
1 | from typet5.train import *
2 |
3 |
4 | def accs_as_table_row(accs_dict: dict):
5 | def retrive(path: str):
6 | segs = path.split(".")
7 | target = accs_dict
8 | for s in segs:
9 | if s not in target:
10 | return "N/A"
11 | target = target[s]
12 | assert isinstance(target, CountedAcc), f"Unexpected type: {CountedAcc}"
13 | return f"{target.acc * 100:.2f}"
14 |
15 | def print_row(name: str, postfix: str):
16 | row = {
17 | "full.all": f"full_acc{postfix}.full_acc{postfix}",
18 | "calibrated.all": f"acc{postfix}.acc{postfix}",
19 | "calibrated.simple": f"acc{postfix}.acc{postfix}_by_simple.simple",
20 | "calibrated.complex": f"acc{postfix}.acc{postfix}_by_simple.complex",
21 | "base.all": f"base_acc{postfix}.base_acc{postfix}",
22 | }
23 |
24 | nums = [retrive(path) for path in row.values()]
25 | print(f"Accuracies on {name} types:")
26 | print("header: ", list(row.keys()))
27 | print(" & ".join(nums))
28 |
29 | print_row("all", "")
30 | print_row("common", "_common")
31 | print_row("rare", "_rare")
32 |
33 |
34 | class TypeT5Configs:
35 | Default = TrainingConfig(
36 | func_only=True,
37 | pre_args=PreprocessArgs(
38 | drop_env_types=False,
39 | add_implicit_rel_imports=True,
40 | ),
41 | left_margin=2048,
42 | right_margin=2048 - 512,
43 | preamble_size=1000,
44 | )
45 |
46 | NoPreamble = TrainingConfig(
47 | func_only=True,
48 | pre_args=PreprocessArgs(
49 | imports_in_preamble=False,
50 | stub_in_preamble=False,
51 | drop_env_types=False,
52 | add_implicit_rel_imports=True,
53 | ),
54 | left_margin=2048,
55 | right_margin=2048 - 512,
56 | preamble_size=1000,
57 | )
58 |
59 | NoUsees = TrainingConfig(
60 | func_only=True,
61 | pre_args=PreprocessArgs(
62 | max_callees=0,
63 | drop_env_types=False,
64 | add_implicit_rel_imports=True,
65 | ),
66 | left_margin=512,
67 | right_margin=2048 + 1024,
68 | preamble_size=511,
69 | )
70 |
71 | NoUsers = TrainingConfig(
72 | func_only=True,
73 | pre_args=PreprocessArgs(
74 | max_callers=0,
75 | drop_env_types=False,
76 | add_implicit_rel_imports=True,
77 | ),
78 | left_margin=2048 + 1024,
79 | right_margin=512,
80 | preamble_size=1000,
81 | )
82 |
83 | NoSequential = TrainingConfig(
84 | func_only=True,
85 | pre_args=PreprocessArgs(
86 | drop_env_types=True,
87 | add_implicit_rel_imports=True,
88 | ),
89 | left_margin=2048,
90 | right_margin=2048 - 512,
91 | preamble_size=1000,
92 | )
93 |
--------------------------------------------------------------------------------
/src/typet5/experiments/typilus.py:
--------------------------------------------------------------------------------
1 | import json
2 | import subprocess
3 |
4 | from typet5.experiments.utils import SupportedSyntax
5 | from typet5.function_dataset import collect_public_api_labels
6 | from typet5.static_analysis import LabelInfo, ModuleName, PythonProject
7 | from typet5.type_check import PythonType, parse_type_expr, parse_type_str
8 | from typet5.type_env import AccuracyMetric, type_accuracies
9 | from typet5.utils import *
10 |
11 | JSON = dict
12 |
13 |
14 | def eval_typilus_on_repos(
15 | repo_roots: list[Path],
16 | metrics: list[AccuracyMetric],
17 | typilus_path: Path,
18 | work_dir: Path,
19 | max_workers: int | None = None,
20 | ):
21 | out_dirs = [work_dir / r.name for r in repo_roots]
22 | for out_dir in out_dirs:
23 | out_dir.mkdir(parents=True, exist_ok=True)
24 | pmap(
25 | run_typilus,
26 | repo_roots,
27 | out_dirs,
28 | [typilus_path] * len(repo_roots),
29 | desc="Running Typilus",
30 | max_workers=max_workers,
31 | )
32 |
33 | typilus_outputs = []
34 | for repo in repo_roots:
35 | with open(work_dir / repo.name / "predictions.json") as f:
36 | typilus_outputs.append(json.load(f))
37 | return analyze_typilus_predictions(
38 | typilus_outputs,
39 | repo_roots,
40 | metrics,
41 | )
42 |
43 |
44 | def run_typilus(repo: Path, out_dir: Path, typilus_path: Path) -> None:
45 | typilus_python = "/home/jiayi/anaconda3/envs/typilus-torch/bin/python"
46 |
47 | out = subprocess.run(
48 | [
49 | typilus_python,
50 | "-m",
51 | "run_typilus",
52 | repo.resolve(),
53 | out_dir.resolve(),
54 | ],
55 | cwd=typilus_path,
56 | env={"PYTHONPATH": "src"},
57 | capture_output=True,
58 | )
59 | if out.returncode != 0:
60 | raise RuntimeError(
61 | f"Typilus failed on {repo} with error: {out.stderr.decode()}"
62 | )
63 |
64 |
65 | def analyze_typilus_predictions(
66 | typilus_outputs: list[JSON],
67 | repo_roots: list[Path],
68 | metrics: list[AccuracyMetric],
69 | common_labels_only: bool = True,
70 | ):
71 | assert_eq(
72 | len({r.name for r in repo_roots}),
73 | len(repo_roots),
74 | extra_message=lambda: "repo names must be unique",
75 | )
76 |
77 | # first, build the label map
78 | label_maps = pmap(collect_public_api_labels, repo_roots, desc="Collecting labels")
79 |
80 | # then, collect the prediction-label pairs
81 | pred_maps = list[dict[ModuleName, dict[CodePosition, str]]]()
82 | for out in typilus_outputs:
83 | pred_map = dict[ModuleName, dict[CodePosition, str]]()
84 | for file, preds in out.items():
85 | assert isinstance(file, str)
86 | assert isinstance(preds, list)
87 | if file.startswith("/"):
88 | file = file[1:]
89 | file_mod = PythonProject.rel_path_to_module_name(Path(file))
90 | submap = pred_map[file_mod] = dict()
91 | for pred in preds:
92 | line, col = pred["location"]
93 | submap[CodePosition(line, col)] = pred["pred"]
94 | pred_maps.append(pred_map)
95 |
96 | pred_list, label_list, cat_list = [], [], []
97 | n_missing = 0
98 | n_skipped_rare = 0
99 | for label_map, pred_map in zip(label_maps, pred_maps):
100 | for mod, labels in label_map.items():
101 | preds = pred_map.get(mod)
102 | n_labels = len(labels)
103 | if preds is None:
104 | if n_labels > 0:
105 | logging.warning(f"Missing {n_labels} predictions for module: {mod}")
106 | n_missing += n_labels
107 | continue
108 | for pos, linfo in labels.items():
109 | pred = preds.get(pos)
110 | if pred is None:
111 | n_missing += 1
112 | continue
113 |
114 | if (ltype := parse_type_expr(linfo.annot.annotation)) is None:
115 | continue
116 | try:
117 | pred = parse_type_str(pred)
118 | except SyntaxError:
119 | pred = PythonType.Any()
120 |
121 | if common_labels_only and not metrics[0].is_common_type(ltype):
122 | n_skipped_rare += 1
123 | continue
124 | label_list.append(ltype)
125 | cat_list.append(linfo.cat)
126 | pred_list.append(pred)
127 |
128 | accs = dict[str, dict]()
129 | for m in metrics:
130 | accs[m.name] = sub_a = type_accuracies(pred_list, label_list, cat_list, m)
131 | if n_missing > 0:
132 | sub_a["n_missing"] = n_missing
133 | if n_skipped_rare > 0:
134 | sub_a["n_skipped_rare"] = n_skipped_rare
135 | return accs
136 |
137 |
138 | TypilusSupportedSyntax = SupportedSyntax(
139 | pattern_match=False,
140 | union_types=False,
141 | basic_types=False,
142 | named_exprs=False,
143 | )
144 |
--------------------------------------------------------------------------------
/src/typet5/experiments/utils.py:
--------------------------------------------------------------------------------
1 | from typet5.function_decoding import EvalResult
2 | from typet5.static_analysis import (
3 | ElemSignature,
4 | FunctionSignature,
5 | ModuleName,
6 | ProjectPath,
7 | PythonModule,
8 | PythonProject,
9 | SignatureMap,
10 | VariableSignature,
11 | _VisitKind,
12 | is_type_rhs,
13 | )
14 | from typet5.type_check import MypyChecker, MypyFeedback, MypyResult
15 | from typet5.type_env import normalize_type, parse_type_expr
16 | from typet5.utils import *
17 |
18 | _DefaultImport = cst.parse_statement(
19 | "from typing import Any, List, Tuple, Dict, Set, Union, Type, Callable # SPOT"
20 | )
21 |
22 |
23 | @dataclass
24 | class SupportedSyntax:
25 | pattern_match: bool = True
26 | union_types: bool = True
27 | basic_types: bool = True
28 | named_exprs: bool = True
29 |
30 |
31 | def remove_newer_syntax(m: cst.Module, supported: SupportedSyntax) -> cst.Module:
32 | """
33 | Remove or rewrite any newer python features that Type4Py doesn't support.
34 | """
35 |
36 | class PatternRewriter(cst.CSTTransformer):
37 | def leave_MatchAs(self, node, updated: cst.MatchAs):
38 | if updated.pattern:
39 | return updated.pattern
40 | elif updated.name:
41 | return updated.name
42 | else:
43 | # wild card pattern
44 | return cst.Name("_")
45 |
46 | def pattern_to_expr(pattern: cst.MatchPattern):
47 | np = cast(cst.BaseExpression, pattern.visit(PatternRewriter()))
48 | return cst.parse_expression(m.code_for_node(np))
49 |
50 | class Rewriter(cst.CSTTransformer):
51 | def leave_Annotation(self, node, updated: "cst.Annotation"):
52 | if supported.union_types:
53 | return updated
54 | ty = parse_type_expr(updated.annotation, silent=True)
55 | if ty is None:
56 | return cst.RemoveFromParent()
57 | ty = normalize_type(ty) # this should get rid of the Union type syntax.
58 | return updated.with_changes(annotation=cst.parse_expression(str(ty)))
59 |
60 | def leave_Module(self, node, updated: "cst.Module"):
61 | new_lines = [_DefaultImport] if not supported.basic_types else []
62 | default_import = updated.code_for_node(_DefaultImport)
63 | for stmt in updated.body:
64 | if updated.code_for_node(stmt) != default_import:
65 | new_lines.append(stmt)
66 | return updated.with_changes(body=new_lines)
67 |
68 | def leave_Match(self, node, updated: cst.Match):
69 | if supported.pattern_match:
70 | return updated
71 | subject = updated.subject
72 | if isinstance(subject, cst.Tuple):
73 | subject = subject.with_changes(
74 | lpar=[cst.LeftParen()], rpar=[cst.RightParen()]
75 | )
76 |
77 | conditions = [
78 | cst.Comparison(
79 | subject,
80 | [
81 | cst.ComparisonTarget(
82 | cst.Equal(),
83 | pattern_to_expr(c.pattern),
84 | )
85 | ],
86 | )
87 | for c in updated.cases
88 | ]
89 | bodies = [c.body for c in updated.cases]
90 | if_clauses = None
91 | for cond, body in reversed(list(zip(conditions, bodies))):
92 | if_clauses = cst.If(cond, body, orelse=if_clauses)
93 | assert isinstance(if_clauses, cst.If)
94 | return if_clauses
95 |
96 | def leave_NamedExpr(self, node, updated: "cst.NamedExpr"):
97 | if supported.named_exprs:
98 | return updated
99 | return updated.value
100 |
101 | return m.visit(Rewriter())
102 |
103 |
104 | def remove_newer_syntax_for_file(file: Path, rules: SupportedSyntax) -> bool:
105 | text = read_file(file)
106 | m = cst.parse_module(text)
107 | m = remove_newer_syntax(m, rules)
108 | new_text = m.code
109 | if new_text != text:
110 | write_file(file, new_text)
111 | return True
112 | return False
113 |
114 |
115 | def remove_newer_syntax_for_repo(root: Path, rules: SupportedSyntax) -> None:
116 | all_files = [p for p in root.glob("**/*.py") if p.is_file()]
117 | changed = pmap(
118 | remove_newer_syntax_for_file,
119 | all_files,
120 | [rules] * len(all_files),
121 | desc="Removing newer syntax",
122 | )
123 | print(f"{sum(changed)} / {len(all_files)} files have been rewritten.")
124 |
125 |
126 | def apply_sigmap(
127 | m: cst.Module,
128 | sigmap: SignatureMap,
129 | module_name: ModuleName,
130 | add_default_imports=True,
131 | ) -> cst.Module:
132 | """
133 | Apply the signature map to the module.
134 | """
135 |
136 | class Rewriter(cst.CSTTransformer):
137 | def __init__(self):
138 | super().__init__()
139 | self.path_stack = [ProjectPath(module_name, "")]
140 | self.visit_stack = [_VisitKind.Root]
141 |
142 | @property
143 | def current_path(self) -> ProjectPath:
144 | return self.path_stack[-1]
145 |
146 | @property
147 | def current_visit_kind(self) -> _VisitKind:
148 | return self.visit_stack[-1]
149 |
150 | def enter_(self, name: str, kind: _VisitKind):
151 | self.path_stack.append(self.current_path.append(name))
152 | self.visit_stack.append(kind)
153 |
154 | def exit_(self):
155 | self.path_stack.pop()
156 | self.visit_stack.pop()
157 |
158 | def visit_FunctionDef(self, node: cst.FunctionDef):
159 | self.enter_(node.name.value, _VisitKind.Function)
160 |
161 | def leave_FunctionDef(self, node, updated: cst.FunctionDef):
162 | if isinstance(sig := sigmap.get(self.current_path), FunctionSignature):
163 | try:
164 | updated = sig.apply(updated)
165 | except LookupError:
166 | pass
167 | self.exit_()
168 | return updated
169 |
170 | def visit_ClassDef(self, node: "cst.ClassDef") -> Optional[bool]:
171 | self.enter_(node.name.value, _VisitKind.Class)
172 |
173 | def leave_ClassDef(self, node, updated: cst.ClassDef):
174 | self.exit_()
175 | return updated
176 |
177 | def leave_AnnAssign(self, node, updated: cst.AnnAssign):
178 | target = None
179 | match updated.target:
180 | case cst.Name(name):
181 | target = name
182 | if (
183 | target is not None
184 | and isinstance(
185 | sig := sigmap.get(self.current_path.append(target)),
186 | VariableSignature,
187 | )
188 | and sig.annot is not None
189 | ):
190 | updated = updated.with_changes(annotation=sig.annot)
191 | return updated
192 |
193 | def leave_Assign(self, node, updated: cst.Assign):
194 | target = None
195 | if self.current_visit_kind != _VisitKind.Function:
196 | match updated.targets:
197 | case [cst.AssignTarget(target=cst.Name(name))]:
198 | target = name
199 | if (
200 | target is not None
201 | and isinstance(
202 | sig := sigmap.get(self.current_path.append(target)),
203 | VariableSignature,
204 | )
205 | and sig.annot is not None
206 | and not (
207 | self.current_visit_kind == _VisitKind.Root
208 | and is_type_rhs(updated.value)
209 | ) # skip annotating type aliases
210 | ):
211 | return cst.AnnAssign(cst.Name(target), sig.annot, updated.value)
212 | return updated
213 |
214 | def leave_Module(self, node, updated: cst.Module):
215 | if add_default_imports:
216 | return updated.with_changes(
217 | body=[_DefaultImport] + list(updated.body),
218 | )
219 | return updated
220 |
221 | return m.visit(Rewriter())
222 |
223 |
224 | def quote_annotations(m: cst.Module, normalize_types: bool = True) -> cst.Module:
225 | """
226 | Quote all type annotations as strings in the module..
227 | """
228 |
229 | class Rewriter(cst.CSTTransformer):
230 | def leave_Annotation(self, node, updated: "cst.Annotation"):
231 | if updated.annotation is None:
232 | return updated
233 | if normalize_types:
234 | ty = parse_type_expr(updated.annotation)
235 | if ty is not None:
236 | text = repr(str(ty.normalized()))
237 | else:
238 | text = repr(show_expr(updated.annotation, quoted=False))
239 | else:
240 | text = repr(show_expr(updated.annotation, quoted=False))
241 | return updated.with_changes(annotation=cst.SimpleString(text))
242 |
243 | return m.visit(Rewriter())
244 |
245 |
246 | def apply_sigmap_and_typecheck(
247 | project: PythonProject,
248 | sigmap: SignatureMap,
249 | workdir: Path,
250 | quote_types=True,
251 | binary_path: Optional[Path] = None,
252 | ) -> MypyResult:
253 | assert workdir.is_dir(), f"Workdir is not a directory: {workdir}"
254 |
255 | # write the type annotated source files to the workdir
256 | for name, m in project.modules.items():
257 | # file = workdir / project.module2src_file[name]
258 | file = workdir / (name.replace(".", "/") + ".py")
259 | file.parent.mkdir(parents=True, exist_ok=True)
260 | m1 = apply_sigmap(m.tree, sigmap, name)
261 | if quote_types:
262 | m1 = quote_annotations(m1, normalize_types=True)
263 | write_file(file, m1.code)
264 | # handle __init__.py files specially
265 | files = list(workdir.glob("**/*.py"))
266 | for f in files:
267 | if (d := f.with_suffix("")).is_dir():
268 | f.rename(d / "__init__.py")
269 |
270 | # now call the type checker
271 | r = MypyChecker.check_project(workdir, binary_path)
272 | if isinstance(r, str):
273 | raise RuntimeError(f"Type checking failed: {r}")
274 | return r
275 |
276 |
277 | def count_type_errors(
278 | result: Iterable[MypyFeedback],
279 | ) -> int:
280 | error_codes = {
281 | "name-defined",
282 | "attr-defined",
283 | "arg-type",
284 | "return-value",
285 | "assignment",
286 | }
287 |
288 | n = 0
289 | for e in result:
290 | if e.error_code in error_codes:
291 | n += 1
292 | return n
293 |
294 |
295 | def collect_project_type_errors(
296 | proj: PythonProject,
297 | sigmap: SignatureMap,
298 | workdir: Path = Path("mypy_temp"),
299 | binary_path: Path | None = None,
300 | ) -> list[MypyFeedback]:
301 | workdir = workdir / proj.name
302 | shutil.rmtree(workdir, ignore_errors=True)
303 | workdir.mkdir(exist_ok=True, parents=True)
304 | try:
305 | check_r = apply_sigmap_and_typecheck(
306 | proj, sigmap, workdir, binary_path=binary_path
307 | )
308 | return [e for es in check_r.error_dict.values() for e in es]
309 | except RuntimeError as e:
310 | print("Warning: mypy failed for project:", proj.name)
311 | return []
312 |
--------------------------------------------------------------------------------
/src/typet5/model.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import Counter
3 | from copy import copy, deepcopy
4 | from typing import NamedTuple, overload
5 |
6 | import numpy as np
7 | from datasets.arrow_dataset import Dataset
8 | from huggingface_hub import snapshot_download
9 | from mypy_extensions import mypyc_attr
10 | from torch import Tensor
11 | from torch.utils.data import DataLoader, RandomSampler
12 | from transformers.data.data_collator import DataCollatorForSeq2Seq
13 |
14 | from .data import (
15 | ChunkedDataset,
16 | CtxArgs,
17 | TokenizedSrcSet,
18 | output_ids_as_types,
19 | preds_to_accuracies,
20 | )
21 | from .type_env import AccuracyMetric, PythonType
22 | from .utils import *
23 |
24 |
25 | @dataclass
26 | class DecodingArgs:
27 | ctx_args: CtxArgs
28 | sampling_max_tokens: int
29 | max_workers: int = DefaultWorkers
30 | # the maximal prediction length = tokens_per_type * num_types + slack_tokens
31 | tokens_per_type: int = 16
32 | slack_tokens: int = 10
33 | do_sample: bool = False
34 | top_p: float = 0.9
35 | num_beams: Optional[int] = None
36 | num_beam_groups: Optional[int] = None
37 | length_penalty: float = 1.0
38 | diversity_penalty: float | None = None
39 |
40 | def scale_ctx_size(self, factor: float) -> "DecodingArgs":
41 | result = deepcopy(self)
42 | assert result.ctx_args is not None
43 | """Scale the context size of the model by the given factor, while keeping the window size the same.
44 | Also scale down the sampling batch size accordingly."""
45 | ctx_size = round(self.ctx_args.ctx_size * factor)
46 | right_margin = round(self.ctx_args.right_margin * factor)
47 | left_margin = ctx_size - right_margin - self.ctx_args.window_size
48 | result.ctx_args.ctx_size = ctx_size
49 | result.ctx_args.left_margin = left_margin
50 | result.ctx_args.right_margin = right_margin
51 | result.sampling_max_tokens = round(self.sampling_max_tokens / factor**2)
52 |
53 | return result
54 |
55 | def __repr__(self) -> str:
56 | return repr_modified_args(self)
57 |
58 |
59 | @dataclass
60 | class DatasetPredResult(Generic[T1]):
61 | chunks: ChunkedDataset
62 | predictions: list[list[PythonType]]
63 | extra_info: list[T1] = field(default_factory=list)
64 |
65 | def accuracies(self, metric: AccuracyMetric) -> dict:
66 | return preds_to_accuracies(self.predictions, self.chunks, metric)
67 |
68 | def group_by_repo(self, repos_dir: Path) -> dict[Path, "DatasetPredResult[T1]"]:
69 | chunk2repo = list[Path]()
70 | for i, info in enumerate(self.chunks.chunks_info):
71 | file = repos_dir / info.src_file
72 | repo = self.chunks.file2repo[file]
73 | chunk2repo.append(repo)
74 |
75 | group2ids = groupby(range(len(chunk2repo)), lambda i: chunk2repo[i])
76 | result = dict()
77 | chunk_ids = self.chunks.data["chunk_id"]
78 | for repo, ids in group2ids.items():
79 | result[repo] = DatasetPredResult(
80 | self.chunks[(chunk_ids[i] for i in ids)],
81 | [self.predictions[i] for i in ids],
82 | [self.extra_info[i] for i in ids] if self.extra_info else [],
83 | )
84 | return result
85 |
86 |
87 | @dataclass
88 | class ModelWrapper:
89 | model: ModelType
90 | tokenizer: TokenizerType
91 | args: DecodingArgs
92 | common_type_names: set[str]
93 | monitor: TaskMonitor = EmptyLoggingMonitor()
94 |
95 | @staticmethod
96 | def get_codet5_path(use_small_model: bool = False) -> str:
97 | return (
98 | "Salesforce/codet5-small" if use_small_model else "Salesforce/codet5-base"
99 | )
100 |
101 | def scale_ctx_size(self, factor) -> "ModelWrapper":
102 | r = copy(self)
103 | r.args = r.args.scale_ctx_size(factor)
104 | return r
105 |
106 | def predict_on_batch(
107 | self,
108 | batch: dict,
109 | num_return_sequences: int | None = None,
110 | ) -> tuple[list[list[PythonType]], Tensor]:
111 | """Run the model on the given batch and return the predicted types for each row."""
112 | model = self.model
113 | args = self.args
114 | n_labels = batch["n_labels"]
115 | max_labels = max(n_labels)
116 |
117 | div_pen = args.diversity_penalty
118 | if args.num_beam_groups is not None:
119 | assert (
120 | div_pen is not None and div_pen > 0
121 | ), "num_beam_groups requires diversity_penalty > 0"
122 |
123 | output_ids = model.generate(
124 | inputs=batch["input_ids"].to(model.device),
125 | do_sample=args.do_sample,
126 | top_p=args.top_p,
127 | num_beams=args.num_beams,
128 | num_return_sequences=num_return_sequences,
129 | num_beam_groups=args.num_beam_groups,
130 | max_length=args.tokens_per_type * max_labels + args.slack_tokens,
131 | diversity_penalty=div_pen,
132 | length_penalty=args.length_penalty,
133 | renormalize_logits=True,
134 | ).cpu() # type: ignore
135 | assert len(output_ids.shape) == 2
136 |
137 | def decode_row(row, n_labels) -> list[PythonType]:
138 | return output_ids_as_types(row, n_labels)
139 |
140 | n_rows = output_ids.shape[0]
141 | if num_return_sequences is not None:
142 | assert_eq(n_rows, num_return_sequences * len(n_labels))
143 | else:
144 | num_return_sequences = 1
145 | types = [
146 | decode_row(output_ids[i, :], n_labels[i // num_return_sequences])
147 | for i in range(n_rows)
148 | ]
149 | return types, output_ids
150 |
151 | @overload
152 | def predict(
153 | self, dataset: Dataset, tqdm_args: dict = {}, num_return_sequences: None = None
154 | ) -> list[list[PythonType]]:
155 | ...
156 |
157 | @overload
158 | def predict(
159 | self, dataset: Dataset, tqdm_args: dict, num_return_sequences: int
160 | ) -> list[list[list[PythonType]]]:
161 | ...
162 |
163 | def predict(
164 | self,
165 | dataset: Dataset,
166 | tqdm_args: dict = {},
167 | num_return_sequences: Optional[int] = None,
168 | ):
169 | """Run the model on the given dataset and return the predicted types
170 | (or multiple sequences of predicted types if num_return_sequences is not none) for each row."""
171 | model = self.model
172 | collator = DataCollatorForSeq2Seq(self.tokenizer, model)
173 | loader = dynamic_dataloader(
174 | dataset, # type: ignore
175 | max_tokens=self.args.sampling_max_tokens,
176 | collate_fn=collator,
177 | shuffle=True,
178 | )
179 | device = model.device
180 | # we use this dict to keep the order of the chunks since it may be permuted by dynamic_dataloader
181 | pred_types = dict[int, list]()
182 | with tqdm(
183 | total=len(dataset), desc="predict", smoothing=0.01, **tqdm_args
184 | ) as tqdm_bar:
185 | for batch in loader:
186 | n_chunks = batch["input_ids"].shape[0]
187 | batch["input_ids"] = batch["input_ids"].to(device)
188 | preds, _ = self.predict_on_batch(batch, num_return_sequences)
189 | for i, c_id in enumerate(batch["chunk_id"]):
190 | c_id = int(c_id)
191 | if num_return_sequences is None:
192 | pred_types[c_id] = preds[i]
193 | else:
194 | pred_types[c_id] = preds[
195 | i * num_return_sequences : (i + 1) * num_return_sequences
196 | ]
197 | tqdm_bar.update(n_chunks)
198 | return [pred_types[int(c_id)] for c_id in dataset["chunk_id"]]
199 |
200 | def save(self, path: Path):
201 | """Save the model to the given path along with its tokenizer and args."""
202 | self.model.save_pretrained(str(path))
203 | self.tokenizer.save_pretrained(str(path))
204 | pickle_dump(path / "args.pkl", self.args)
205 | pickle_dump(path / "common_names.pkl", self.common_type_names)
206 |
207 | def to(self, device) -> "ModelWrapper":
208 | self.model = self.model.to(device)
209 | return self
210 |
211 | @classmethod
212 | def load_from_hub(cls, repo_name: str):
213 | path = snapshot_download(repo_name)
214 | return cls.load(Path(path))
215 |
216 | @classmethod
217 | def load(cls, path: Path) -> "ModelWrapper":
218 | """Load a pretrained model from the given path."""
219 | model = cast(ModelType, ModelType.from_pretrained(str(path)))
220 | tokenizer = TokenizerType.from_pretrained(str(path))
221 | args = pickle_load(path / "args.pkl")
222 | common_type_names = ModelWrapper.load_common_type_names(path)
223 | return ModelWrapper(
224 | model=model,
225 | tokenizer=tokenizer,
226 | args=args,
227 | common_type_names=common_type_names,
228 | monitor=TaskLoggingMonitor(path.name),
229 | )
230 |
231 | @classmethod
232 | def load_common_type_names(cls, model_path: Path) -> set[str]:
233 | if (model_path / "common_names.pkl").exists():
234 | return pickle_load(model_path / "common_names.pkl")
235 | else:
236 | return set()
237 |
238 | def eval_on_dataset(
239 | self,
240 | src_data: TokenizedSrcSet,
241 | max_labels: Optional[int] = None,
242 | tqdm_args: dict = {},
243 | ) -> DatasetPredResult:
244 | """Convinient method to preprocess the src according to the model's ctx_args and evaluate the (R0) accuracy."""
245 | ctx_args = self.args.ctx_args
246 | if max_labels is not None:
247 | ctx_args = copy(ctx_args)
248 | ctx_args.max_labels = max_labels
249 |
250 | chunks = src_data.to_chunks(ctx_args, tqdm_args=tqdm_args)
251 | preds = self.predict(
252 | chunks.data, num_return_sequences=None, tqdm_args=tqdm_args
253 | )
254 | return DatasetPredResult(chunks, preds)
255 |
256 |
257 | def dynamic_dataloader(
258 | dataset: Dataset,
259 | max_tokens: int,
260 | collate_fn,
261 | shuffle: bool = False,
262 | ):
263 | ex_sizes = [len(x) for x in dataset["input_ids"]]
264 | ids = list(range(len(ex_sizes)))
265 | if shuffle:
266 | random.shuffle(ids)
267 | ids.sort(key=lambda x: ex_sizes[x], reverse=True)
268 | batches = list[list[int]]()
269 | while len(ids) > 0:
270 | w = ex_sizes[ids[0]]
271 | n = max(1, max_tokens // w)
272 | batches.append(ids[:n])
273 | ids = ids[n:]
274 | if shuffle:
275 | random.shuffle(batches)
276 |
277 | return DataLoader(
278 | cast(Any, dataset),
279 | batch_sampler=batches,
280 | collate_fn=collate_fn,
281 | )
282 |
--------------------------------------------------------------------------------
/src/typet5/tokenized_src.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy
4 | from libcst.metadata import CodeRange, PositionProvider
5 |
6 | from .static_analysis import remove_comments, remove_imports, stub_from_module
7 | from .type_check import MypyFeedback, PythonType, parse_type_str
8 | from .type_env import (
9 | AnnotInfo,
10 | AnnotPath,
11 | CodePathManager,
12 | apply_annotations,
13 | collect_user_annotations,
14 | )
15 | from .utils import *
16 |
17 | TokenSeq = list[int] # might need to make this more space-efficient
18 |
19 |
20 | @dataclass
21 | class TokenizedSrc:
22 | """A src file with certain type annotations masked out."""
23 |
24 | file: Path
25 | repo: Path
26 | types: list[PythonType]
27 | types_pos: list[int] # the position of the types in tokenized_code.
28 | types_str: list[str]
29 | types_tks: list[TokenSeq]
30 | types_info: list[AnnotInfo]
31 | main_code: str
32 | tokenized_code: TokenSeq # with certain types masked out
33 | preamble_code: str
34 | tokenized_preamble: TokenSeq
35 | prev_types: dict[int, PythonType] | None = None # previously predicted types
36 | inlined_spans: dict[int, slice] | None = None # the spans of inlined previous types
37 | feedbacks: list[MypyFeedback] | None = None
38 |
39 | @staticmethod
40 | def parse(
41 | code: str,
42 | file: Path,
43 | repo: Path,
44 | args: "PreprocessArgs",
45 | ) -> "TokenizedSrc":
46 | d = preprocess_code(code, args)
47 | d["file"] = file
48 | d["repo"] = repo
49 | return tokenized_src_from_segs(**d)
50 |
51 | def __str__(self):
52 | segs = [
53 | "========TokenizedSrc========",
54 | f"file:{self.file}",
55 | f"repo:{self.repo}",
56 | "--------Preamble--------",
57 | self.preamble_code,
58 | "--------Main Code--------",
59 | decode_tokens(self.tokenized_code),
60 | "========End of TokenizedSrc========",
61 | ]
62 | return "\n".join(segs)
63 |
64 | def inline_prev_predictions(
65 | self, as_comment: bool, prev_types: dict[int, PythonType] | None = None
66 | ) -> "TokenizedSrc":
67 | "Inine the previous predictions into the code, either directly or as comments."
68 | if len(self.types) == 0:
69 | return copy.copy(self)
70 |
71 | if prev_types is None:
72 | prev_types = self.prev_types
73 | assert isinstance(prev_types, dict), f"prev_types has type: {type(prev_types)}"
74 | assert len(prev_types) > 0
75 |
76 | types_pos = list[int]()
77 | inlined_spans = dict[int, slice]()
78 | new_tks = list[int]()
79 | tokenizer = DefaultTokenizer
80 | mask_id = tokenizer.mask_token_id
81 | comment_start = tokenizer.encode("/* ", add_special_tokens=False)
82 | comment_end = tokenizer.encode(" */", add_special_tokens=False)
83 |
84 | start = 0
85 |
86 | for t in range(len(self.types)):
87 | new_tks.extend(self.tokenized_code[start : self.types_pos[t]])
88 | types_pos.append(len(new_tks))
89 | type_tk = self.tokenized_code[self.types_pos[t]]
90 | if t in prev_types:
91 | assert type_tk == mask_id
92 | to_insert = tokenizer.encode(
93 | str(prev_types[t]), add_special_tokens=False
94 | )
95 | if as_comment:
96 | to_insert = comment_start + to_insert + comment_end
97 | new_tks.extend(to_insert)
98 | inlined_spans[t] = slice(types_pos[t], len(new_tks))
99 | else:
100 | new_tks.append(type_tk)
101 | start = self.types_pos[t] + 1
102 | new_tks.extend(self.tokenized_code[start:])
103 |
104 | assert prev_types.keys() == inlined_spans.keys()
105 |
106 | return TokenizedSrc(
107 | file=self.file,
108 | repo=self.repo,
109 | types=self.types,
110 | types_pos=types_pos,
111 | types_str=self.types_str,
112 | types_tks=self.types_tks,
113 | types_info=self.types_info,
114 | main_code=self.main_code,
115 | tokenized_code=new_tks,
116 | preamble_code=self.preamble_code,
117 | tokenized_preamble=self.tokenized_preamble,
118 | prev_types=prev_types,
119 | inlined_spans=inlined_spans,
120 | feedbacks=self.feedbacks,
121 | )
122 |
123 | def print_code(self, max_lines: int = 100, body_only: bool = False):
124 | "Print out the (decoded) token sequence"
125 | code = decode_tokens(self.tokenized_code)
126 | if not body_only:
127 | code = decode_tokens(self.tokenized_preamble) + code
128 | print_limited(code, max_lines)
129 |
130 | @staticmethod
131 | def inline_predictions(
132 | src: "TokenizedSrc",
133 | as_comment: bool,
134 | prev_types: dict[int, PythonType] | None = None,
135 | ):
136 | return src.inline_prev_predictions(as_comment=as_comment, prev_types=prev_types)
137 |
138 |
139 | @dataclass
140 | class PreprocessArgs:
141 | imports_in_preamble: bool = True
142 | stub_in_preamble: bool = True
143 | drop_comments: bool = True
144 | max_callees: int = 80 # only applicable to functional dataset
145 | max_callers: int = 20 # only applicable to functional dataset
146 | drop_env_types: bool = False # only applicable to functional dataset
147 | add_override_usages: bool = False # only applicable to functional dataset
148 | add_implicit_rel_imports: bool = True # only applicable to functional dataset
149 |
150 |
151 | def preprocess_code(code: str, args: PreprocessArgs) -> dict:
152 | """Preprocess the Python code to carve out all the type annotations. The original
153 | code is split into sequences at the type annotations."""
154 | m = cst.parse_module(code)
155 | preamble_segs = list[str]()
156 |
157 | if args.drop_comments:
158 | m = remove_comments(m)
159 | if args.imports_in_preamble:
160 | m, imports = remove_imports(m)
161 | imports_part = cst.Module([cst.SimpleStatementLine([s]) for s in imports])
162 | preamble_segs.append(imports_part.code)
163 | if args.stub_in_preamble:
164 | stub_m = stub_from_module(m)
165 | preamble_segs.append(stub_m.code)
166 |
167 | cst_code = m.code
168 | annots_info, types = collect_user_annotations(m)
169 | types_str = [
170 | m.code_for_node(not_none(info.annot).annotation) for info in annots_info
171 | ]
172 | mask_annot = cst.Annotation(cst.Name(SpecialNames.TypeMask))
173 | replaces = dict()
174 | for info in annots_info:
175 | replaces[info.path] = mask_annot
176 | new_code = apply_annotations(m, replaces).code
177 | code_segs = new_code.split(SpecialNames.TypeMask)
178 |
179 | assert (
180 | len(code_segs) == len(types) + 1
181 | ), f"{len(code_segs)} != {len(types) + 1}. replaces: {replaces}\ncode: {new_code}"
182 | return {
183 | "preamble": "".join(preamble_segs),
184 | "code_segs": code_segs,
185 | "types": types,
186 | "types_str": types_str,
187 | "annots_info": annots_info,
188 | "cst_code": cst_code,
189 | "prev_types": None,
190 | }
191 |
192 |
193 | def tokenized_src_from_segs(
194 | file: Path,
195 | repo: Path,
196 | cst_code: str,
197 | preamble: str,
198 | code_segs: list[str],
199 | types: list[PythonType],
200 | types_str: list[str],
201 | annots_info: list[AnnotInfo],
202 | tokenized_preamble: list[int] | None = None,
203 | prev_types: dict[int, PythonType] | None = None,
204 | is_label: list[bool] | None = None,
205 | left_extra_tks: list[int] | None = None,
206 | right_extra_tks: list[int] | None = None,
207 | ) -> TokenizedSrc:
208 | tkn = DefaultTokenizer
209 | r = TokenizedSrc(
210 | file=file,
211 | repo=repo,
212 | main_code=cst_code,
213 | tokenized_code=list[int](),
214 | preamble_code=preamble,
215 | tokenized_preamble=tkn.encode(preamble, add_special_tokens=False)
216 | if tokenized_preamble is None
217 | else tokenized_preamble,
218 | types=list[PythonType](),
219 | types_pos=list[int](),
220 | types_str=list[str](),
221 | types_info=list[AnnotInfo](),
222 | types_tks=list[list[int]](),
223 | prev_types=prev_types,
224 | )
225 |
226 | mask_id = not_none(tkn.mask_token_id)
227 | all_tks = r.tokenized_code
228 | if left_extra_tks:
229 | all_tks.extend(left_extra_tks)
230 | for i in range(len(code_segs) - 1):
231 | all_tks.extend(tkn.encode(code_segs[i], add_special_tokens=False))
232 | if is_label is None or is_label[i]:
233 | r.types_pos.append(len(all_tks))
234 | r.types.append(types[i])
235 | r.types_tks.append(tkn.encode(str(types[i]), add_special_tokens=False))
236 | r.types_str.append(types_str[i])
237 | r.types_info.append(annots_info[i])
238 | all_tks.append(mask_id)
239 | else:
240 | all_tks.extend(tkn.encode(types_str[i], add_special_tokens=False))
241 | all_tks.extend(tkn.encode(code_segs[-1], add_special_tokens=False))
242 | if right_extra_tks:
243 | all_tks.extend(right_extra_tks)
244 |
245 | return r
246 |
247 |
248 | def feedbacks_to_tokenized_src(
249 | src: TokenizedSrc,
250 | current_code: str,
251 | feedbacks: list[MypyFeedback],
252 | patch_predictions: bool = False,
253 | ) -> TokenizedSrc:
254 | try:
255 | m = cst.parse_module(current_code)
256 | except Exception as e:
257 | raise RuntimeError(
258 | f"Failed to parse file: '{src.file}' with content:\n{current_code}"
259 | ) from e
260 | m_code = m.code
261 | assert (
262 | m_code.rstrip() == current_code.rstrip()
263 | ), f"String diffferences: {show_string_diff(current_code, m_code)}"
264 | current_annots, _ = collect_user_annotations(m)
265 | preds_map = dict[CodeRange, str]()
266 | types = list[PythonType]()
267 | prev_types = dict[int, PythonType]()
268 | types_str = list[str]()
269 | annots_info = list[AnnotInfo]()
270 | path2label_id = {info.path: i for i, info in enumerate(src.types_info)}
271 |
272 | for a in current_annots:
273 | if a.path in path2label_id:
274 | assert (r := a.annot_range) is not None
275 | assert (annot := a.annot) is not None
276 | prev_type = preds_map[r] = m.code_for_node(annot.annotation)
277 | li = path2label_id[a.path]
278 | prev_types[li] = parse_type_str(prev_type)
279 | types.append(src.types[li])
280 | types_str.append(src.types_str[li])
281 | annots_info.append(a)
282 | pos_to_msg = {f.position: f.message for f in feedbacks}
283 | new_code = patch_code_with_extra(
284 | current_code, preds_map, pos_to_msg, patch_predictions
285 | )
286 | code_segs = new_code.split(SpecialNames.TypeMask)
287 | assert (
288 | len(code_segs) == len(types) + 1
289 | ), f"{len(code_segs)} != {len(types)} + 1.\nNew Code:\n{new_code}"
290 |
291 | new_src = tokenized_src_from_segs(
292 | file=src.file,
293 | repo=src.repo,
294 | cst_code=new_code,
295 | preamble=src.preamble_code,
296 | tokenized_preamble=src.tokenized_preamble,
297 | code_segs=code_segs,
298 | types=types,
299 | types_str=types_str,
300 | annots_info=annots_info,
301 | prev_types=prev_types,
302 | )
303 | new_src.feedbacks = feedbacks
304 | return new_src
305 |
306 |
307 | def patch_code_with_extra(
308 | code: str,
309 | predictions: dict[CodeRange, str],
310 | errors: dict[CodePosition, str],
311 | patch_predictions: bool,
312 | ) -> str:
313 | replaces = []
314 | # When the ranges overlap, we want to use the order: new_prediction -> prev_prediction -> errors
315 | for r, t in predictions.items():
316 | replaces.append((r, 1, SpecialNames.TypeMask))
317 | if patch_predictions:
318 | replaces.append((CodeRange(r.start, r.start), 2, f"/* {t} */"))
319 |
320 | for p, e in errors.items():
321 | replaces.append((CodeRange(p, p), 3, f"/* error: {e} */"))
322 |
323 | return replace_strs_by_pos(code, replaces)
324 |
--------------------------------------------------------------------------------
/src/typet5/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import warnings
4 | from dataclasses import dataclass
5 | from pathlib import Path
6 | from typing import *
7 |
8 | import pytorch_lightning as pl
9 | import torch
10 | import torch.nn as nn
11 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
12 | from pytorch_lightning.loggers import WandbLogger
13 | from torch.optim import AdamW
14 | from transformers import DataCollatorForSeq2Seq
15 | from transformers.modeling_outputs import Seq2SeqLMOutput
16 |
17 | from .data import ChunkedDataset, TokenizedSrcSet
18 | from .model import (
19 | CtxArgs,
20 | DecodingArgs,
21 | ModelType,
22 | ModelWrapper,
23 | TokenizerType,
24 | dynamic_dataloader,
25 | )
26 | from .tokenized_src import PreprocessArgs
27 | from .type_check import TypeCheckArgs
28 | from .utils import *
29 |
30 |
31 | @dataclass
32 | class ModelTrainingArgs:
33 | train_ctx_args: CtxArgs
34 | dec_args: DecodingArgs
35 | train_max_tokens: int
36 | eval_max_tokens: int
37 | max_epochs: int
38 | tc_args: TypeCheckArgs
39 | accumulate_grad_batches: int | dict | None = None
40 |
41 |
42 | class TrainingConfig(NamedTuple):
43 | quicktest: bool = False
44 | func_only: bool = True # whether to use functional dataset format
45 | pre_args: PreprocessArgs = PreprocessArgs()
46 | trained_on: str = "ManyTypes4Py"
47 | data_reduction: int = 1
48 | check_in_isolation: bool = False # DAgger
49 | inline_prev_gold: bool = False
50 | ctx_size: int = 4096
51 | left_margin: int = 2048
52 | # up to how much of the left_margin to be allocated as preamble
53 | preamble_size: int = 1000
54 | right_margin: int = 2048 - 512
55 | train_max_labels: int = 32
56 | dec_max_labels: int = 16
57 | use_small_model: bool = False
58 | grad_accum_labels = 32 # DAgger
59 | modifications: str = ""
60 |
61 | def as_dict(self) -> dict[str, Any]:
62 | return {attr: getattr(self, attr) for attr in self.__annotations__}
63 |
64 | def as_name(self) -> str:
65 | return self.get_model_name()
66 |
67 | def __repr__(self):
68 | return repr_modified_args(self, flatten=True)
69 |
70 | def get_model_name(self) -> str:
71 | return "model-v7--" + repr_modified_args(self, flatten=True)
72 |
73 | def train_ctx_args(self) -> CtxArgs:
74 | return CtxArgs(
75 | ctx_size=self.ctx_size,
76 | preamble_size=self.preamble_size,
77 | left_margin=self.left_margin,
78 | right_margin=self.right_margin,
79 | max_labels=self.train_max_labels,
80 | inline_prev_gold=self.inline_prev_gold,
81 | )
82 |
83 | def get_preprocess_args(self):
84 | return self.pre_args
85 |
86 | def dec_ctx_args(self) -> CtxArgs:
87 | r = self.train_ctx_args()
88 | r.max_labels = self.dec_max_labels
89 | return r
90 |
91 |
92 | def train_spot_model(
93 | tk_dataset: dict[str, TokenizedSrcSet],
94 | model_name: str,
95 | train_args: ModelTrainingArgs,
96 | gpus: list[int],
97 | quicktest=False,
98 | use_early_stop=False,
99 | use_small_model=False,
100 | ) -> ModelWrapper:
101 | os.chdir(proj_root())
102 | train_ctx_args = train_args.train_ctx_args
103 | dec_args = train_args.dec_args
104 |
105 | running_dir = get_model_dir(False) / model_name
106 | if running_dir.exists():
107 | shutil.rmtree(running_dir)
108 | running_dir.mkdir(parents=True, exist_ok=True)
109 |
110 | print("Disk space left:")
111 | subprocess.run(["df", "-h", str(running_dir)])
112 |
113 | model_path = ModelWrapper.get_codet5_path(use_small_model)
114 | lit_model = TrainModelWrapper(model_path, model_saving_path=running_dir / "ckpts")
115 | tokenizer: TokenizerType = lit_model.tokenizer
116 |
117 | common_type_names = tk_dataset["train"].common_type_names()
118 | wrapper = ModelWrapper(
119 | lit_model.model, tokenizer, dec_args, common_type_names=common_type_names
120 | )
121 |
122 | chunks: dict[str, ChunkedDataset] = {}
123 | with run_long_task("Preparing chunked datasets", notify=False):
124 | for n in ["valid", "train"]:
125 | src = tk_dataset[n]
126 | chunks[n] = src.to_chunks(train_ctx_args)
127 |
128 | wandb_logger = WandbLogger() # assuming a run has already been initialized
129 |
130 | collate_fn = DataCollatorForSeq2Seq(lit_model.tokenizer, lit_model.model)
131 | train_dataloader = dynamic_dataloader(
132 | cast(Any, chunks["train"].data),
133 | max_tokens=train_args.train_max_tokens,
134 | collate_fn=collate_fn,
135 | shuffle=True,
136 | )
137 | valid_dataloader = dynamic_dataloader(
138 | cast(Any, chunks["valid"].data),
139 | max_tokens=train_args.eval_max_tokens,
140 | collate_fn=collate_fn,
141 | shuffle=True, # doesn't hurt
142 | )
143 |
144 | ckpt_interval = max(1, len(train_dataloader) // 10)
145 | val_interval = 1 if quicktest else max(500, ckpt_interval)
146 |
147 | checkpoint_cb = ModelCheckpoint(
148 | dirpath=running_dir,
149 | save_top_k=3,
150 | monitor="valid/loss",
151 | mode="min",
152 | save_on_train_epoch_end=False,
153 | verbose=quicktest,
154 | )
155 |
156 | trainer = pl.Trainer(
157 | default_root_dir=str(running_dir),
158 | # fast_dev_run=6 if quicktest else False,
159 | # log_every_n_steps=500,
160 | accelerator="gpu" if gpus else "cpu",
161 | devices=gpus,
162 | precision=16,
163 | max_epochs=train_args.max_epochs,
164 | logger=wandb_logger,
165 | val_check_interval=val_interval,
166 | callbacks=(
167 | [checkpoint_cb, EarlyStopping("valid/loss", mode="min", verbose=quicktest)]
168 | if use_early_stop
169 | else []
170 | ),
171 | gradient_clip_val=1.0,
172 | gradient_clip_algorithm="norm",
173 | accumulate_grad_batches=train_args.accumulate_grad_batches,
174 | # track_grad_norm=2,
175 | )
176 |
177 | warnings.filterwarnings("ignore", "The dataloader.*does not have many workers.*")
178 |
179 | with run_long_task(f"Training {model_name}", notify=False):
180 | trainer.fit(
181 | model=lit_model,
182 | train_dataloaders=train_dataloader,
183 | val_dataloaders=valid_dataloader,
184 | )
185 |
186 | save_dir = get_model_dir(True) / model_name
187 |
188 | final_eval = trainer.validate(model=lit_model, dataloaders=valid_dataloader)[0]
189 |
190 | try:
191 | if (
192 | use_early_stop
193 | and (best_loss := checkpoint_cb.best_model_score) is not None
194 | and best_loss < final_eval["valid/loss"]
195 | ):
196 | print(
197 | f"Loading best model with score {best_loss} from: {checkpoint_cb.best_model_path}"
198 | )
199 | wrapper.model = TrainModelWrapper.load_from_checkpoint(
200 | checkpoint_cb.best_model_path
201 | ).model
202 | if save_dir.exists():
203 | shutil.rmtree(save_dir)
204 | save_dir.mkdir(parents=True, exist_ok=True)
205 |
206 | wrapper.save(save_dir)
207 | shutil.rmtree(running_dir)
208 | except Exception as e:
209 | logging.error(
210 | "Error encountered after training, returning partial results... Error:\n", e
211 | )
212 |
213 | return wrapper
214 |
215 |
216 | class TrainModelWrapper(pl.LightningModule):
217 | "A pytorch lightening module that handles training and evaluation of the SPOT model."
218 |
219 | def __init__(
220 | self, model_checkpoint: str | Path, *, model_saving_path: Path
221 | ) -> None:
222 | super().__init__()
223 | self.save_hyperparameters()
224 | self.model: ModelType = load_model_spot(model_checkpoint)
225 | self.tokenizer: TokenizerType = TokenizerType.from_pretrained(model_checkpoint)
226 | self.model_saving_path = model_saving_path
227 | self.model_saving_interval: Optional[int] = None
228 | self.avg_loss = MovingAvg(alpha=0.01)
229 | self.labels_trained = 0
230 |
231 | def on_fit_start(self):
232 | # maps chunk id to the initial predictions made for that chunk immediately
233 | # before the model was trained on it
234 | if self.model_saving_interval is not None:
235 | self.batch_ids: list[list[int]] = []
236 | self.saving_counter = 0
237 | self.model.save_pretrained(self.model_saving_path / f"n_batches=0")
238 |
239 | def configure_optimizers(self):
240 | return _configure_optimizers(self.model)
241 |
242 | def training_step(self, batch, batch_idx):
243 | if self.model_saving_interval is not None and self.current_epoch == 0:
244 | self.batch_ids.append(batch["chunk_id"].tolist())
245 | self.saving_counter += 1
246 | if self.saving_counter >= self.model_saving_interval:
247 | self.saving_counter = 0
248 | # model can be used for `n_batches` and onward.
249 | self.model.save_pretrained(
250 | self.model_saving_path / f"n_batches={len(self.batch_ids)}"
251 | )
252 |
253 | outputs = self.model.forward(
254 | input_ids=batch["input_ids"],
255 | attention_mask=batch["attention_mask"],
256 | labels=batch["labels"],
257 | )
258 | assert isinstance(outputs, Seq2SeqLMOutput)
259 | loss = not_none(outputs.loss)
260 | n_labels = batch["n_labels"].sum().item()
261 | self.labels_trained += n_labels
262 | self.avg_loss.update(loss.item())
263 | self.log("train/loss", self.avg_loss.value)
264 | self.log("train/lr", self.lr_schedulers().get_last_lr()[0]) # type: ignore
265 | self.log("train/labels", float(self.labels_trained))
266 | return loss
267 |
268 | def validation_step(self, batch, batch_idx):
269 | outputs = self.model(
270 | input_ids=batch["input_ids"],
271 | attention_mask=batch["attention_mask"],
272 | labels=batch["labels"],
273 | )
274 | loss = outputs.loss
275 | self.log("valid/loss", loss.item())
276 | self.log("train/labels", float(self.labels_trained))
277 |
278 |
279 | def concat_batches(batches: list[dict], keys: list[str]) -> dict:
280 | return {k: torch.concat([b[k] for b in batches]) for k in keys}
281 |
282 |
283 | def _configure_optimizers(model: nn.Module, base_lr: float = 2e-5):
284 | no_decay = ["bias", "LayerNorm.weight"]
285 | grouped_params = [
286 | {
287 | "params": [
288 | p
289 | for pn, p in model.named_parameters()
290 | if not any(n in pn for n in no_decay)
291 | ],
292 | "weight_decay": 0.01,
293 | },
294 | {
295 | "params": [
296 | p
297 | for pn, p in model.named_parameters()
298 | if any(n in pn for n in no_decay)
299 | ],
300 | "weight_decay": 0.0,
301 | },
302 | ]
303 | optimizer = AdamW(grouped_params, lr=base_lr)
304 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.2)
305 | return [optimizer], [lr_scheduler]
306 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utopia-group/TypeT5/d8ff8638f4d00f03042db5780a8d4fa09a72916d/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_func_decoding.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from pathlib import Path
3 |
4 | import pytest
5 |
6 | from typet5.function_dataset import FunctionSignature
7 | from typet5.static_analysis import cst
8 | from typet5.static_analysis import to_abs_import_path as to_abs
9 | from typet5.utils import *
10 | from typet5.utils import assert_eq, groupby, not_none, show_string_diff
11 |
12 |
13 | def test_function_signature():
14 | ex_code = """
15 | def f(x, y: int=3, *, v=3, **kwargs) -> int:
16 | u: int
17 | return 1
18 | """
19 |
20 | f = cast(cst.FunctionDef, cst.parse_module(ex_code).body[0])
21 | sig = FunctionSignature.from_function(f, False)
22 |
23 | new_sig = copy.deepcopy(sig)
24 | new_sig.params["x"] = cst.Annotation(cst.parse_expression("list[int]"))
25 | assert (
26 | "def f(x: list[int], y: int=3, *, v=3, **kwargs) -> int"
27 | in cst.Module([new_sig.apply(f)]).code
28 | )
29 |
30 | new_sig = copy.deepcopy(sig)
31 | new_sig.params["y"] = cst.Annotation(cst.parse_expression("list[int]"))
32 | assert (
33 | "def f(x, y: list[int]=3, *, v=3, **kwargs) -> int"
34 | in cst.Module([new_sig.apply(f)]).code
35 | )
36 |
37 | new_sig = copy.deepcopy(sig)
38 | new_sig.params["v"] = cst.Annotation(cst.parse_expression("list[int]"))
39 | assert (
40 | "def f(x, y: int=3, *, v: list[int]=3, **kwargs) -> int"
41 | in cst.Module([new_sig.apply(f)]).code
42 | )
43 |
44 | new_sig = copy.deepcopy(sig)
45 | new_sig.params["kwargs"] = cst.Annotation(cst.parse_expression("list[int]"))
46 | assert (
47 | "def f(x, y: int=3, *, v=3, **kwargs: list[int]) -> int"
48 | in cst.Module([new_sig.apply(f)]).code
49 | )
50 |
51 | new_sig = copy.deepcopy(sig)
52 | new_sig.returns = cst.Annotation(cst.parse_expression("list[int]"))
53 | assert (
54 | "def f(x, y: int=3, *, v=3, **kwargs) -> list[int]"
55 | in cst.Module([new_sig.apply(f)]).code
56 | )
57 |
58 |
59 | def test_method_signature():
60 | ex_code = """
61 | def f(self, x, y):
62 | u: int
63 | return 1
64 | """
65 |
66 | f = cast(cst.FunctionDef, cst.parse_module(ex_code).body[0])
67 | sig = FunctionSignature.from_function(f, False)
68 | assert len(sig.params) == 2
69 |
70 | ex_code2 = """
71 | def f(a, x, y):
72 | u: int
73 | return 1
74 | """
75 |
76 | f = cast(cst.FunctionDef, cst.parse_module(ex_code2).body[0])
77 | sig = FunctionSignature.from_function(f, False)
78 | assert len(sig.params) == 3
79 |
80 | ex_code3 = """
81 | def f(a=lambda x: x):
82 | return 1
83 | """
84 |
85 | f = cast(cst.FunctionDef, cst.parse_module(ex_code3).body[0])
86 | sig = FunctionSignature.from_function(f, False)
87 | assert len(sig.params) == 1
88 |
--------------------------------------------------------------------------------
/tests/test_model_creation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoConfig
3 | from transformers.models.t5.configuration_t5 import T5Config
4 |
5 | from typet5.data import CtxArgs
6 | from typet5.model import DecodingArgs, ModelType, ModelWrapper
7 | from typet5.train import TrainingConfig
8 | from typet5.utils import DefaultTokenizer
9 |
10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11 |
12 |
13 | def test_basic_torch_operations():
14 | x = torch.randn(10).to(device)
15 | assert x.sum() <= 10
16 |
17 |
18 | def test_model_creation():
19 | config = AutoConfig.from_pretrained(ModelWrapper.get_codet5_path())
20 | model = ModelType(config).to(device)
21 |
22 | ctx_args = TrainingConfig().dec_ctx_args()
23 | dec_args = DecodingArgs(ctx_args, 32)
24 | wrapper = ModelWrapper(model, DefaultTokenizer, dec_args, set())
25 | wrapper.to(device)
26 |
27 | ids = DefaultTokenizer.encode(
28 | "def get_count() -> : ...", return_tensors="pt"
29 | )
30 | batch = {"input_ids": ids, "n_labels": [1]}
31 | out = wrapper.predict_on_batch(batch)
32 | assert True
33 |
--------------------------------------------------------------------------------
/tests/test_type_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from pathlib import Path
4 |
5 | import pytest
6 |
7 | from typet5.static_analysis import FunctionSignature, mask_types
8 | from typet5.tokenized_src import PreprocessArgs
9 | from typet5.type_check import MypyResult, PythonType, remove_top_optional
10 | from typet5.type_env import (
11 | AnnotCat,
12 | AnnotPath,
13 | AnyAnnot,
14 | SelectAnnotations,
15 | TypeInfAction,
16 | annot_path,
17 | apply_annotations,
18 | collect_annots_info,
19 | collect_user_annotations,
20 | mypy_checker,
21 | normalize_type,
22 | parse_type_str,
23 | type_inf_env,
24 | )
25 | from typet5.utils import (
26 | SpecialNames,
27 | as_any,
28 | assert_eq,
29 | cst,
30 | proj_root,
31 | read_file,
32 | write_file,
33 | )
34 |
35 | os.chdir(proj_root())
36 |
37 |
38 | def test_annotation_collection():
39 | parsed = cst.parse_module(read_file("data/code/env_code_2.py"))
40 | annots = collect_annots_info(parsed)
41 | annot_paths = [(a.path, a.cat) for a in annots]
42 | correct_annot_paths: list[tuple[AnnotPath, AnnotCat]] = [
43 | (annot_path("fib", "n"), AnnotCat.FuncArg),
44 | (annot_path("fib", SpecialNames.Return), AnnotCat.FuncReturn),
45 | (annot_path("foo", "bar"), AnnotCat.FuncArg),
46 | (annot_path("foo", SpecialNames.Return), AnnotCat.FuncReturn),
47 | (annot_path("Bar", "z"), AnnotCat.ClassAtribute),
48 | (annot_path("Bar", "w"), AnnotCat.ClassAtribute),
49 | (annot_path("Bar", "__init__", SpecialNames.Return), AnnotCat.FuncReturn),
50 | (annot_path("Bar", "__init__", "x"), AnnotCat.FuncArg),
51 | (annot_path("Bar", "__init__", "self.x"), AnnotCat.ClassAtribute),
52 | (annot_path("Bar", "__init__", "self.y"), AnnotCat.ClassAtribute),
53 | (annot_path("Bar", "reset", "w0"), AnnotCat.FuncArg),
54 | (annot_path("Bar", "reset", SpecialNames.Return), AnnotCat.FuncReturn),
55 | (annot_path("Bar", "foo", "z"), AnnotCat.FuncArg),
56 | (annot_path("Bar", "foo", SpecialNames.Return), AnnotCat.FuncReturn),
57 | (annot_path("bar"), AnnotCat.GlobalVar),
58 | ]
59 | for pair in correct_annot_paths:
60 | assert pair in annot_paths
61 | for pair in annot_paths:
62 | assert pair in correct_annot_paths
63 |
64 |
65 | def test_self_parameter_annotation():
66 | code = """
67 | def foo(self: float, x: int) -> str:
68 | return "1"
69 | """
70 | parsed = cst.parse_module(code)
71 | _, types = collect_user_annotations(parsed)
72 |
73 | assert_eq(types, [PythonType.from_name("int"), PythonType.from_name("str")])
74 | n_segs = len(mask_types(parsed).code.split(SpecialNames.TypeMask))
75 | assert_eq(n_segs, len(types) + 1)
76 |
77 | sig = FunctionSignature.from_function(as_any(parsed.body[0]), False)
78 | assert len(sig.params) == len(types) - 1
79 |
80 |
81 | parsed = cst.parse_module(read_file("data/code/bad_code_1.py"))
82 |
83 |
84 | code_1_patch = {
85 | annot_path("fib", "n"): cst.Annotation(cst.Name("int")),
86 | annot_path("fib", SpecialNames.Return): cst.Annotation(cst.Name("int")),
87 | annot_path("t_add", SpecialNames.Return): cst.Annotation(cst.Name("str")),
88 | annot_path("bad_y"): AnyAnnot,
89 | }
90 |
91 |
92 | def test_annotation_applying():
93 | old_annots = collect_annots_info(parsed)
94 | old_map = {a.path: a.annot for a in old_annots if a.annot is not None}
95 | new_parsed = apply_annotations(parsed, code_1_patch)
96 | new_annots = collect_annots_info(new_parsed)
97 | new_map = {a.path: a.annot for a in new_annots if a.annot is not None}
98 |
99 | for k, v in code_1_patch.items():
100 | assert old_map[k].annotation.value != new_map[k].annotation.value # type: ignore
101 | assert new_map[k].annotation.value == v.annotation.value # type: ignore
102 |
103 |
104 | @pytest.mark.skip("Not used.")
105 | def test_mypy_checker_1():
106 | with mypy_checker(Path("data/code"), wait_before_check=0.0) as checker:
107 | check_r = checker.recheck_project()
108 | assert isinstance(check_r, MypyResult)
109 | assert Path("data/code/bad_code_1.py").resolve() in check_r.error_dict
110 | assert Path("data/code/bad_code_2.py").resolve() in check_r.error_dict
111 |
112 |
113 | @pytest.mark.skip("Not used.")
114 | def test_mypy_checker_2():
115 | with mypy_checker(Path("data/code_output"), wait_before_check=0.0) as checker:
116 | if Path("data/code_output/bad_code_1.py").exists():
117 | os.remove("data/code_output/bad_code_1.py")
118 | oe = checker.recheck_project().num_errors
119 | write_file("data/code_output/bad_code_1.py", parsed.code)
120 | assert checker.recheck_project().num_errors > oe
121 | new_code = apply_annotations(parsed, code_1_patch).code
122 | write_file(
123 | "data/code_output/bad_code_1.py",
124 | new_code,
125 | )
126 | c_r = checker.recheck_project()
127 | assert c_r.num_errors == oe, f"mypy_output: {c_r.output_str}\ncode: {new_code}"
128 |
129 |
130 | def test_type_parsing():
131 | # test quoted types
132 | assert parse_type_str("'Foo[int]'") == parse_type_str("Foo[int]")
133 | assert parse_type_str('"Bar"') == parse_type_str("Bar")
134 |
135 |
136 | def test_type_normalization():
137 | equiv_pairs: list[tuple[str, str]] = [
138 | ("list[int]", "List[int]"),
139 | ("dict[str, list]", "Dict[str, List]"),
140 | ("'Foo[int]'", "Foo[int]"),
141 | ("typing.Union[str, List]", "typing.Union[list, str]"),
142 | ("typing.Union[str, typing.Union[str, int]]", "str | int"),
143 | ("typing.Union[str, float, typing.Union[str, int]]", "str | int | float"),
144 | ("Union[str, float, None]", "Optional[Union[str, float]]"),
145 | ("str | None", "Optional[str]"),
146 | ("Any | None", "Optional"),
147 | ("List[Any]", "List"),
148 | ("Dict[Any, Any]", "Dict"),
149 | ]
150 |
151 | for a, b in equiv_pairs:
152 | ta = parse_type_str(a)
153 | tb = parse_type_str(b)
154 | assert normalize_type(ta) == normalize_type(tb)
155 |
156 | nonequiv_pairs: list[tuple[str, str]] = [
157 | ("Union[str, int]", "Union[str, list]"),
158 | ("typing.List[str]", "t.List[str]"),
159 | ("tuple[str, int]", "tuple[int, str]"),
160 | ("Dict[str, Any]", "Dict"),
161 | ]
162 |
163 | for a, b in nonequiv_pairs:
164 | ta = parse_type_str(a)
165 | tb = parse_type_str(b)
166 | assert normalize_type(ta) != normalize_type(tb)
167 |
168 | dict_ex = PythonType.from_str("Dict[Any, Any] | None")
169 | assert remove_top_optional(normalize_type(dict_ex)) == PythonType.from_name("Dict")
170 |
171 | dict2 = PythonType.from_str("Dict[str, Any]")
172 | assert normalize_type(dict2) == dict2
173 |
174 |
175 | import shutil
176 |
177 | from typet5.data import TokenizedSrcSet, type_check_src, type_check_src_in_project
178 | from typet5.utils import load_tokenizer_spot, proj_root
179 |
180 |
181 | @pytest.mark.skip("Not using type checker for the moment")
182 | def test_mypy_checking():
183 | simple_dataset = TokenizedSrcSet.from_repos(
184 | proj_root() / "data",
185 | [proj_root() / "data/code"],
186 | PreprocessArgs(drop_comments=True),
187 | )
188 |
189 | src_to_check = simple_dataset.get_src_by_file(Path("code/bad_code_2.py"))
190 | result_1 = type_check_src(src_to_check, {0: "int"})
191 | assert len(result_1.feedbacks) == 0
192 |
193 | temp_dir = proj_root() / "mypy_temp/test_dir"
194 | shutil.rmtree(temp_dir, ignore_errors=True)
195 |
196 | with simple_dataset.setup_typechecking(
197 | [src_to_check],
198 | cleanup=True,
199 | skip_pre_fdbks=True,
200 | ) as env:
201 | result_2 = type_check_src_in_project(
202 | src_to_check,
203 | {0: "int"},
204 | (env.template_root / "code"),
205 | "Skip",
206 | )
207 | assert isinstance(result_2.feedbacks, list) and len(result_2.feedbacks) == 0
208 |
209 | rs = simple_dataset.type_check_each_file_in_project(
210 | [
211 | ((simple_dataset.repos_root / "code/bad_code_2.py"), {0: "int"}),
212 | ((simple_dataset.repos_root / "code/bad_code_1.py"), {0: "str"}),
213 | ],
214 | )
215 | fdbks3 = rs[0].feedbacks
216 | assert isinstance(fdbks3, list) and len(fdbks3) == 0
217 | # assert (
218 | # 'Argument 1 to "fib" has incompatible type "int"; expected "str"'
219 | # in fdbks3[0].message
220 | # )
221 |
--------------------------------------------------------------------------------