├── .coveragerc ├── .flake8 ├── .gitignore ├── .gitlab-ci.yml ├── .pre-commit-config.yaml ├── .pypirc ├── CHANGELOG.md ├── GrammaTech-CLA-retypd.pdf ├── LICENSE.md ├── README.md ├── delete_remote_packages.py ├── pytest.ini ├── reference ├── contents.md ├── paper.pdf ├── presentation_slides.pdf └── type-recovery.rst ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── src ├── __init__.py ├── c_type_generator.py ├── c_types.py ├── clattice.py ├── dummylattice.py ├── fast_enfa.py ├── graph.py ├── graph_solver.py ├── loggable.py ├── parser.py ├── pathexpr.py ├── schema.py ├── sketches.py ├── solver.py └── version.py └── test ├── test_endtoend.py ├── test_graph.py ├── test_internals.py ├── test_pathexpr.py ├── test_regressions.py └── test_sketches.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | 3 | show_missing = True 4 | 5 | exclude_lines = 6 | if __name__ == .__main__.: 7 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # .flake8 2 | # 3 | # DESCRIPTION 4 | # Configuration file for the python linter flake8. 5 | # 6 | # This configuration is based on the generic 7 | # configuration published on GitHub. 8 | # 9 | # AUTHOR 10 | # krnd 11 | # 12 | # VERSION 13 | # 1.0 14 | # 15 | # SEE ALSO 16 | # http://flake8.pycqa.org/en/latest/user/options.html 17 | # http://flake8.pycqa.org/en/latest/user/error-codes.html 18 | # https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes 19 | # https://gist.github.com/krnd 20 | # 21 | 22 | 23 | [flake8] 24 | 25 | ################### PROGRAM ################################ 26 | 27 | # Specify the number of subprocesses that Flake8 will use to run checks in parallel. 28 | jobs = auto 29 | 30 | 31 | ################### OUTPUT ################################# 32 | 33 | ########## VERBOSITY ########## 34 | 35 | # Increase the verbosity of Flake8s output. 36 | verbose = 0 37 | # Decrease the verbosity of Flake8s output. 38 | quiet = 0 39 | 40 | 41 | ########## FORMATTING ########## 42 | 43 | # Select the formatter used to display errors to the user. 44 | format = default 45 | 46 | # Print the total number of errors. 47 | count = True 48 | # Print the source code generating the error/warning in question. 49 | show-source = True 50 | # Count the number of occurrences of each error/warning code and print a report. 51 | statistics = True 52 | 53 | 54 | ########## TARGETS ########## 55 | 56 | # Redirect all output to the specified file. 57 | output-file = .flake8.log 58 | # Also print output to stdout if output-file has been configured. 59 | tee = True 60 | 61 | 62 | ################### FILE PATTERNS ########################## 63 | 64 | # Provide a comma-separated list of glob patterns to exclude from checks. 65 | exclude = 66 | # git folder 67 | .git, 68 | # python cache 69 | __pycache__, 70 | # These files typically just have unused imports 71 | __init__.py, 72 | # Provide a comma-separate list of glob patterns to include for checks. 73 | filename = 74 | *.py 75 | 76 | 77 | ################### LINTING ################################ 78 | 79 | ########## ENVIRONMENT ########## 80 | 81 | # Provide a custom list of builtin functions, objects, names, etc. 82 | builtins = 83 | 84 | 85 | ########## OPTIONS ########## 86 | 87 | # Report all errors, even if it is on the same line as a `# NOQA` comment. 88 | disable-noqa = False 89 | 90 | # Set the maximum length that any line (with some exceptions) may be. 91 | max-line-length = 100 92 | # Set the maximum allowed McCabe complexity value for a block of code. 93 | max-complexity = 10 94 | # Toggle whether pycodestyle should enforce matching the indentation of the opening brackets line. 95 | # incluences E131 and E133 96 | hang-closing = True 97 | 98 | 99 | ########## RULES ########## 100 | 101 | # ERROR CODES 102 | # 103 | # E/W - PEP8 errors/warnings (pycodestyle) 104 | # F - linting errors (pyflakes) 105 | # C - McCabe complexity error (mccabe) 106 | 107 | # Specify a list of codes to ignore. 108 | ignore = D4 109 | # Specify the list of error codes you wish Flake8 to report. 110 | select = F, E9 111 | # Enable off-by-default extensions. 112 | enable-extensions = 113 | 114 | 115 | ########## DOCSTRING ########## 116 | 117 | # Enable PyFlakes syntax checking of doctests in docstrings. 118 | doctests = False 119 | 120 | # Specify which files are checked by PyFlakes for doctest syntax. 121 | include-in-doctest = 122 | # Specify which files are not to be checked by PyFlakes for doctest syntax. 123 | exclude-in-doctest = 124 | 125 | # tell flake8-rst-docstrings plugin to ignore some custom roles 126 | # this prevents FP RST303 and RST304 127 | rst-roles = py:meth, py:class, py:obj 128 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .*.sw? 2 | *~ 3 | __pycache__ 4 | .mypy_cache 5 | .flake8.log 6 | build 7 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | stages: 2 | - test 3 | - export 4 | 5 | test_module: 6 | stage: test 7 | image: python:3.8-slim 8 | script: 9 | - pip install -r requirements-dev.txt 10 | - flake8 src/ 11 | - pip install -r requirements.txt 12 | - pip install . 13 | - pytest -m commit --cov=retypd --cov-config=.coveragerc test/ 14 | tags: 15 | - shared 16 | rules: 17 | - if: '$CI_COMMIT_BRANCH || $CI_MERGE_REQUEST_REF_PATH' 18 | 19 | export_module: 20 | stage: export 21 | image: python:3.8-slim 22 | script: 23 | - pip install -r requirements-dev.txt 24 | - python3 setup.py bdist_wheel --dist-dir=$CI_PROJECT_DIR/dist 25 | - ls $CI_PROJECT_DIR/dist/*.whl | xargs $CI_PROJECT_DIR/delete_remote_packages.py $GL_PKG_API_TOKEN 26 | - sed "s/password = /password = $GL_PKG_API_TOKEN/" $CI_PROJECT_DIR/.pypirc > ~/.pypirc 27 | - python3 -m twine upload --verbose --repository repypi $CI_PROJECT_DIR/dist/*.whl 28 | tags: 29 | - shared 30 | rules: 31 | - if: '$CI_COMMIT_BRANCH == "master"' 32 | - if: '$CI_COMMIT_REF_NAME =~ /^release-.*/' 33 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-docstring-first 7 | - id: check-executables-have-shebangs 8 | - id: check-json 9 | - id: check-merge-conflict 10 | - id: check-yaml 11 | - id: debug-statements 12 | - id: end-of-file-fixer 13 | - id: flake8 14 | - id: mixed-line-ending 15 | - id: trailing-whitespace 16 | - repo: https://github.com/psf/black 17 | rev: 22.3.0 18 | hooks: 19 | - id: black 20 | args: 21 | - --line-length=79 22 | -------------------------------------------------------------------------------- /.pypirc: -------------------------------------------------------------------------------- 1 | [distutils] 2 | index-servers = 3 | repypi 4 | [repypi] 5 | repository = https://git.grammatech.com/api/v4/projects/1587/packages/pypi 6 | username = __token__ 7 | password = 8 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 0.3 4 | 5 | - Implement `InferShapes` as in the paper to propagate all type capabilities. 6 | - Two implementations of constraint simplification: naive path exploration and using path expressions. 7 | - Reworked anonymous variable introduction to keep recursive type information. 8 | - Constraint simplification is used to infer type schemes and to generate primitive constraints. 9 | - Fixes in constraint graph generation: 10 | - Only add recall edges and for left-hand sides of constraints and forget edges for right-hand sides. 11 | - Add Left/Right marking to interesting nodes. 12 | 13 | # 0.2 14 | -------------------------------------------------------------------------------- /GrammaTech-CLA-retypd.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrammaTech/retypd/8f7f72be9a567731bb82636cc91d70a3551050bf/GrammaTech-CLA-retypd.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A tool for recovering type information in binaries. It is designed with a simple front end schema so that it can be used with any disassembler. **This analysis is in active development and the details of its API are expected to change.** 2 | 3 | Intended use is best demonstrated in `test/test_schema.py`'s `test_end_to_end` method: Create a `ConstraintSet` object and populate it with facts from the disassembled binary. Instantiate a Solver object with the populated ConstraintSet and a collection of interesting variables (as strings or DerivedTypeVariable objects). Call the Solver object (it needs no arguments after it is instantiated). Inferred constraints are stored in an attribute called constraints. 4 | 5 | Several additional details are included in `reference/type-recovery.rst`, including explanations of some of the more complex concepts from the paper and an outline of the type recovery algorithm. 6 | 7 | ## Copyright and Acknowledgments 8 | 9 | Copyright (C) 2021 GrammaTech, Inc. 10 | 11 | This code is licensed under the GPLv3 license. See the LICENSE file in the project root for license terms. 12 | 13 | This project is sponsored by the Office of Naval Research, One Liberty Center, 875 N. Randolph Street, Arlington, VA 22203 under contract #N68335-17-C-0700. The content of the information does not necessarily reflect the position or policy of the Government and no official endorsement should be inferred. 14 | -------------------------------------------------------------------------------- /delete_remote_packages.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Helper script for working around a "feature" in Gitlab's pypi package repositories. 4 | Pypi does not support uploading packages of the same version, 5 | they believe the version should be bumped for any change. 6 | We're going to delete packages instead. 7 | 8 | Call this script with positional args. 9 | The first is the token with read and write access to the pypi repository. 10 | This can be your personal access token or $CI_DEPLOY_PASSWORD. 11 | $CI_JOB_TOKEN does not work, it doesn't have write access. 12 | The remainder is an unlimited list of wheels to be uploaded. 13 | 14 | If any local wheels find a match in name and version in the remote pypi repository, the remote will be deleted. 15 | """ 16 | import sys 17 | import requests 18 | import pkginfo 19 | 20 | TOKEN = sys.argv[1] 21 | for wheel_loc in sys.argv[2:]: 22 | wheel = pkginfo.Wheel(wheel_loc) 23 | # Get packages, handling pagination 24 | responses = [] 25 | response = requests.get( 26 | f"https://git.grammatech.com/api/v4/projects/1587/packages?package_name={wheel.name}", 27 | headers={"PRIVATE-TOKEN": TOKEN}, 28 | ) 29 | responses.append(response) 30 | while response.links.get("next"): 31 | response = requests.get( 32 | response.links.get("next")["url"], headers={"PRIVATE-TOKEN": TOKEN} 33 | ) 34 | if response.status_code != 200: 35 | raise Exception( 36 | f"{response.status_code} status code while requesting package listings filtered by local name: {wheel.name}" 37 | ) 38 | responses.append(response) 39 | 40 | packages = [ 41 | package for response in responses for package in response.json() 42 | ] 43 | # Delete all matching packages 44 | for package in packages: 45 | if ( 46 | wheel.version == package["version"] 47 | and wheel.name == package["name"] 48 | ): 49 | print(f'Deleting {package["name"]} {package["version"]}.') 50 | response = requests.delete( 51 | f'https://git.grammatech.com/api/v4/projects/1587/packages/{package["id"]}', 52 | headers={"PRIVATE-TOKEN": TOKEN}, 53 | ) 54 | if response.status_code != 204: 55 | raise Exception( 56 | f'{response.status_code} status code while deleting this package: {package["name"]} {package["version"]}' 57 | ) 58 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | commit: Tests run for each commit 4 | nightly: Tests run only in nightly tests 5 | -------------------------------------------------------------------------------- /reference/contents.md: -------------------------------------------------------------------------------- 1 | The presentation slides can be found [here](https://raw.githubusercontent.com/emeryberger/PLDI-2016/master/presentations/pldi16-presentation241.pdf>) and the paper can be found [here](https://arxiv.org/pdf/1603.05495.pdf). 2 | -------------------------------------------------------------------------------- /reference/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrammaTech/retypd/8f7f72be9a567731bb82636cc91d70a3551050bf/reference/paper.pdf -------------------------------------------------------------------------------- /reference/presentation_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrammaTech/retypd/8f7f72be9a567731bb82636cc91d70a3551050bf/reference/presentation_slides.pdf -------------------------------------------------------------------------------- /reference/type-recovery.rst: -------------------------------------------------------------------------------- 1 | *************************************** 2 | Type inference (in the style of retypd) 3 | *************************************** 4 | 5 | Several people have implemented or attempted to implement retypd. It is a powerful system but 6 | difficult to understand. My goal in this document is to explain what I have learned as I have 7 | studied the paper and implemented the analysis. 8 | 9 | References to the paper are to the extended form which is included in this directory and available 10 | on `ArXiv `_. References to the slides, which are also 11 | included in this directory, are from `Matt's PLDI talk 12 | `_. 13 | The slides, as a PDF, lack animations and I will reference them by "physical" number, not by the 14 | numbers printed in the slides themselves. 15 | 16 | ############### 17 | What is retypd? 18 | ############### 19 | 20 | Retypd is a polymorphic type inference scheme, suitable for type recovery from binaries. Its primary 21 | contribution is a subtlety in how it treats pointer types; some uses of pointers are covariant and 22 | some uses are contravariant. By managing this variance effectively, it is able to reason much more 23 | precisely than unmodified unification-based type inference algorithms. 24 | 25 | Retypd was originally published and presented at PLDI 2016 by Matthew Noonan, who was a GrammaTech 26 | employee at the time. His presentation is very dense and subtle (but quite rigorous). This document 27 | hopes to distill some of the details learned by studying the paper, the presentation, and by 28 | creating an implementation. 29 | 30 | ############################# 31 | Subtleties and misconceptions 32 | ############################# 33 | 34 | ----------- 35 | Scaffolding 36 | ----------- 37 | 38 | Many of the mathematical constructs in this paper exist for scaffolding; they develop the reader's 39 | intuition about how the operations work. However, they are not actually used in the algorithm. The 40 | steps outlined in the final section summarize the portions of the paper and slides that actually 41 | should be implemented. 42 | 43 | ---------------- 44 | Subtyping styles 45 | ---------------- 46 | 47 | The paper asserts things about several subtyping modalities. While I cannot speak for the author, 48 | discussions with a collaborator on the project suggest that "structural subtyping" may to refer to 49 | `depth subtyping `_ and 50 | "non-structural subtyping" to width subtyping (same link as depth). We're not certain what "physical 51 | subtyping" means. 52 | 53 | ------------------------ 54 | Non-structural subtyping 55 | ------------------------ 56 | 57 | Throughout the paper, the subset relation (⊑) is used simply to mean that one datum's type is a 58 | subtype of another. Most frequently, it is used to model depth (structural) subtyping. However, at 59 | least one occurrence (p.7, α ⊑ F.in_stack0) indicates width (non-structural) subtyping. The paper 60 | states on multiple occasions that instantiation, such as on call site boundaries, allows this 61 | non-structural subtyping to occur. 62 | 63 | The key is not function calls; nothing in the analysis makes function calls different from other 64 | components of a program. However, using functions as "interesting" variables does create boundaries. 65 | For example, a path in the graph that reaches some function F via its in_0 capability will generate 66 | a constraint whose right-hand side begins with "F.in_0", indicating something about the type of a 67 | value passed to F. A path from F to something else via its in_0 capability will generate a 68 | constraint whose left-hand side begins with "F.in_0". Because these type constraints belong to 69 | different sketches, any free type variables are implicitly instantiated at this boundary. 70 | 71 | This non-structural subtyping can skip these boundaries, but only in well-defined ways. For example, 72 | if F accepts a struct with two members and G passes a compatible struct that adds a third member, it 73 | is possible for G's sketch to include information from path downstream from the F.in_0 vertex it 74 | encounters. However, it can only infer things about the first two members from the uses in F; F 75 | cannot infer anything about the third member because it would require going back up the graph. 76 | (IS THIS TRUE? CAN WE PROVE IT FOR ALL CASES?) 77 | 78 | ----------------- 79 | Variance notation 80 | ----------------- 81 | 82 | The notation that writes variance as if it were a capability (e.g., F.⊕) indicates not the variance 83 | of the variable that precedes it rather summarizes the variance of everything on the stack *after* 84 | F. In other words, it is the variance of the elided portions of the stack. 85 | 86 | For the formation of the graph (shown in slides 68-78), it is especially important to remember that 87 | the variance of the empty string (⟨ε⟩) is ⊕. Since type constraints from C elide no part of their 88 | stack, the elided portion is ε. In other words, all of the symbols copied directly from C into nodes 89 | in the graph have variance ⊕. 90 | 91 | ------------------- 92 | Special graph nodes 93 | ------------------- 94 | 95 | Start# and End# and the L and R subscripts are conveniences for ensuring that we only examine paths 96 | from interesting variables to interesting variables whose internal nodes are exclusively 97 | *uninteresting* variables. For many implementations, there is no need to encode these things 98 | directly. 99 | 100 | ------------------------------ 101 | Finding the graph in the paper 102 | ------------------------------ 103 | 104 | The slides use a different notation than does the paper. The graph construction summarized in slides 105 | 68-78 corresponds to Appendices C and D (pp. 20-24). 106 | 107 | ---------------- 108 | The type lattice 109 | ---------------- 110 | 111 | The lattice Λ of types is somewhat arbitrary, but C types are too general to be useful. Instead, the 112 | TSL retypd implementation uses types as users think about them. A file descriptor, for example, is 113 | implemented with an int but is conceptually different. It ought to be possible to recover the basic 114 | types from the TSL code. 115 | 116 | The two JSON files in this directory include the schemas from the TSL implementation. The presence 117 | of any functions included in these schemas will yield types with semantic meaning. 118 | 119 | ------------------------ 120 | S-Pointer and S-Field⊕/⊖ 121 | ------------------------ 122 | 123 | Most of the inference rules in Figure 3 (p. 5) become redundant because of the structure of the 124 | graph: 125 | * T-Left and T-Right are implicit because edges cannot exist without vertices. 126 | * T-Prefix is expressed by the existence of forget edges. 127 | * As discussed in the paper, T-InheritL and T-InheritR are redundant, as a combination of 128 | T-Left/T-Right and S-Field⊕/S-Field⊖ can produce the same facts. 129 | * S-Refl and S-Trans are implicit. 130 | 131 | Once the initial graph is created, edges corresponding to S-Field⊕ and S-Field⊖ are added as 132 | shortcuts between existing nodes. The lazy instantiation of S-Pointer is presented in Algorithm D.2, 133 | lines 20-27. It is also shown in Figure 14 (in one of two possible forms). The form shown uses an 134 | instantiation of S-Field⊖ for a store, then an instantiation of S-Pointer, and finally an 135 | instantiation of S-Field⊕ for a store. The rightmost edge in Figure 14 shows the contravariant edge 136 | generated (lazily) by S-Pointer that would be used if the load and store were swapped. In either 137 | case, the first instantiation requires three or more edges and the last of these edges must be a pop 138 | (or recall) edge. The head of the pop edge always comes from a node with a contravariant elided 139 | prefix (in the figure, p⊖). The target of the first edge required by the last instantiation is 140 | always a node with the same derived type variable but with inverted variance (in the figure, p⊕). 141 | N.B. this triple instantiation of rules does not create any new nodes. 142 | 143 | As a result, saturation adds edges for S-Field⊕ and S-Field⊖ between nodes that already exist. It 144 | also adds edges for S-Pointer (combined with the other two rules). This limits these rules' 145 | instantiations so that they never create additional nodes in the graph. As a result, saturation 146 | converges. I have not yet proven that this guarantees that all useful instantiations of these rules 147 | occur in this limited context, but I think that the proof in Appendix B proves this property. 148 | 149 | ##################### 150 | Type recovery outline 151 | ##################### 152 | 153 | The following steps, in order, implement retypd. The steps after saturation reflect recent updates 154 | to retypd and not the original paper. 155 | 156 | #. Generate base constraints (slides 18-27 or Appendix A). Call this set of constraints C. 157 | #. Do **not** fix the set of constraints over the inference rules from Figure 3 (see also slide 28); 158 | this diverges in the presence of recursive types. The remainder of the algorithm accomplishes the 159 | same thing as the fixed point but without diverging. 160 | #. Build a graph Δ from C; a ⊑ b becomes a.⊕ → b.⊕ *and* b.⊖ → a.⊖ (Δ_c on p. 21). Each of these 161 | edges is unlabeled. 162 | #. For every node with capabilities (e.g., a.c.⊕), create "forget" and "recall" edges. For our 163 | example node, let us assume that c is contravariant (i.e., ⟨c⟩ = ⊖). Produce an edge with the 164 | label "forget c" from a.c.⊕ → a.⊖ and an edge with the label "recall c" in the opposite 165 | direction. This may or may not create additional nodes. Forget and recall edges are used in the 166 | slides and, respectively, are called push and pop edges in the paper (see step 2 of D.2 on page 167 | 22). **N.B. forgetting is equated with pushing because the elided capability is pushed onto the 168 | stack.** 169 | #. Saturate by finding *sequences* of edges that are all unlabeled except for a single forget edge 170 | (say, "forget *l*") that reach nodes with outgoing edges with a corresponding recall edge 171 | ("recall *l*"). If the sequence begins and reaches q and if the recall edge is from q to r, 172 | create an edge from p to r without a label. Repeat to a fixed point. Additionally, create 173 | shortcut edges as shown in Figure 14 for S-Field/S-Pointer/S-Field instantiations. 174 | #. Remove self loops; the graph represents a reflexive relation, so edges from a vertex to itself 175 | are not informative. 176 | #. Identify cycles (strongly connected components) in the graph that do not include both forget and 177 | recall edges. Identify nodes in these cycles that have predecessors outside of the SCC. Eliminate 178 | duplicates (there is no need to include A.load if A is already in the set). Create a new type 179 | variable for each remaining node and add each of these nodes to the set of interesting variables. 180 | #. Split the graph into two subgraphs, copying recall and unlabeled edges but not forget edges to 181 | the new subgraph. Change the tails of existing recall edges to the nodes in the new subgraph. 182 | This ensures that paths can never include forget edges after recall edges. 183 | #. Starting at each node associated with an interesting variable, find paths to other interesting 184 | variables. Record the edge labels. For each path found, generate constraints: append the forget 185 | labels to the interesting variable at the beginning of the path and the recall labels to the 186 | interesting variable at the end of the path. If both of the resulting derived type variables have 187 | a covariant suffix and if they are not equal to each other, emit a constraint. 188 | #. If desired, generate sketches from the type constraints. 189 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | pkginfo 3 | pytest 4 | pytest-cov 5 | requests 6 | sphinx 7 | sphinx_rtd_theme 8 | twine 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dataclasses ~= 0.6 2 | graphviz ~= 0.20 3 | networkx ~= 3.1.0 4 | tqdm ~= 4.64.0 5 | pyformlang ~= 1.0.1 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Retypd - machine code type inference 4 | # Copyright (C) 2021 GrammaTech, Inc. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # 19 | # This project is sponsored by the Office of Naval Research, One Liberty 20 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 21 | # N68335-17-C-0700. The content of the information does not necessarily 22 | # reflect the position or policy of the Government and no official 23 | # endorsement should be inferred. 24 | 25 | from imp import load_source 26 | from os import path 27 | from setuptools import setup 28 | 29 | PKGINFO = load_source("pkginfo.version", "src/version.py") 30 | __version__ = PKGINFO.__version__ 31 | __packagename__ = PKGINFO.__packagename__ 32 | 33 | here = path.abspath(path.dirname(__file__)) 34 | 35 | # get the dependencies and installs 36 | with open(path.join(here, "requirements.txt"), encoding="utf-8") as f: 37 | all_reqs = f.read().split("\n") 38 | 39 | install_requires = [x.strip() for x in all_reqs if "git+" not in x] 40 | 41 | setup( 42 | name=__packagename__, 43 | version=__version__, 44 | description="An implementation of retypd in Python3", 45 | author="GrammaTech, Inc.", 46 | python_requires=">=3.8.0", 47 | package_dir={__packagename__: "src"}, 48 | packages=[__packagename__], 49 | install_requires=install_requires, 50 | ) 51 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | """An implementation of retypd based on the paper and slides included in the reference subdirectory. 24 | 25 | To invoke, create a Program, which requires a lattice of atomic types, a set of global variables of 26 | interest, a mapping from functions to constraints generated from them, and a call graph. Then, 27 | instantiate a Solver with the Program. Lastly, invoke the solver. The result of calling the solver 28 | object is the set of constraints generated from the analysis. 29 | """ 30 | 31 | from .graph import EdgeLabel, Node 32 | from .dummylattice import DummyLattice, DummyLatticeCTypes 33 | from .schema import ( 34 | ConstraintSet, 35 | DerefLabel, 36 | DerivedTypeVariable, 37 | InLabel, 38 | LoadLabel, 39 | OutLabel, 40 | Program, 41 | StoreLabel, 42 | SubtypeConstraint, 43 | Variance, 44 | Lattice, 45 | LatticeCTypes, 46 | ) 47 | from .solver import Solver, SolverConfig 48 | from .parser import SchemaParser 49 | from .c_type_generator import CTypeGenerator, CTypeGenerationError 50 | from .clattice import CLattice, CLatticeCTypes 51 | from .c_types import ( 52 | CType, 53 | VoidType, 54 | IntType, 55 | FloatType, 56 | CharType, 57 | BoolType, 58 | ArrayType, 59 | PointerType, 60 | FunctionType, 61 | Field, 62 | CompoundType, 63 | StructType, 64 | UnionType, 65 | ) 66 | from .graph_solver import GraphSolverConfig 67 | from .sketches import Sketch 68 | from .loggable import LogLevel 69 | -------------------------------------------------------------------------------- /src/c_type_generator.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | from .schema import ( 24 | AccessPathLabel, 25 | InLabel, 26 | OutLabel, 27 | LoadLabel, 28 | StoreLabel, 29 | DerefLabel, 30 | DerivedTypeVariable, 31 | LatticeCTypes, 32 | Lattice, 33 | ) 34 | from .sketches import ( 35 | SketchNode, 36 | Sketch, 37 | LabelNode, 38 | SkNode, 39 | ) 40 | from .c_types import ( 41 | CType, 42 | PointerType, 43 | FunctionType, 44 | ArrayType, 45 | StructType, 46 | IntType, 47 | UnionType, 48 | Field, 49 | FloatType, 50 | CharType, 51 | ) 52 | from .loggable import Loggable, LogLevel 53 | from typing import Set, Dict, Optional, List, Tuple 54 | from collections import defaultdict 55 | import itertools 56 | 57 | 58 | class CTypeGenerationError(Exception): 59 | """ 60 | Exception raised when an unexpected situation occurs during CType generation. 61 | """ 62 | 63 | pass 64 | 65 | 66 | class CTypeGenerator(Loggable): 67 | """ 68 | Generate C-like types from sketches. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | sketch_map: Dict[DerivedTypeVariable, Sketch], 74 | lattice: Lattice, 75 | lattice_ctypes: LatticeCTypes, 76 | default_int_size: int, 77 | default_ptr_size: int, 78 | verbose: LogLevel = LogLevel.QUIET, 79 | ): 80 | super(CTypeGenerator, self).__init__(verbose) 81 | self.default_int_size = default_int_size 82 | self.default_ptr_size = default_ptr_size 83 | self.sketch_map = sketch_map 84 | self.struct_types = {} 85 | self.dtv2type = defaultdict(dict) 86 | self.lattice = lattice 87 | self.lattice_ctypes = lattice_ctypes 88 | 89 | def union_types( 90 | self, a: Optional[CType], b: Optional[CType] 91 | ) -> Optional[CType]: 92 | """ 93 | This function decides how to merge two CTypes for the same access path. This differs from 94 | the lattice, which only considers "atomic" or "terminal" types. This can take, for example, 95 | two pointers to structs. Or an integer and a pointer to a struct. 96 | """ 97 | if a is None: 98 | return b 99 | if b is None: 100 | return a 101 | at = type(a) 102 | bt = type(b) 103 | if at == IntType and bt in (PointerType, StructType, ArrayType): 104 | return b 105 | if bt == IntType and at in (PointerType, StructType, ArrayType): 106 | return a 107 | if at == bt: 108 | if at == IntType: 109 | if a.width == b.width: 110 | return IntType(a.width, a.signed or b.signed) 111 | elif at in (FloatType, CharType) and a.width == b.width: 112 | return a 113 | elif at == ArrayType: 114 | am = a.member_type 115 | bm = b.member_type 116 | if a.length == b.length and self.union_types(am, bm) in ( 117 | am, 118 | bm, 119 | ): 120 | return a 121 | elif at == PointerType: 122 | ap = a.target_type 123 | bp = b.target_type 124 | 125 | if self.union_types(ap, bp) in (ap, bp): 126 | return a 127 | 128 | unioned_types = [] 129 | if at == UnionType: 130 | unioned_types.extend(a.fields) 131 | else: 132 | unioned_types.append(Field(a)) 133 | if bt == UnionType: 134 | unioned_types.extend(b.fields) 135 | else: 136 | unioned_types.append(Field(b)) 137 | self.debug("Unioning: %s", unioned_types) 138 | return UnionType(unioned_types) 139 | 140 | def resolve_label(self, sketch: Sketch, node: SkNode) -> SketchNode: 141 | if isinstance(node, LabelNode): 142 | self.info("Resolved label: %s", node) 143 | return sketch.lookup(node.target) 144 | return node 145 | 146 | def _succ_no_loadstore( 147 | self, 148 | curr_access_path: List[AccessPathLabel], 149 | sketch: Sketch, 150 | node: SketchNode, 151 | seen: Set[SketchNode], 152 | ) -> List[Tuple[List[AccessPathLabel], SketchNode]]: 153 | """ 154 | Collect a list of SketchNode successors of `node` 155 | in the sketch graph `sketch` by traversing load and store 156 | labels. Each `SketchNode` successor is accompanied 157 | by the access path that leads to it. 158 | The access path will coincide with the node's DTV unless 159 | LabelNodes have been traversed. 160 | """ 161 | successors = [] 162 | if node not in seen: 163 | seen.add(node) 164 | for n in sketch.sketches.successors(node): 165 | new_label = sketch.sketches[node][n]["label"] 166 | n = self.resolve_label(sketch, n) 167 | if n is None: 168 | continue 169 | if n.dtv.tail in (StoreLabel.instance(), LoadLabel.instance()): 170 | successors.extend( 171 | self._succ_no_loadstore( 172 | curr_access_path + [new_label], sketch, n, seen 173 | ) 174 | ) 175 | else: 176 | successors.append((curr_access_path + [new_label], n)) 177 | self.debug("Successors %s --> %s", node, successors) 178 | return successors 179 | 180 | def merge_counts(self, count_set: Set[int]) -> int: 181 | """ 182 | Given a set of element counts from a node, merge them into a single count. 183 | """ 184 | if not count_set: 185 | return 1 186 | elif len(count_set) == 1: 187 | return count_set.pop() 188 | elif DerefLabel.COUNT_NULLTERM in count_set: 189 | return DerefLabel.COUNT_NULLTERM 190 | return DerefLabel.COUNT_NOBOUND 191 | 192 | def c_type_from_nodeset( 193 | self, 194 | base_dtv: DerivedTypeVariable, 195 | sketch: Sketch, 196 | ns: Set[SketchNode], 197 | ) -> Optional[CType]: 198 | """ 199 | Given a derived type var, sketches, and a set of nodes, produce a C-like type. The 200 | set of nodes must be all for the same access path (e.g., foo[0]), but can be different 201 | _ways_ of accessing it (e.g., foo.load.s8@0 and foo.store.s8@0). 202 | """ 203 | ns = set([self.resolve_label(sketch, n) for n in ns]) 204 | assert None not in ns 205 | 206 | # Check cache 207 | for n in ns: 208 | if n.dtv in self.dtv2type[base_dtv]: 209 | self.info("Already cached (recursive type): %s", n.dtv) 210 | return self.dtv2type[base_dtv][n.dtv] 211 | 212 | children = list( 213 | itertools.chain( 214 | *[self._succ_no_loadstore([], sketch, n, set()) for n in ns] 215 | ) 216 | ) 217 | available_tails = {access_paths[-1] for access_paths, _ in children} 218 | 219 | if len(children) == 0: 220 | # Compute the atomic type bounds and size bound 221 | lb = self.lattice.bottom 222 | ub = self.lattice.top 223 | sz = 0 224 | counts = set() 225 | for n in ns: 226 | tail = n.dtv.tail 227 | if tail is not None and isinstance(tail, DerefLabel): 228 | byte_size = tail.size 229 | counts.add(tail.count) 230 | else: 231 | byte_size = self.default_int_size 232 | lb = self.lattice.join(lb, n.lower_bound) 233 | ub = self.lattice.meet(ub, n.upper_bound) 234 | sz = max(sz, byte_size) 235 | 236 | # Convert it to a CType 237 | rv = self.lattice_ctypes.atom_to_ctype(lb, ub, sz) 238 | count = self.merge_counts(counts) 239 | if count > 1: 240 | rv = ArrayType(rv, count) 241 | elif count == DerefLabel.COUNT_NULLTERM: 242 | # C type for null terminated string is [w]char[_t]* 243 | if rv.size in (1, 2, 4): # Valid character sizes 244 | rv = CharType(rv.size) 245 | else: 246 | self.info( 247 | "Unexpected character size for null-terminated string: %d", 248 | rv.size, 249 | ) 250 | # In C, unbounded arrays are represented as pointers, which is what this deref 251 | # will be represented as default. XXX in future we could change ArrayType to allow 252 | # for representing unboundedness. 253 | 254 | for n in ns: 255 | self.dtv2type[base_dtv][n.dtv] = rv 256 | self.debug("Terminal type: %s -> %s", ns, rv) 257 | elif all( 258 | isinstance(tail, (InLabel, OutLabel)) for tail in available_tails 259 | ): 260 | outputs = set() 261 | inputs = defaultdict(set) 262 | 263 | ptr_func = PointerType( 264 | FunctionType(None, []), self.default_ptr_size 265 | ) 266 | 267 | for n in ns: 268 | self.dtv2type[base_dtv][n.dtv] = ptr_func 269 | 270 | for access_path, child in children: 271 | tail = access_path[-1] 272 | 273 | if isinstance(tail, OutLabel): 274 | outputs.add(child) 275 | elif isinstance(tail, InLabel): 276 | inputs[tail.index].add(child) 277 | else: 278 | assert False, "Unreachable type generation state" 279 | 280 | ptr_func.target_type.return_type = self.c_type_from_nodeset( 281 | base_dtv, sketch, outputs 282 | ) 283 | ptr_func.target_type.params = [ 284 | self.c_type_from_nodeset(base_dtv, sketch, inputs[index]) 285 | if len(inputs[index]) > 0 286 | else self.lattice_ctypes.atom_to_ctype( 287 | self.lattice.bottom, 288 | self.lattice.top, 289 | self.default_int_size, 290 | ) 291 | for index in range(max(inputs.keys(), default=-1) + 1) 292 | ] 293 | 294 | return ptr_func 295 | else: 296 | # We could recurse on types below, so we populate the struct _first_ 297 | s = StructType() 298 | self.struct_types[s.name] = s 299 | count = self.merge_counts( 300 | { 301 | n.dtv.tail.count 302 | for n in ns 303 | if isinstance(n.dtv.tail, DerefLabel) 304 | } 305 | ) 306 | if count > 1: 307 | rv = ArrayType(s, count) 308 | else: 309 | rv = PointerType(s, self.default_ptr_size) 310 | for n in ns: 311 | self.dtv2type[base_dtv][n.dtv] = rv 312 | 313 | self.debug("%s has %d children", ns, len(children)) 314 | children_bases = { 315 | ( 316 | access_path[-1].offset, 317 | access_path[-1].offset 318 | + access_path[-1].count * access_path[-1].size, 319 | ): child 320 | for access_path, child in children 321 | if isinstance(access_path[-1], DerefLabel) 322 | and access_path[-1].count != 1 323 | } 324 | 325 | children_to_base = {} 326 | 327 | for access_path, child in children: 328 | if not isinstance(access_path[-1], DerefLabel): 329 | continue 330 | 331 | for ((start, end), base) in children_bases.items(): 332 | if access_path[-1].offset > start: 333 | if end < start: 334 | children_to_base[child] = start 335 | elif child.offset < end: 336 | children_to_base[child] = start 337 | 338 | children_by_offset = defaultdict(set) 339 | 340 | for access_path, c in children: 341 | tail = access_path[-1] 342 | if not isinstance(tail, DerefLabel): 343 | self.info(f"WARNING: {c.dtv} does not end in DerefLabel") 344 | continue 345 | if c in children_to_base: 346 | children_by_offset[children_to_base[c]].add(c) 347 | else: 348 | children_by_offset[tail.offset].add(c) 349 | 350 | fields = [] 351 | for offset, siblings in children_by_offset.items(): 352 | child_type = self.c_type_from_nodeset( 353 | base_dtv, sketch, siblings 354 | ) 355 | fields.append(Field(child_type, offset=offset)) 356 | s.set_fields(fields=fields) 357 | return rv 358 | 359 | def _simplify_pointers(self, typ: CType, seen: Set[CType]) -> CType: 360 | """ 361 | Look for all Pointer(Struct(FieldType)) patterns where the struct has a single field at 362 | offset = 0 and convert it to Pointer(FieldType). 363 | """ 364 | if typ in seen: 365 | return typ 366 | seen.add(typ) 367 | if isinstance(typ, Field): 368 | return Field(self._simplify_pointers(typ.ctype, seen), typ.offset) 369 | elif isinstance(typ, ArrayType): 370 | return ArrayType( 371 | self._simplify_pointers(typ.member_type, seen), typ.length 372 | ) 373 | elif isinstance(typ, PointerType): 374 | if isinstance(typ.target_type, StructType): 375 | s = typ.target_type 376 | if len(s.fields) == 1 and s.fields[0].offset == 0: 377 | rv = PointerType( 378 | self._simplify_pointers(s.fields[0].ctype, seen), 379 | self.default_ptr_size, 380 | ) 381 | self.info("Simplified pointer: %s", rv) 382 | return rv 383 | return PointerType( 384 | self._simplify_pointers(typ.target_type, seen), 385 | self.default_ptr_size, 386 | ) 387 | elif isinstance(typ, FunctionType): 388 | params = [self._simplify_pointers(t, seen) for t in typ.params] 389 | rt = self._simplify_pointers(typ.return_type, seen) 390 | return FunctionType(rt, params, name=typ.name) 391 | elif isinstance(typ, StructType): 392 | s = StructType(name=typ.name) 393 | s.set_fields( 394 | [self._simplify_pointers(t, seen) for t in typ.fields] 395 | ) 396 | return s 397 | return typ 398 | 399 | def __call__( 400 | self, 401 | simplify_pointers: bool = True, 402 | filter_to: Optional[Set[DerivedTypeVariable]] = None, 403 | ) -> Dict[DerivedTypeVariable, CType]: 404 | """ 405 | Generate CTypes. 406 | 407 | :param simplify_pointers: By default pointers to single-field structs are simplified to 408 | just be pointers to the base type of the first field. Set this to False to keep types 409 | normalized to always use structs to contain pointed-to data. 410 | :param filter_to: If specified, only emit types for the given base DerivedTypeVariables 411 | (typically globals or functions). If None (default), emit all types. 412 | """ 413 | dtv_to_type = {} 414 | for dtv, sketch in self.sketch_map.items(): 415 | if filter_to is not None and dtv not in filter_to: 416 | continue 417 | 418 | node = sketch.lookup(dtv) 419 | 420 | if node is None: 421 | continue 422 | 423 | maybe_ptr_func = self.c_type_from_nodeset(dtv, sketch, {node}) 424 | 425 | if isinstance(maybe_ptr_func, PointerType) and isinstance( 426 | maybe_ptr_func.target_type, FunctionType 427 | ): 428 | # Consumers expect the function itself, not a pointer to it 429 | maybe_ptr_func.target_type.name = dtv.base 430 | dtv_to_type[dtv] = maybe_ptr_func.target_type 431 | else: 432 | dtv_to_type[dtv] = maybe_ptr_func 433 | 434 | if simplify_pointers: 435 | self.debug("Simplifying pointers") 436 | new_dtv_to_type = {} 437 | for dtv, typ in dtv_to_type.items(): 438 | new_dtv_to_type[dtv] = self._simplify_pointers(typ, set()) 439 | dtv_to_type = new_dtv_to_type 440 | 441 | return dtv_to_type 442 | -------------------------------------------------------------------------------- /src/c_types.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | from abc import ABC 24 | from typing import Iterable, Optional, Sequence 25 | import os 26 | 27 | 28 | class CType(ABC): 29 | @property 30 | def size(self) -> int: 31 | pass 32 | 33 | @property 34 | def comment(self) -> Optional[str]: 35 | return None 36 | 37 | def pretty_print(self, name: str) -> str: 38 | return f"{self} {name}" 39 | 40 | 41 | class VoidType(CType): 42 | @property 43 | def size(self) -> int: 44 | return 0 45 | 46 | def __str__(self) -> str: 47 | return "void" 48 | 49 | 50 | class IntType(CType): 51 | def __init__(self, width: int, signed: bool) -> None: 52 | self.width = width 53 | self.signed = signed 54 | 55 | @property 56 | def size(self) -> int: 57 | return self.width 58 | 59 | def __str__(self) -> str: 60 | signed_tag = "" 61 | if not self.signed: 62 | signed_tag = "u" 63 | return f"{signed_tag}int{self.width*8}_t" 64 | 65 | 66 | class FloatType(CType): 67 | def __init__(self, width: int) -> None: 68 | self.width = width 69 | 70 | @property 71 | def size(self) -> int: 72 | return self.width 73 | 74 | def __str__(self) -> str: 75 | return f"float{self.width}_t" 76 | 77 | 78 | class CharType(CType): 79 | # no chars aren't _always_ 8 bits 80 | def __init__(self, width: int) -> None: 81 | self.width = width 82 | 83 | @property 84 | def size(self) -> int: 85 | return self.width 86 | 87 | def __str__(self) -> str: 88 | return f"char{self.width}_t" 89 | 90 | 91 | class BoolType(CType): 92 | def __init__(self, width: int) -> None: 93 | self.width = width 94 | 95 | @property 96 | def size(self) -> int: 97 | return self.width 98 | 99 | def __str__(self) -> str: 100 | return f"bool{self.width}_t" 101 | 102 | 103 | class ArrayType(CType): 104 | next_id = 0 105 | 106 | def __init__(self, member_type: CType, length: int) -> None: 107 | self.member_type = member_type 108 | self.length = length 109 | self.id = ArrayType.next_id 110 | ArrayType.next_id += 1 111 | 112 | @property 113 | def size(self) -> int: 114 | return self.member_type.size * self.length 115 | 116 | def __str__(self) -> str: 117 | return f"{self.member_type}[{self.length}]" 118 | 119 | def pretty_print(self, name: str) -> str: 120 | return f"{self.member_type} {name}[{self.length}]" 121 | 122 | 123 | class PointerType(CType): 124 | def __init__(self, target_type: CType, width: int) -> None: 125 | self.target_type = target_type 126 | self.width = width 127 | 128 | @property 129 | def size(self) -> int: 130 | return self.width 131 | 132 | def __str__(self) -> str: 133 | return f"{self.target_type}*" 134 | 135 | 136 | class FunctionType(CType): 137 | next_id = 0 138 | 139 | def __init__( 140 | self, 141 | return_type: CType, 142 | params: Sequence[CType], 143 | name: Optional[str] = None, 144 | ) -> None: 145 | self.return_type = return_type 146 | self.params = params 147 | if name: 148 | self.name = name 149 | else: 150 | self.name = f"function_{FunctionType.next_id}" 151 | FunctionType.next_id += 1 152 | 153 | @property 154 | def size(self) -> int: 155 | raise NotImplementedError() 156 | 157 | def __str__(self) -> str: 158 | return self.name 159 | 160 | def pretty_print(self, _name: str) -> str: 161 | return f'{self.return_type} {self.name}({", ".join(map(str, self.params))});' 162 | 163 | 164 | class Field: 165 | def __init__( 166 | self, ctype: CType, offset: Optional[int] = None, name: str = "" 167 | ) -> None: 168 | self.ctype = ctype 169 | self.offset = offset 170 | self.name = name 171 | 172 | @property 173 | def size(self) -> int: 174 | return self.ctype.size 175 | 176 | 177 | class CompoundType(CType): 178 | @property 179 | def compound_type(self) -> str: 180 | return "compound" 181 | 182 | @property 183 | def fields(self) -> Iterable[Field]: 184 | return [] 185 | 186 | @property 187 | def name(self) -> str: 188 | return "UNKNOWN" 189 | 190 | def __str__(self) -> str: 191 | return f"{self.compound_type} {self.name}" 192 | 193 | def pretty_print(self, name: str) -> str: 194 | nt = f"{os.linesep}\t" 195 | result = f"{self.compound_type} {name} {{" 196 | for index, field in enumerate(self.fields): 197 | name = f"field_{index}" 198 | result += f"{nt}{field.ctype.pretty_print(name)};" 199 | if field.offset is not None: 200 | result += f" // offset {field.offset}" 201 | return f"{result}{os.linesep}}};" 202 | 203 | 204 | class StructType(CompoundType): 205 | next_id = 0 206 | 207 | def __init__( 208 | self, fields: Iterable[Field] = [], name: Optional[str] = None 209 | ) -> None: 210 | self.set_fields(fields) 211 | if name: 212 | self._name = name 213 | else: 214 | self._name = f"struct_{StructType.next_id}" 215 | StructType.next_id += 1 216 | 217 | def set_fields(self, fields: Iterable[Field]): 218 | """ 219 | We need to be able to construct a Struct before populating it so that we can 220 | represent recursive types. 221 | """ 222 | self._fields = sorted(fields, key=lambda f: f.offset) 223 | 224 | @property 225 | def size(self) -> int: 226 | if not self.fields: 227 | return 0 228 | return self.fields[-1].offset + self.fields[-1].size 229 | 230 | @property 231 | def compound_type(self) -> str: 232 | return "struct" 233 | 234 | @property 235 | def name(self) -> str: 236 | return self._name 237 | 238 | @property 239 | def fields(self) -> Iterable[Field]: 240 | return self._fields 241 | 242 | 243 | class UnionType(CompoundType): 244 | next_id = 0 245 | 246 | def __init__( 247 | self, fields: Iterable[Field], name: Optional[str] = None 248 | ) -> None: 249 | self._fields = fields 250 | if name: 251 | self._name = name 252 | else: 253 | self._name = f"union_{UnionType.next_id}" 254 | UnionType.next_id += 1 255 | 256 | @property 257 | def size(self) -> int: 258 | return max(map(lambda t: t.size, self._fields)) 259 | 260 | @property 261 | def compound_type(self) -> str: 262 | return "union" 263 | 264 | @property 265 | def name(self) -> str: 266 | return self._name 267 | 268 | @property 269 | def fields(self) -> Iterable[Field]: 270 | return self._fields 271 | -------------------------------------------------------------------------------- /src/clattice.py: -------------------------------------------------------------------------------- 1 | from .schema import DerivedTypeVariable, Lattice, LatticeCTypes 2 | from .c_types import ( 3 | ArrayType, 4 | BoolType, 5 | CharType, 6 | FloatType, 7 | IntType, 8 | VoidType, 9 | ) 10 | from typing import FrozenSet 11 | import networkx 12 | 13 | 14 | class CLattice(Lattice[DerivedTypeVariable]): 15 | INT_SIZES = [8, 16, 32, 64] 16 | 17 | # Unsized C integers 18 | _int = DerivedTypeVariable("int") 19 | _int_size = [DerivedTypeVariable(f"int{z}") for z in INT_SIZES] 20 | _uint = DerivedTypeVariable("uint") 21 | _uint_size = [DerivedTypeVariable(f"uint{z}") for z in INT_SIZES] 22 | 23 | # Floats 24 | _float = DerivedTypeVariable("float") 25 | _double = DerivedTypeVariable("double") 26 | 27 | # Special types 28 | _void = DerivedTypeVariable("void") 29 | _char = DerivedTypeVariable("char") 30 | _bool = DerivedTypeVariable("bool") 31 | 32 | _top = DerivedTypeVariable("┬") 33 | _bottom = DerivedTypeVariable("┴") 34 | 35 | _internal = ( 36 | frozenset( 37 | { 38 | _int, 39 | _uint, 40 | _float, 41 | _double, 42 | _void, 43 | _char, 44 | _bool, 45 | } 46 | ) 47 | | frozenset(_int_size) 48 | | frozenset(_uint_size) 49 | ) 50 | _endcaps = frozenset({_top, _bottom}) 51 | 52 | def __init__(self) -> None: 53 | self.graph = networkx.DiGraph() 54 | self.graph.add_edge(self._uint, self.top) 55 | 56 | for dtv in self._uint_size: 57 | self.graph.add_edge(dtv, self._uint) 58 | 59 | self.graph.add_edge(self._int, self._uint) 60 | 61 | for dtv in self._int_size: 62 | self.graph.add_edge(dtv, self._int) 63 | 64 | for int_dtv, uint_dtv in zip(self._int_size, self._uint_size): 65 | self.graph.add_edge(int_dtv, uint_dtv) 66 | 67 | self.graph.add_edge(self._float, self.top) 68 | self.graph.add_edge(self._double, self.top) 69 | 70 | # char is a int8_t with some semantic information. NOTE: This assumes 71 | # that INT_SIZES[0] == 8 72 | self.graph.add_edge(self._char, self._int_size[0]) 73 | self.graph.add_edge(self._void, self.top) 74 | self.graph.add_edge(self._bool, self.top) 75 | 76 | for dtv in self._int_size: 77 | self.graph.add_edge(self._bottom, dtv) 78 | 79 | self.graph.add_edge(self._bottom, self._int) 80 | self.graph.add_edge(self._bottom, self._float) 81 | self.graph.add_edge(self._bottom, self._double) 82 | self.graph.add_edge(self._bottom, self._void) 83 | self.graph.add_edge(self._bottom, self._char) 84 | self.graph.add_edge(self._bottom, self._bool) 85 | 86 | assert all( 87 | len(self.graph.out_edges(dtv)) > 0 88 | and len(self.graph.in_edges(dtv)) > 0 89 | for dtv in self._internal 90 | ) 91 | 92 | try: 93 | networkx.find_cycle(self.graph) 94 | assert False, "Lattice cannot be circular" 95 | except networkx.NetworkXNoCycle: 96 | pass 97 | 98 | self.revgraph = self.graph.reverse() 99 | 100 | @property 101 | def atomic_types(self) -> FrozenSet[DerivedTypeVariable]: 102 | return CLattice._internal | CLattice._endcaps 103 | 104 | @property 105 | def internal_types(self) -> FrozenSet[DerivedTypeVariable]: 106 | return CLattice._internal 107 | 108 | @property 109 | def top(self) -> DerivedTypeVariable: 110 | return CLattice._top 111 | 112 | @property 113 | def bottom(self) -> DerivedTypeVariable: 114 | return CLattice._bottom 115 | 116 | def meet( 117 | self, t: DerivedTypeVariable, v: DerivedTypeVariable 118 | ) -> DerivedTypeVariable: 119 | return networkx.lowest_common_ancestor(self.graph, t, v) 120 | 121 | def join( 122 | self, t: DerivedTypeVariable, v: DerivedTypeVariable 123 | ) -> DerivedTypeVariable: 124 | return networkx.lowest_common_ancestor(self.revgraph, t, v) 125 | 126 | 127 | class CLatticeCTypes(LatticeCTypes): 128 | def atom_to_ctype(self, lower_bound, upper_bound, byte_size): 129 | if upper_bound == CLattice._top: 130 | atom = lower_bound 131 | elif lower_bound == CLattice._bottom: 132 | atom = upper_bound 133 | else: 134 | atom = lower_bound 135 | 136 | if atom in CLattice._int_size: 137 | return IntType( 138 | CLattice.INT_SIZES[CLattice._int_size.index(atom)] // 8, True 139 | ) 140 | 141 | if atom in CLattice._uint_size: 142 | return IntType( 143 | CLattice.INT_SIZES[CLattice._uint_size.index(atom)] // 8, False 144 | ) 145 | 146 | default = ArrayType(CharType(1), byte_size) 147 | 148 | return { 149 | CLattice._int: IntType(byte_size, True), 150 | CLattice._uint: IntType(byte_size, False), 151 | CLattice._void: VoidType(), 152 | CLattice._char: CharType(byte_size), 153 | CLattice._float: FloatType(4), 154 | CLattice._bool: BoolType(byte_size), 155 | CLattice._double: FloatType(8), 156 | }.get(atom, default) 157 | -------------------------------------------------------------------------------- /src/dummylattice.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | """An abstract lattice type for atomic types (e.g., primitives). Also includes a small 24 | implementation for reference. 25 | """ 26 | 27 | from typing import FrozenSet, Any 28 | from .schema import DerivedTypeVariable, Lattice, LatticeCTypes 29 | from .c_types import IntType, PointerType, CharType, ArrayType 30 | 31 | 32 | class DummyLattice(Lattice[DerivedTypeVariable]): 33 | _int = DerivedTypeVariable("int") 34 | _success = DerivedTypeVariable("#SuccessZ") 35 | _fd = DerivedTypeVariable("#FileDescriptor") 36 | _str = DerivedTypeVariable("str") 37 | _top = DerivedTypeVariable("┬") 38 | _bottom = DerivedTypeVariable("┴") 39 | _internal = frozenset({_int, _fd, _success, _str}) 40 | _endcaps = frozenset({_top, _bottom}) 41 | 42 | def __init__(self) -> None: 43 | pass 44 | 45 | @property 46 | def atomic_types(self) -> FrozenSet[DerivedTypeVariable]: 47 | return DummyLattice._internal | DummyLattice._endcaps 48 | 49 | @property 50 | def internal_types(self) -> FrozenSet[DerivedTypeVariable]: 51 | return DummyLattice._internal 52 | 53 | @property 54 | def top(self) -> DerivedTypeVariable: 55 | return DummyLattice._top 56 | 57 | @property 58 | def bottom(self) -> DerivedTypeVariable: 59 | return DummyLattice._bottom 60 | 61 | def meet( 62 | self, t: DerivedTypeVariable, v: DerivedTypeVariable 63 | ) -> DerivedTypeVariable: 64 | if t == v: 65 | return t 66 | # idempotence 67 | if t == DummyLattice._top: 68 | return v 69 | if v == DummyLattice._top: 70 | return t 71 | types = {t, v} 72 | # dominance 73 | if DummyLattice._bottom in types: 74 | return DummyLattice._bottom 75 | # the two types are not equal and neither is TOP or BOTTOM, so if either is STR then the two 76 | # are incomparable 77 | if DummyLattice._str in types: 78 | return DummyLattice._bottom 79 | # the remaining cases are integral types. If one is INT, they are comparable 80 | if DummyLattice._int in types: 81 | types -= {DummyLattice._int} 82 | return next(iter(types)) 83 | # the only remaining case is SUCCESS and FILE_DESCRIPTOR, which are not comparable 84 | return DummyLattice._bottom 85 | 86 | def join( 87 | self, t: DerivedTypeVariable, v: DerivedTypeVariable 88 | ) -> DerivedTypeVariable: 89 | if t == v: 90 | return t 91 | # idempotence 92 | if t == DummyLattice._bottom: 93 | return v 94 | if v == DummyLattice._bottom: 95 | return t 96 | types = {t, v} 97 | # dominance 98 | if DummyLattice._top in types: 99 | return DummyLattice._top 100 | # the two types are not equal and neither is TOP or BOTTOM, so if either is STR then the two 101 | # are incomparable 102 | if DummyLattice._str in types: 103 | return DummyLattice._top 104 | # the remaining cases are integral types. In all three combinations of two, the least upper 105 | # bound is INT. 106 | return DummyLattice._int 107 | 108 | 109 | class DummyLatticeCTypes(LatticeCTypes): 110 | def atom_to_ctype(self, atom_lower: Any, atom_upper: Any, byte_size: int): 111 | best = atom_lower if atom_lower != DummyLattice._bottom else atom_upper 112 | return { 113 | DummyLattice._int: IntType(byte_size, True), 114 | DummyLattice._success: IntType(byte_size, True), 115 | DummyLattice._fd: IntType(byte_size, False), 116 | DummyLattice._str: PointerType(CharType(1), byte_size), 117 | }.get(best, ArrayType(IntType(1, False), byte_size)) 118 | -------------------------------------------------------------------------------- /src/fast_enfa.py: -------------------------------------------------------------------------------- 1 | from pyformlang.finite_automaton import ( 2 | EpsilonNFA, 3 | State, 4 | DeterministicFiniteAutomaton, 5 | ) 6 | from typing import Iterable 7 | 8 | 9 | def to_single_state(l_states: Iterable[State]) -> State: 10 | """ 11 | Merge a list of states 12 | """ 13 | values = [str(state.value) if state else "TRASH" for state in l_states] 14 | return State(";".join(values)) 15 | 16 | 17 | class FastENFA(EpsilonNFA): 18 | def _to_deterministic_internal( 19 | self, eclose: bool 20 | ) -> DeterministicFiniteAutomaton: 21 | """ 22 | Transforms the epsilon-nfa into a dfa 23 | 24 | NOTE: This a a modified version of the original EpsilonNFA._to_deterministic_internal 25 | This has a few small changes: 26 | - Refactored the call to `add_final_state` to be less redundant 27 | - Added some additional checks to `all_trans` computation to filter out invalid items 28 | """ 29 | dfa = DeterministicFiniteAutomaton() 30 | # Add Eclose 31 | if eclose: 32 | start_eclose = self.eclose_iterable(self._start_state) 33 | else: 34 | start_eclose = self._start_state 35 | 36 | start_state = to_single_state(start_eclose) 37 | 38 | dfa.add_start_state(start_state) 39 | to_process = [start_eclose] 40 | processed = {start_state} 41 | 42 | while to_process: 43 | current = to_process.pop() 44 | s_from = to_single_state(current) 45 | 46 | if any(state in self._final_states for state in current): 47 | dfa.add_final_state(s_from) 48 | 49 | for symb in self._input_symbols: 50 | all_trans = [ 51 | self._transition_function._transitions[x][symb] 52 | for x in current 53 | if ( 54 | x in self._transition_function._transitions 55 | and symb in self._transition_function._transitions[x] 56 | ) 57 | ] 58 | states = set() 59 | for trans in all_trans: 60 | states = states.union(trans) 61 | if not states: 62 | continue 63 | # Eclose added 64 | if eclose: 65 | states = self.eclose_iterable(states) 66 | state_merged = to_single_state(states) 67 | dfa.add_transition(s_from, symb, state_merged) 68 | if state_merged not in processed: 69 | processed.add(state_merged) 70 | to_process.append(states) 71 | 72 | return dfa 73 | -------------------------------------------------------------------------------- /src/graph.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | from __future__ import annotations 24 | from collections import defaultdict 25 | from enum import Enum, unique 26 | from typing import AbstractSet, Any, Dict, Optional, Set, Tuple 27 | from .schema import ( 28 | AccessPathLabel, 29 | ConstraintSet, 30 | DerivedTypeVariable, 31 | LoadLabel, 32 | StoreLabel, 33 | Variance, 34 | ) 35 | import networkx 36 | import os 37 | 38 | 39 | class EdgeLabel: 40 | """A forget or recall label in the graph. Instances should never be mutated.""" 41 | 42 | @unique 43 | class Kind(Enum): 44 | FORGET = 1 45 | RECALL = 2 46 | 47 | def __init__(self, capability: AccessPathLabel, kind: Kind) -> None: 48 | self.capability = capability 49 | self.kind = kind 50 | if self.kind == EdgeLabel.Kind.FORGET: 51 | type_str = "forget" 52 | else: 53 | type_str = "recall" 54 | self._str = f"{type_str} {self.capability}" 55 | self._hash = hash(self.capability) ^ hash(self.kind) 56 | 57 | def __eq__(self, other: EdgeLabel) -> bool: 58 | return ( 59 | isinstance(other, EdgeLabel) 60 | and self.capability == other.capability 61 | and self.kind == other.kind 62 | ) 63 | 64 | def __lt__(self, other: EdgeLabel) -> bool: 65 | if not isinstance(other, EdgeLabel): 66 | raise ValueError(f"Cannot compare EdgeLabel to {type(other)}") 67 | return self._str < other._str 68 | 69 | def __hash__(self) -> int: 70 | return self._hash 71 | 72 | def __str__(self) -> str: 73 | return self._str 74 | 75 | def __repr__(self) -> str: 76 | return self._str 77 | 78 | 79 | @unique 80 | class SideMark(Enum): 81 | """ 82 | Marking of interesting graph nodes to avoid non-elementary proofs. 83 | See Definition D.2 and Note 1 in section D.1 of the paper. 84 | """ 85 | 86 | NO = 0 87 | LEFT = 1 88 | RIGHT = 2 89 | 90 | 91 | class Node: 92 | """A node in the graph of constraints. Node objects are immutable. 93 | 94 | Forgotten is a flag used to differentiate between two subgraphs later in the algorithm. See 95 | :py:method:`Solver._recall_forget_split` for details. 96 | """ 97 | 98 | @unique 99 | class Forgotten(Enum): 100 | PRE_FORGET = 0 101 | POST_FORGET = 1 102 | 103 | def __init__( 104 | self, 105 | base: DerivedTypeVariable, 106 | suffix_variance: Variance, 107 | side_mark: SideMark = SideMark.NO, 108 | forgotten: Forgotten = Forgotten.PRE_FORGET, 109 | ) -> None: 110 | self.base = base 111 | self.suffix_variance = suffix_variance 112 | self.side_mark = side_mark 113 | if side_mark == SideMark.LEFT: 114 | side_mark_str = "L:" 115 | elif side_mark == SideMark.RIGHT: 116 | side_mark_str = "R:" 117 | else: 118 | side_mark_str = "" 119 | if suffix_variance == Variance.COVARIANT: 120 | variance = ".⊕" 121 | else: 122 | variance = ".⊖" 123 | self._forgotten = forgotten 124 | if forgotten == Node.Forgotten.POST_FORGET: 125 | self._str = "F:" + side_mark_str + str(self.base) + variance 126 | else: 127 | self._str = side_mark_str + str(self.base) + variance 128 | self._hash = hash( 129 | (self.base, self.suffix_variance, self.side_mark, self._forgotten) 130 | ) 131 | 132 | def __eq__(self, other: Any) -> bool: 133 | if not isinstance(other, Node): 134 | return False 135 | 136 | if self._hash != other._hash: 137 | return False 138 | 139 | return self._str == other._str 140 | 141 | def __lt__(self, other: Node) -> bool: 142 | if not isinstance(other, Node): 143 | raise ValueError( 144 | f"Cannot compare objects of type Node and {type(other)} " 145 | ) 146 | return self._hash < other._hash 147 | 148 | def __hash__(self) -> int: 149 | return self._hash 150 | 151 | def forget_once( 152 | self, 153 | ) -> Tuple[Optional[AccessPathLabel], Optional[Node]]: 154 | """ "Forget" the last element in the access path, creating a new Node. The new Node has 155 | variance that reflects this change. 156 | """ 157 | if self.base.path: 158 | prefix_path = list(self.base.path) 159 | last = prefix_path.pop() 160 | prefix = DerivedTypeVariable(self.base.base, prefix_path) 161 | return ( 162 | last, 163 | Node( 164 | prefix, 165 | Variance.combine(last.variance(), self.suffix_variance), 166 | self.side_mark, 167 | ), 168 | ) 169 | return (None, None) 170 | 171 | def recall(self, label: AccessPathLabel) -> Node: 172 | """ "Recall" label, creating a new Node. The new Node has variance that reflects this 173 | change. 174 | """ 175 | path = list(self.base.path) 176 | path.append(label) 177 | variance = Variance.combine(self.suffix_variance, label.variance()) 178 | return Node( 179 | DerivedTypeVariable(self.base.base, path), 180 | variance, 181 | self.side_mark, 182 | ) 183 | 184 | def __str__(self) -> str: 185 | return self._str 186 | 187 | def __repr__(self) -> str: 188 | return self._str 189 | 190 | def split_recall_forget(self) -> Node: 191 | """Get a duplicate of self for use in the post-recall subgraph.""" 192 | return Node( 193 | self.base, 194 | self.suffix_variance, 195 | self.side_mark, 196 | Node.Forgotten.POST_FORGET, 197 | ) 198 | 199 | def inverse(self, keep_same_mark: bool = False) -> Node: 200 | """ 201 | Get a Node identical to this one but with inverted variance and mark. 202 | If keep_same_mark is true, the side mark is not inverted. 203 | """ 204 | if keep_same_mark: 205 | new_side_mark = self.side_mark 206 | else: 207 | new_side_mark = SideMark.NO 208 | if self.side_mark == SideMark.LEFT: 209 | new_side_mark = SideMark.RIGHT 210 | elif self.side_mark == SideMark.RIGHT: 211 | new_side_mark = SideMark.LEFT 212 | return Node( 213 | self.base, 214 | Variance.invert(self.suffix_variance), 215 | new_side_mark, 216 | self._forgotten, 217 | ) 218 | 219 | 220 | class ConstraintGraph: 221 | """Represents the constraint graph in the slides. Essentially the same as the transducer from 222 | Appendix D. Edge weights use the formulation from the paper. 223 | """ 224 | 225 | def __init__( 226 | self, 227 | constraints: ConstraintSet, 228 | interesting_vars: Set[DerivedTypeVariable], 229 | keep_graph_before_split: bool = False, 230 | ) -> None: 231 | self.graph = networkx.DiGraph() 232 | for constraint in constraints.subtype: 233 | self.add_edges(constraint.left, constraint.right, interesting_vars) 234 | self.saturate() 235 | self._remove_self_loops() 236 | if keep_graph_before_split: 237 | self.graph_before_split = self.graph.copy() 238 | self._recall_forget_split() 239 | 240 | # Regular language: RECALL*FORGET* (i.e., FORGET cannot precede RECALL) 241 | def _recall_forget_split(self) -> None: 242 | """The algorithm, after saturation, only admits paths such that recall edges all precede 243 | the first forget edge (if there is such an edge). To enforce this, we modify the graph by 244 | splitting each node and the unlabeled and forget edges (but not recall edges!). Forget edges 245 | in the original graph are changed to point to the 'forgotten' duplicate of their original 246 | target. As a result, no recall edges are reachable after traversing a single forget edge. 247 | """ 248 | for head, tail in list(self.graph.edges): 249 | atts = self.graph[head][tail] 250 | label = atts.get("label") 251 | if label and label.kind == EdgeLabel.Kind.RECALL: 252 | continue 253 | forget_head = head.split_recall_forget() 254 | forget_tail = tail.split_recall_forget() 255 | if label and label.kind == EdgeLabel.Kind.FORGET: 256 | self.graph.remove_edge(head, tail) 257 | self.graph.add_edge(head, forget_tail, **atts) 258 | self.graph.add_edge(forget_head, forget_tail, **atts) 259 | 260 | def add_edge(self, head: Node, tail: Node, **atts) -> bool: 261 | """Add an edge to the graph. The optional atts dict should include, if anything, a mapping 262 | from the string 'label' to an EdgeLabel object. 263 | """ 264 | if head not in self.graph or tail not in self.graph[head]: 265 | self.graph.add_edge(head, tail, **atts) 266 | return True 267 | return False 268 | 269 | def add_edges( 270 | self, 271 | sub: DerivedTypeVariable, 272 | sup: DerivedTypeVariable, 273 | interesting_vars: Set[DerivedTypeVariable], 274 | **atts, 275 | ) -> bool: 276 | """Add an edge to the underlying graph. Also add its reverse with reversed variance. 277 | Each constraint, becomes two pushdown rules in the paper. 278 | In each case, we add recall edges only to the left-hand term of the rule 279 | and forget edges to the right-hand side. 280 | """ 281 | changed = False 282 | left = ( 283 | SideMark.LEFT if sub.base_var in interesting_vars else SideMark.NO 284 | ) 285 | right = ( 286 | SideMark.RIGHT if sup.base_var in interesting_vars else SideMark.NO 287 | ) 288 | forward_from = Node(sub, Variance.COVARIANT, left) 289 | forward_to = Node(sup, Variance.COVARIANT, right) 290 | changed = self.add_edge(forward_from, forward_to, **atts) or changed 291 | self.add_recalls(forward_from) 292 | self.add_forgets(forward_to) 293 | backward_from = forward_to.inverse() 294 | backward_to = forward_from.inverse() 295 | changed = self.add_edge(backward_from, backward_to, **atts) or changed 296 | self.add_recalls(backward_from) 297 | self.add_forgets(backward_to) 298 | return changed 299 | 300 | def add_recalls(self, node: Node) -> None: 301 | """ 302 | Recall edges are added for the left-hand side of constraints 303 | """ 304 | (capability, prefix) = node.forget_once() 305 | while prefix: 306 | self.add_edge( 307 | prefix, 308 | node, 309 | label=EdgeLabel(capability, EdgeLabel.Kind.RECALL), 310 | ) 311 | node = prefix 312 | (capability, prefix) = node.forget_once() 313 | 314 | def add_forgets(self, node: Node) -> None: 315 | """ 316 | Forget edges are added for the right-hand side of constraints 317 | """ 318 | (capability, prefix) = node.forget_once() 319 | while prefix: 320 | self.add_edge( 321 | node, 322 | prefix, 323 | label=EdgeLabel(capability, EdgeLabel.Kind.FORGET), 324 | ) 325 | node = prefix 326 | (capability, prefix) = node.forget_once() 327 | 328 | def saturate(self) -> None: 329 | """Add "shortcut" edges, per algorithm D.2 in the paper.""" 330 | changed = False 331 | reaching_R: Dict[ 332 | Node, Set[Tuple[AccessPathLabel, Node]] 333 | ] = defaultdict(set) 334 | 335 | def add_forgets( 336 | dest: Node, forgets: Set[Tuple[AccessPathLabel, Node]] 337 | ): 338 | nonlocal changed 339 | if dest not in reaching_R or not (forgets <= reaching_R[dest]): 340 | changed = True 341 | reaching_R[dest].update(forgets) 342 | 343 | def add_edge(origin: Node, dest: Node): 344 | nonlocal changed 345 | changed = self.add_edge(origin, dest) or changed 346 | 347 | def is_contravariant(node: Node) -> bool: 348 | return node.suffix_variance == Variance.CONTRAVARIANT 349 | 350 | for head_x, tail_y in self.graph.edges: 351 | label = self.graph[head_x][tail_y].get("label") 352 | if label and label.kind == EdgeLabel.Kind.FORGET: 353 | add_forgets(tail_y, {(label.capability, head_x)}) 354 | while changed: 355 | changed = False 356 | for head_x, tail_y in self.graph.edges: 357 | if not self.graph[head_x][tail_y].get("label"): 358 | add_forgets(tail_y, reaching_R[head_x]) 359 | existing_edges = list(self.graph.edges) 360 | for head_x, tail_y in existing_edges: 361 | label = self.graph[head_x][tail_y].get("label") 362 | if label and label.kind == EdgeLabel.Kind.RECALL: 363 | capability_l = label.capability 364 | for (label, origin_z) in reaching_R[head_x]: 365 | if label == capability_l: 366 | add_edge(origin_z, tail_y) 367 | 368 | contravariant_vars = filter(is_contravariant, self.graph.nodes) 369 | for x in contravariant_vars: 370 | for (capability_l, origin_z) in reaching_R[x]: 371 | label = None 372 | if isinstance(capability_l, StoreLabel): 373 | label = LoadLabel.instance() 374 | if isinstance(capability_l, LoadLabel): 375 | label = StoreLabel.instance() 376 | if label: 377 | add_forgets( 378 | x.inverse(keep_same_mark=True), {(label, origin_z)} 379 | ) 380 | 381 | def _remove_self_loops(self) -> None: 382 | """Loops from a node directly to itself are not useful, so it's useful to remove them.""" 383 | self.graph.remove_edges_from( 384 | {(node, node) for node in self.graph.nodes} 385 | ) 386 | 387 | @classmethod 388 | def from_constraints( 389 | cls, 390 | constraints: ConstraintSet, 391 | interesting_vars: AbstractSet[DerivedTypeVariable], 392 | ) -> networkx.DiGraph: 393 | return cls(constraints, interesting_vars).graph 394 | 395 | @staticmethod 396 | def edge_to_str(graph, edge: Tuple[Node, Node]) -> str: 397 | """A helper for __str__ that formats an edge""" 398 | width = 2 + max(map(lambda v: len(str(v)), graph.nodes)) 399 | (sub, sup) = edge 400 | label = graph[sub][sup].get("label") 401 | edge_str = f"{str(sub):<{width}}→ {str(sup):<{width}}" 402 | if label: 403 | return edge_str + f" ({label})" 404 | else: 405 | return edge_str 406 | 407 | @staticmethod 408 | def graph_to_str(graph: networkx.DiGraph) -> str: 409 | nt = os.linesep + "\t" 410 | edge_to_str = lambda edge: ConstraintGraph.edge_to_str(graph, edge) 411 | return f"{nt.join(map(edge_to_str, graph.edges))}" 412 | 413 | def __str__(self) -> str: 414 | nt = os.linesep + "\t" 415 | return ( 416 | f"ConstraintGraph:{nt}{ConstraintGraph.graph_to_str(self.graph)}" 417 | ) 418 | 419 | 420 | def remove_unreachable_states( 421 | graph: networkx.DiGraph, start_nodes: Set[Node], end_nodes: Set[Node] 422 | ) -> Tuple[networkx.DiGraph, Set[Node], Set[Node]]: 423 | """ 424 | Remove states that not reachable from start_nodes or do not reach end_nodes. 425 | This can speed up path exploration since we do not have to search 426 | paths through nodes that do not reach interesting destinations. 427 | """ 428 | if len(graph) == 0 or len(start_nodes) == 0 or len(end_nodes) == 0: 429 | return graph, set(), set() 430 | 431 | reachable_nodes = set( 432 | networkx.multi_source_dijkstra_path_length(graph, start_nodes).keys() 433 | ) 434 | rev_reachable_nodes = set( 435 | networkx.multi_source_dijkstra_path_length( 436 | graph.reverse(copy=False), end_nodes 437 | ).keys() 438 | ) 439 | keep = reachable_nodes & rev_reachable_nodes 440 | keep_start = start_nodes & keep 441 | keep_end = end_nodes & keep 442 | return graph.subgraph(keep), keep_start, keep_end 443 | -------------------------------------------------------------------------------- /src/graph_solver.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Optional, Set, Any 3 | 4 | from .fast_enfa import FastENFA 5 | from .pathexpr import RExp, scc_decompose_path_seq, solve_paths_from 6 | from .graph import ( 7 | EdgeLabel, 8 | Node, 9 | remove_unreachable_states, 10 | ) 11 | from .schema import ( 12 | ConstraintSet, 13 | SubtypeConstraint, 14 | Variance, 15 | ) 16 | import abc 17 | import networkx 18 | from dataclasses import dataclass 19 | from pyformlang.finite_automaton import ( 20 | EpsilonNFA, 21 | State, 22 | Symbol, 23 | Epsilon, 24 | ) 25 | 26 | 27 | @dataclass 28 | class GraphSolverConfig: 29 | # Maximum path length when converts the constraint graph into output constraints 30 | max_path_length: int = 2**64 31 | # Maximum number of paths to explore per type variable root when generating output constraints 32 | max_paths_per_root: int = 2**64 33 | # Maximum paths total to explore per SCC 34 | max_total_paths: int = 2**64 35 | # Restrict graph to reachable nodes from and to endpoints 36 | # before doing the path exploration. 37 | restrict_graph_to_reachable: bool = True 38 | 39 | 40 | def _maybe_constraint( 41 | origin: Node, 42 | dest: Node, 43 | string: List[EdgeLabel], 44 | ) -> Optional[SubtypeConstraint]: 45 | """Generate constraints by adding the forgets in string to origin and the recalls in string 46 | to dest. If both of the generated vertices are covariant (the empty string's variance is 47 | covariant, so only covariant vertices can represent a type_scheme type variable without an 48 | elided portion of its path) and if the two variables are not equal, emit a constraint. 49 | """ 50 | lhs = origin 51 | rhs = dest 52 | forgets = [] 53 | recalls = [] 54 | for label in string: 55 | if label.kind == EdgeLabel.Kind.FORGET: 56 | forgets.append(label.capability) 57 | else: 58 | recalls.append(label.capability) 59 | 60 | for recall in recalls: 61 | lhs = lhs.recall(recall) 62 | for forget in reversed(forgets): 63 | rhs = rhs.recall(forget) 64 | 65 | if ( 66 | lhs.suffix_variance == Variance.COVARIANT 67 | and rhs.suffix_variance == Variance.COVARIANT 68 | ): 69 | lhs_var = lhs.base 70 | rhs_var = rhs.base 71 | if lhs_var != rhs_var: 72 | return SubtypeConstraint(lhs_var, rhs_var) 73 | return None 74 | 75 | 76 | class GraphSolver(abc.ABC): 77 | """ 78 | Given a graph of constraints, solve this to a new set of constraints, which 79 | should be smaller than the original set of constraints 80 | """ 81 | 82 | def __init__(self, config: GraphSolverConfig): 83 | self.config = config 84 | 85 | def _generate_constraints_from_to_internal( 86 | self, 87 | graph: networkx.DiGraph, 88 | start_nodes: Set[Node], 89 | end_nodes: Set[Node], 90 | ) -> ConstraintSet: 91 | raise NotImplementedError() 92 | 93 | def generate_constraints_from_to( 94 | self, 95 | graph: networkx.DiGraph, 96 | start_nodes: Set[Node], 97 | end_nodes: Set[Node], 98 | ) -> ConstraintSet: 99 | """ 100 | Generate a set of final constraints from a set of start_nodes to a set of 101 | end_nodes based on the given graph. 102 | Use path expressions or naive exploration depending on the 103 | Solver's configuration. 104 | """ 105 | if self.config.restrict_graph_to_reachable: 106 | graph, start_nodes, end_nodes = remove_unreachable_states( 107 | graph, start_nodes, end_nodes 108 | ) 109 | 110 | return self._generate_constraints_from_to_internal( 111 | graph, start_nodes, end_nodes 112 | ) 113 | 114 | 115 | class DFAGraphSolver(GraphSolver): 116 | START = State("$$START$$") 117 | FINAL = State("$$FINAL$$") 118 | 119 | @classmethod 120 | def _graph_to_enfa( 121 | cls, 122 | graph: networkx.DiGraph, 123 | start_nodes: Set[Node], 124 | end_nodes: Set[Node], 125 | ) -> EpsilonNFA: 126 | """ 127 | Generate an ε-NFA from graph. 128 | """ 129 | enfa = FastENFA() 130 | 131 | for (from_node, to_node, label) in graph.edges(data="label"): 132 | if label is None: 133 | sym = Epsilon() 134 | else: 135 | sym = Symbol(label) 136 | 137 | enfa.add_transition(State(from_node), sym, State(to_node)) 138 | 139 | enfa.add_start_state(cls.START) 140 | enfa.add_final_state(cls.FINAL) 141 | 142 | for start in start_nodes: 143 | enfa.add_transition(cls.START, Symbol(start), State(start)) 144 | 145 | for end in end_nodes: 146 | enfa.add_transition(State(end), Symbol(end), cls.FINAL) 147 | 148 | return enfa 149 | 150 | def _generate_constraints_from_to_internal( 151 | self, 152 | graph: networkx.DiGraph, 153 | start_nodes: Set[Node], 154 | end_nodes: Set[Node], 155 | ) -> ConstraintSet: 156 | """ 157 | Treat the graph as a ε-NFA, then convert to a DFA and subsequent minimal 158 | DFA. Compute path labels between start/ends over minimized DFA. 159 | """ 160 | enfa = self._graph_to_enfa(graph, start_nodes, end_nodes) 161 | mdfa = enfa.minimize() 162 | dfa_g = mdfa.to_networkx() 163 | 164 | constraints = ConstraintSet() 165 | 166 | for final_state in mdfa.final_states: 167 | for path in networkx.all_simple_edge_paths( 168 | dfa_g, mdfa.start_state, final_state 169 | ): 170 | path_labels = [ 171 | dfa_g.get_edge_data(s, e)[index]["label"] 172 | for s, e, index in path 173 | ] 174 | start_node = path_labels[0] 175 | end_node = path_labels[-1] 176 | 177 | constraint = _maybe_constraint( 178 | start_node, end_node, path_labels[1:-1] 179 | ) 180 | 181 | if constraint: 182 | constraints.add(constraint) 183 | 184 | return constraints 185 | 186 | 187 | class PathExprGraphSolver(GraphSolver): 188 | @staticmethod 189 | def cross_concatenation( 190 | prefix_list: List[List[Any]], postfix_list: List[List[Any]] 191 | ) -> List[List[Any]]: 192 | """ 193 | Compute the cross product concatenation of two lists of lists. 194 | """ 195 | combined = [] 196 | for prefix in prefix_list: 197 | for postfix in postfix_list: 198 | combined.append(prefix + postfix) 199 | return combined 200 | 201 | @classmethod 202 | def enumerate_non_looping_paths( 203 | cls, path_expr: RExp 204 | ) -> List[List[EdgeLabel]]: 205 | """ 206 | Given a path expression, return a list of all the paths 207 | that do not involve loops. 208 | """ 209 | if path_expr.label == RExp.Label.NULL: 210 | return [] 211 | elif path_expr.label == RExp.Label.EMPTY: 212 | return [[]] 213 | elif path_expr.label == RExp.Label.NODE: 214 | return [[path_expr.data]] 215 | # ignore looping paths 216 | elif path_expr.label == RExp.Label.STAR: 217 | return [[]] 218 | elif path_expr.label == RExp.Label.DOT: 219 | paths = [[]] 220 | for child in path_expr.children: 221 | paths = cls.cross_concatenation( 222 | paths, cls.enumerate_non_looping_paths(child) 223 | ) 224 | return paths 225 | elif path_expr.label == RExp.Label.OR: 226 | paths = [] 227 | for child in path_expr.children: 228 | paths.extend(cls.enumerate_non_looping_paths(child)) 229 | return paths 230 | else: 231 | assert False 232 | 233 | def _generate_constraints_from_to_internal( 234 | self, 235 | graph: networkx.DiGraph, 236 | start_nodes: Set[Node], 237 | end_nodes: Set[Node], 238 | ) -> ConstraintSet: 239 | """ 240 | Generate constraints based on the computation of path expressions. 241 | Compute path expressions for each pair of start and end nodes. 242 | For each path expression, enumerate non-looping paths. 243 | """ 244 | numbering, path_seq = scc_decompose_path_seq(graph, "label") 245 | constraints = ConstraintSet() 246 | for start_node in start_nodes: 247 | path_exprs = solve_paths_from(path_seq, numbering[start_node]) 248 | for end_node in end_nodes: 249 | indices = (numbering[start_node], numbering[end_node]) 250 | path_expr = path_exprs[indices] 251 | for path in self.enumerate_non_looping_paths(path_expr): 252 | constraint = _maybe_constraint(start_node, end_node, path) 253 | if constraint: 254 | constraints.add(constraint) 255 | return constraints 256 | 257 | 258 | class NaiveGraphSolver(GraphSolver): 259 | def _generate_constraints_from_to_internal( 260 | self, 261 | graph: networkx.DiGraph, 262 | start_nodes: Set[Node], 263 | end_nodes: Set[Node], 264 | ) -> ConstraintSet: 265 | """ 266 | Generate constraints based on the naive exploration of the graph. 267 | """ 268 | constraints = ConstraintSet() 269 | npaths = 0 270 | # On large procedures, the graph this is exploring can be quite large (hundreds of nodes, 271 | # thousands of edges). This can result in an insane number of paths - most of which do not 272 | # result in a constraint, and most of the ones that do result in constraints are redundant. 273 | def explore( 274 | current_node: Node, 275 | path: List[Node] = [], 276 | string: List[EdgeLabel] = [], 277 | ) -> None: 278 | """Find all non-empty paths that begin at start_nodes and end at end_nodes. Return 279 | the list of labels encountered along the way as well as the current_node and destination. 280 | """ 281 | nonlocal max_paths_per_root 282 | nonlocal npaths 283 | if len(path) > self.config.max_path_length: 284 | return 285 | if npaths > max_paths_per_root: 286 | return 287 | if path and current_node in end_nodes: 288 | constraint = _maybe_constraint(path[0], current_node, string) 289 | 290 | if constraint: 291 | constraints.add(constraint) 292 | npaths += 1 293 | return 294 | if current_node in path: 295 | npaths += 1 296 | return 297 | 298 | path = list(path) 299 | path.append(current_node) 300 | if current_node in graph: 301 | for succ in graph[current_node]: 302 | label = graph[current_node][succ].get("label") 303 | new_string = list(string) 304 | if label: 305 | new_string.append(label) 306 | explore(succ, path, new_string) 307 | 308 | # We evenly distribute the maximum number of paths that we are willing to explore 309 | # across all origin nodes here. 310 | max_paths_per_root = int( 311 | min( 312 | self.config.max_paths_per_root, 313 | self.config.max_total_paths / float(len(start_nodes) + 1), 314 | ) 315 | ) 316 | for origin in start_nodes: 317 | npaths = 0 318 | explore(origin) 319 | return constraints 320 | -------------------------------------------------------------------------------- /src/loggable.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | from enum import Enum 24 | import tqdm 25 | 26 | 27 | class LogLevel(int, Enum): 28 | QUIET = 0 29 | INFO = 1 30 | DEBUG = 2 31 | 32 | 33 | # Unfortunable, the python logging class is a bit flawed and overly complex for what we need 34 | # When you use info/debug you can use %s/%d/etc formatting ala logging to lazy evaluate 35 | class Loggable: 36 | def __init__(self, verbose: LogLevel = LogLevel.QUIET): 37 | self.verbose = verbose 38 | 39 | def info(self, *args): 40 | if self.verbose >= LogLevel.INFO: 41 | print(str(args[0]) % tuple(args[1:])) 42 | 43 | def debug(self, *args): 44 | if self.verbose >= LogLevel.DEBUG: 45 | print(str(args[0]) % tuple(args[1:])) 46 | 47 | 48 | def show_progress(verbose, iterable): 49 | if verbose: 50 | return tqdm.tqdm(iterable) 51 | return iterable 52 | -------------------------------------------------------------------------------- /src/parser.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | """Parsing helpers, mostly for unit testing. 24 | """ 25 | 26 | import re 27 | from .schema import ( 28 | AccessPathLabel, 29 | DerefLabel, 30 | DerivedTypeVariable, 31 | InLabel, 32 | LoadLabel, 33 | OutLabel, 34 | StoreLabel, 35 | SubtypeConstraint, 36 | Variance, 37 | ConstraintSet, 38 | ) 39 | from .graph import EdgeLabel, Node 40 | from typing import Dict, Tuple, List 41 | 42 | 43 | class SchemaParser: 44 | """Static helper functions for Schema tests. Since this parsing code is unlikely to be useful in 45 | the code itself, it is included here. 46 | """ 47 | 48 | subtype_pattern = re.compile(r"(\S*) (?:⊑|<=) (\S*)") 49 | in_pattern = re.compile("in_([0-9]+)") 50 | deref_pattern = re.compile( 51 | "σ([0-9]+)@(-?[0-9]+)(\*\[(([0-9]+)|nullterm|nobound)\])?" 52 | ) 53 | node_pattern = re.compile(r"(\S+)\.([⊕⊖])") 54 | edge_pattern = re.compile( 55 | r"(\S+)\s+(?:→|->)\s+(\S+)(\s+\((forget|recall) (\S*)\))?" 56 | ) 57 | whitespace_pattern = re.compile(r"\s") 58 | 59 | @staticmethod 60 | def parse_label(label: str) -> AccessPathLabel: 61 | """Parse an AccessPathLabel. Raises ValueError if it is improperly formatted.""" 62 | if label == "load": 63 | return LoadLabel.instance() 64 | if label == "store": 65 | return StoreLabel.instance() 66 | if label == "out": 67 | return OutLabel.instance() 68 | in_match = SchemaParser.in_pattern.match(label) 69 | if in_match: 70 | return InLabel(int(in_match.group(1))) 71 | deref_match = SchemaParser.deref_pattern.match(label) 72 | if deref_match: 73 | count = deref_match.group(4) 74 | if count == "nullterm": 75 | count = DerefLabel.COUNT_NULLTERM 76 | elif count == "nobound": 77 | count = DerefLabel.COUNT_NOBOUND 78 | elif count is not None: 79 | count = int(count) 80 | else: 81 | count = 1 82 | return DerefLabel( 83 | int(deref_match.group(1)), int(deref_match.group(2)), count 84 | ) 85 | raise ValueError 86 | 87 | @staticmethod 88 | def parse_variable(var: str) -> DerivedTypeVariable: 89 | """Parse a DerivedTypeVariable. Raises ValueError if the string contains whitespace.""" 90 | if SchemaParser.whitespace_pattern.match(var): 91 | raise ValueError 92 | components = var.split(".") 93 | path = [SchemaParser.parse_label(label) for label in components[1:]] 94 | return DerivedTypeVariable(components[0], path) 95 | 96 | @staticmethod 97 | def parse_variables(vars: List[str]) -> List[DerivedTypeVariable]: 98 | """ 99 | Parse a list of DerivedTypeVariable. Raises ValueError 100 | if any of the strings contains a whitespace. 101 | """ 102 | return [SchemaParser.parse_variable(var) for var in vars] 103 | 104 | @staticmethod 105 | def parse_constraint(constraint: str) -> SubtypeConstraint: 106 | """Parse a SubtypeConstraint. Raises a ValueError if constraint does not match 107 | SchemaParser.subtype_pattern. 108 | """ 109 | subtype_match = SchemaParser.subtype_pattern.match(constraint) 110 | if subtype_match: 111 | return SubtypeConstraint( 112 | SchemaParser.parse_variable(subtype_match.group(1)), 113 | SchemaParser.parse_variable(subtype_match.group(2)), 114 | ) 115 | raise ValueError 116 | 117 | @staticmethod 118 | def parse_constraint_set(constraints: List[str]) -> ConstraintSet: 119 | """ 120 | Parse a list of constraints into a ConstraintSet 121 | """ 122 | cs = ConstraintSet() 123 | for c in constraints: 124 | cs.add(SchemaParser.parse_constraint(c)) 125 | return cs 126 | 127 | @staticmethod 128 | def parse_node(node: str) -> Node: 129 | """Parse a Node. Raise a ValueError if it does not match SchemaParser.node_pattern.""" 130 | node_match = SchemaParser.node_pattern.match(node) 131 | if node_match: 132 | var = SchemaParser.parse_variable(node_match.group(1)) 133 | if node_match.group(2) == "⊕": 134 | variance = Variance.COVARIANT 135 | elif node_match.group(2) == "⊖": 136 | variance = Variance.CONTRAVARIANT 137 | else: 138 | raise ValueError 139 | return Node(var, variance) 140 | raise ValueError 141 | 142 | @staticmethod 143 | def parse_edge(edge: str) -> Tuple[Node, Node, Dict[str, EdgeLabel]]: 144 | """Parse an edge in the graph, which consists of two nodes and an arrow, with an optional 145 | edge label. 146 | """ 147 | edge_match = SchemaParser.edge_pattern.match(edge) 148 | if edge_match: 149 | sub = SchemaParser.parse_node(edge_match.group(1)) 150 | sup = SchemaParser.parse_node(edge_match.group(2)) 151 | atts = {} 152 | if edge_match.group(3): 153 | capability = SchemaParser.parse_label(edge_match.group(5)) 154 | if edge_match.group(4) == "forget": 155 | kind = EdgeLabel.Kind.FORGET 156 | elif edge_match.group(4) == "recall": 157 | kind = EdgeLabel.Kind.RECALL 158 | else: 159 | raise ValueError 160 | atts["label"] = EdgeLabel(capability, kind) 161 | return (sub, sup, atts) 162 | raise ValueError 163 | -------------------------------------------------------------------------------- /src/pathexpr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | from typing import Any, Dict, List, Tuple 4 | import networkx 5 | 6 | 7 | # Randomly generated hash constant for hash combination 8 | HC_K = 0x9DDFEA08EB382D69 9 | # 64-bit integer wrap around 10 | HC_M = (2**64) - 1 11 | 12 | 13 | def hash_combine(lhs: int, rhs: int) -> int: 14 | """Combine two hashes using a simple algorithm from the original Retypd implementation""" 15 | a = rhs 16 | 17 | a = ((a ^ lhs) * HC_K) & HC_M 18 | a ^= a >> 47 19 | 20 | a = ((a ^ lhs) * HC_K) & HC_M 21 | a ^= a >> 47 22 | 23 | return a 24 | 25 | 26 | class RExp: 27 | """Regular expression class with some helper methods and simplification""" 28 | 29 | class Label: 30 | """Hard-coded label numbers, which each were randomly generated so as 31 | to not need to be hashed, instead their value operates as their hash 32 | """ 33 | 34 | NULL = 0xFA1118ECFC2E5C78 35 | EMPTY = 0xEA759E93D6E9FF37 36 | NODE = 0x091F738A11133080 37 | DOT = 0xAF201303AC824465 38 | OR = 0xBDF5EA44D36816CB 39 | STAR = 0xC2B42C47A64FC2AD 40 | 41 | def __init__(self, label: int, data=None, children=[]): 42 | self.label = label 43 | self.data: Any = data 44 | self.children: Tuple[RExp] = tuple(children) 45 | self.hash = hash_combine(self.label, hash(self.data)) 46 | 47 | for child in self.children: 48 | self.hash = hash_combine(self.hash, child.hash) 49 | 50 | def __hash__(self) -> int: 51 | return self.hash 52 | 53 | def __and__(self, rhs: RExp) -> RExp: 54 | return RExp(self.Label.DOT, children=(self, rhs)) 55 | 56 | def __or__(self, rhs: RExp) -> RExp: 57 | return RExp(self.Label.OR, children=(self, rhs)) 58 | 59 | def star(self): 60 | return RExp(self.Label.STAR, children=(self,)) 61 | 62 | def __eq__(self, other: RExp) -> bool: 63 | if self.hash != other.hash: 64 | return False 65 | if self.label != other.label: 66 | return False 67 | if self.label in (self.Label.NULL, self.Label.EMPTY): 68 | return True 69 | elif self.label in (self.Label.DOT, self.Label.OR, self.Label.STAR): 70 | return all(a == b for a, b in zip(self.children, other.children)) 71 | elif self.label == self.Label.NODE: 72 | return self.data == other.data 73 | else: 74 | raise NotImplementedError() 75 | 76 | def __lt__(self, other: RExp) -> bool: 77 | """ 78 | Compare two regular expressions. 79 | """ 80 | if not isinstance(other, RExp): 81 | raise ValueError(f"Cannot compare RExp to {type(other)}") 82 | if self is other: 83 | return False 84 | if self.hash != other.hash: 85 | return self.hash < other.hash 86 | if self.label != other.label: 87 | return self.label < other.label 88 | # Same label 89 | if self.label in (self.Label.NULL, self.Label.EMPTY): 90 | return False 91 | elif self.label in (self.Label.DOT, self.Label.OR, self.Label.STAR): 92 | if len(self.children) != len(other.children): 93 | return len(self.children) < len(other.children) 94 | else: 95 | for a, b in zip(self.children, other.children): 96 | if a < b: 97 | return True 98 | elif a > b: 99 | return False 100 | # all equal 101 | return False 102 | elif self.label == self.Label.NODE: 103 | return self.data < other.data 104 | 105 | raise NotImplementedError() 106 | 107 | @classmethod 108 | def null(cls) -> RExp: 109 | return RExp(cls.Label.NULL) 110 | 111 | @classmethod 112 | def empty(cls) -> RExp: 113 | return RExp(cls.Label.EMPTY) 114 | 115 | @classmethod 116 | def node(cls, data) -> RExp: 117 | return RExp(cls.Label.NODE, data=data) 118 | 119 | @classmethod 120 | def from_graph_edge( 121 | cls, graph: networkx.DiGraph, src: Any, dest: Any, data: str 122 | ) -> RExp: 123 | """Generate a regular expression from a graph node. For no label on 124 | data we assume an empty string, otherwise a node labeled by that label 125 | """ 126 | attrs = graph[src][dest] 127 | if data not in attrs: 128 | return cls.empty() 129 | else: 130 | return cls.node(attrs[data]) 131 | 132 | def simplify(self) -> RExp: 133 | """Regular expression simplification procedure, page 9 134 | 135 | This simplification procedure has been expanded to 136 | deal with many other sources of redundancy. 137 | The simplification is not recursive, it only simplifies 138 | the two top levels of the RExp, deeper levels have 139 | been already simplified. 140 | 141 | """ 142 | children = self.children 143 | 144 | if self.label == self.Label.OR: 145 | # use a set to avoid duplicates 146 | new_children = set() 147 | for child in children: 148 | # flatten nested OR 149 | if child.label == RExp.Label.OR: 150 | new_children |= set(child.children) 151 | else: 152 | if not child.is_null: 153 | new_children.add(child) 154 | if len(new_children) == 0: 155 | return RExp.null() 156 | elif len(new_children) == 1: 157 | return new_children.pop() 158 | else: 159 | return RExp(RExp.Label.OR, children=sorted(new_children)) 160 | elif self.label == self.Label.DOT: 161 | if any(child.is_null for child in children): 162 | return RExp.null() 163 | new_children = [] 164 | for child in children: 165 | # flatten nested DOT 166 | if child.label == RExp.Label.DOT: 167 | new_children.extend(child.children) 168 | else: 169 | if not child.is_empty: 170 | new_children.append(child) 171 | if len(new_children) == 0: 172 | return RExp.empty() 173 | elif len(new_children) == 1: 174 | return new_children.pop() 175 | return RExp(RExp.Label.DOT, children=new_children) 176 | elif self.label == self.Label.STAR: 177 | if children[0].is_null or children[0].is_empty: 178 | return RExp.empty() 179 | return self 180 | 181 | @property 182 | def is_null(self) -> bool: 183 | return self.label == self.Label.NULL 184 | 185 | @property 186 | def is_empty(self) -> bool: 187 | return self.label == self.Label.EMPTY 188 | 189 | @property 190 | def is_node(self) -> bool: 191 | return self.label == self.Label.NODE 192 | 193 | def __repr__(self) -> str: 194 | if self.label == self.Label.OR: 195 | return "(" + " U ".join(map(repr, self.children)) + ")" 196 | 197 | elif self.label == self.Label.DOT: 198 | return "(" + " . ".join(map(repr, self.children)) + ")" 199 | elif self.label == self.Label.STAR: 200 | return f"{self.children[0]}*" 201 | elif self.label == self.Label.EMPTY: 202 | return "Λ" 203 | elif self.label == self.Label.NULL: 204 | return "∅" 205 | elif self.label == self.Label.NODE: 206 | return f"{self.data}" 207 | else: 208 | raise NotImplementedError() 209 | 210 | 211 | def eliminate( 212 | graph: networkx.DiGraph, data: str, min_num: int, max_num: int 213 | ) -> Dict[Tuple[int, int], RExp]: 214 | """ELIMINATE procedure of Tarjan's path expression algorithm, page 13""" 215 | # Initialize 216 | P = defaultdict(lambda: RExp.null()) 217 | 218 | # consider edges in the subgraph defined by the range 219 | for h, t in graph.edges(range(min_num, max_num)): 220 | if t < min_num or t >= max_num: 221 | continue 222 | edge = RExp.from_graph_edge(graph, h, t, data) 223 | P[h, t] = (P[h, t] | edge).simplify() 224 | 225 | # Loop 226 | for v in range(min_num, max_num): 227 | P[v, v] = P[v, v].star().simplify() 228 | 229 | for u in range(v + 1, max_num): 230 | if P[u, v].is_null: 231 | continue 232 | 233 | P[u, v] = (P[u, v] & P[v, v]).simplify() 234 | 235 | for w in range(v + 1, max_num): 236 | if P[v, w].is_null: 237 | continue 238 | 239 | P[u, w] = (P[u, w] | (P[u, v] & P[v, w]).simplify()).simplify() 240 | 241 | return P 242 | 243 | 244 | PathSeq = List[Tuple[Tuple[int, int], RExp]] 245 | 246 | 247 | def compute_path_sequence( 248 | P: Dict[Tuple[int, int], RExp], 249 | min_num: int, 250 | max_num: int, 251 | ) -> PathSeq: 252 | """Compute path sequence from the ELIMINATE procedure, per Theorem 4 on 253 | page 14 254 | """ 255 | 256 | # Compute ascending and descending path sequences that are in the queried 257 | # range for this path sequence 258 | valid_range = range(min_num, max_num) 259 | ascending = [] 260 | descending = [] 261 | 262 | for indices, expr in P.items(): 263 | start, end = indices 264 | 265 | if start not in valid_range or end not in valid_range: 266 | continue 267 | 268 | if expr.is_null: 269 | continue 270 | # no need to include empty self-paths 271 | if expr.is_empty and start == end: 272 | continue 273 | 274 | if start <= end: 275 | ascending.append((indices, expr)) 276 | else: 277 | descending.append((indices, expr)) 278 | 279 | # Sort by the starting node 280 | output = sorted(ascending, key=lambda pair: pair[0][0]) + sorted( 281 | descending, key=lambda pair: pair[0][0], reverse=True 282 | ) 283 | 284 | return output 285 | 286 | 287 | def solve_paths_from( 288 | path_seq: PathSeq, source: int 289 | ) -> Dict[Tuple[int, int], RExp]: 290 | """Solve path expressions from a source given a path sequence for a 291 | numbered graph, per procedure SOLVE on page 9 of Tarjan 292 | """ 293 | P = defaultdict(lambda: RExp.null()) 294 | P[source, source] = RExp.empty() 295 | 296 | for (v_i, w_i), P_i in path_seq: 297 | if v_i == w_i: 298 | P[source, v_i] = (P[source, v_i] & P_i).simplify() 299 | else: 300 | P[source, w_i] = ( 301 | P[source, w_i] | (P[source, v_i] & P_i).simplify() 302 | ).simplify() 303 | 304 | return P 305 | 306 | 307 | def from_numeral_graph(numeral_graph: networkx.DiGraph, number: int) -> Any: 308 | """Translate from numeral graph to the original node in the graph""" 309 | return numeral_graph.nodes[number]["original"] 310 | 311 | 312 | GraphNumbering = Dict[Any, int] 313 | 314 | 315 | def topological_numbering( 316 | graph: networkx.DiGraph, 317 | ) -> Tuple[GraphNumbering, networkx.DiGraph]: 318 | """Generate a numeral graph from the topological sort of the DAG""" 319 | nodes = list(networkx.topological_sort(graph)) 320 | numbering = {node: num for num, node in enumerate(nodes)} 321 | rev_numbering = dict(enumerate(nodes)) 322 | numeral_graph = networkx.relabel_nodes(graph, numbering) 323 | networkx.set_node_attributes(numeral_graph, rev_numbering, name="original") 324 | return numbering, numeral_graph 325 | 326 | 327 | def dag_path_seq(graph: networkx.DiGraph, data: str) -> PathSeq: 328 | """Per Theorem 5, Page 14 of the paper, generate path sequences for a 329 | directed acyclic graph in a more efficient manner. 330 | """ 331 | # Sort edges by increasing source node 332 | edges = sorted(graph.edges(), key=lambda x: x[0]) 333 | return [ 334 | ((h, t), RExp.from_graph_edge(graph, h, t, data)) for h, t in edges 335 | ] 336 | 337 | 338 | def scc_decompose_path_seq( 339 | graph: networkx.DiGraph, data: str 340 | ) -> Tuple[GraphNumbering, PathSeq]: 341 | """Per Theorem 6, Page 14 of the paper, generate path sequences for a graph 342 | that has been decomposed into strongly connected components. 343 | """ 344 | # Generate the graph of SCCs 345 | component_graph = networkx.condensation(graph) 346 | 347 | scc_numberings = {} 348 | graph_numbering = {} 349 | curr_number = 0 350 | 351 | # Generate a numbering for each SCC in increasing topological order to 352 | # maintain that any edge G_i -> G_j whare are in SCCs S_i and S_j 353 | # respectively, if theres an edge in the condensation from S_i -> S_j then 354 | # G_i < G_j 355 | for component in networkx.topological_sort(component_graph): 356 | # Generate numbering for this SCC, we do so in a sorted order as 357 | # NetworkX returns a set whose non-deterministic ordering can return 358 | # inconsistent results 359 | scc = component_graph.nodes[component]["members"] 360 | start_number = curr_number 361 | 362 | for elem in sorted(scc): 363 | graph_numbering[elem] = curr_number 364 | curr_number += 1 365 | 366 | # Update whole-graph numbering, and keep note of SCC ranges 367 | scc_numberings[component] = (start_number, curr_number) 368 | 369 | # Do the actual relabeling of the graph to the numeral graph 370 | number_graph = networkx.relabel_nodes(graph, graph_numbering) 371 | rev_numbering = {v: k for k, v in graph_numbering.items()} 372 | networkx.set_node_attributes(number_graph, rev_numbering, name="original") 373 | 374 | scc_seqs: List[Tuple[int, PathSeq]] = [] 375 | 376 | # Do ELIMINATE for every SCC, by using a slice of the graph from the 377 | # min/max of the given SCC's numbering 378 | for component in networkx.topological_sort(component_graph): 379 | min_num, max_num = scc_numberings[component] 380 | P = eliminate(number_graph, data, min_num, max_num) 381 | seqs = compute_path_sequence(P, min_num, max_num) 382 | scc_seqs.append((component, seqs)) 383 | 384 | output: PathSeq = [] 385 | 386 | for component, seqs in scc_seqs: 387 | # Add the intra-SCC path sequence nodes 388 | output += seqs 389 | 390 | # Add inter-SCC path sequence nodes 391 | scc = component_graph.nodes[component]["members"] 392 | 393 | for in_edge, out_edge in graph.out_edges(scc): 394 | if out_edge not in scc: 395 | in_num = graph_numbering[in_edge] 396 | out_num = graph_numbering[out_edge] 397 | output.append( 398 | ( 399 | (in_num, out_num), 400 | RExp.from_graph_edge(graph, in_edge, out_edge, data), 401 | ) 402 | ) 403 | 404 | return graph_numbering, output 405 | 406 | 407 | def path_expression_between( 408 | graph: networkx.DiGraph, 409 | data: str, 410 | source: Any, 411 | sink: Any, 412 | decompose=True, 413 | ): 414 | """Per Lemma 1 on page 9, handle output of SOLVE procedure""" 415 | # First, compute path sequences of the graph 416 | if not decompose: 417 | # Generate numberings for the nodes 418 | number_graph = networkx.convert_node_labels_to_integers( 419 | graph, label_attribute="original" 420 | ) 421 | 422 | numbering = { 423 | original: number 424 | for number, original in number_graph.nodes(data="original") 425 | } 426 | 427 | N = len(number_graph.nodes) 428 | P = eliminate(number_graph, data, 0, N) 429 | seqs = compute_path_sequence(P, 0, N) 430 | elif networkx.is_directed_acyclic_graph(graph): 431 | # Fast path for DAGs 432 | numbering, number_graph = topological_numbering(graph) 433 | seqs = dag_path_seq(number_graph, data) 434 | else: 435 | numbering, seqs = scc_decompose_path_seq(graph, data) 436 | 437 | # Solve all paths for source, and output the one for (source, sink) 438 | paths = solve_paths_from(seqs, numbering[source]) 439 | return paths[(numbering[source], numbering[sink])] 440 | -------------------------------------------------------------------------------- /src/schema.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | """Data types for an implementation of retypd analysis. 24 | """ 25 | from __future__ import annotations 26 | from abc import ABC 27 | from enum import Enum, unique 28 | from functools import reduce 29 | from typing import ( 30 | Any, 31 | Dict, 32 | FrozenSet, 33 | Generic, 34 | Iterable, 35 | Iterator, 36 | List, 37 | Optional, 38 | Sequence, 39 | Set, 40 | Tuple, 41 | TypeVar, 42 | Union, 43 | ) 44 | import os 45 | import networkx 46 | 47 | 48 | @unique 49 | class Variance(Enum): 50 | """Represents a capability's variance (or that of some sequence of capabilities).""" 51 | 52 | CONTRAVARIANT = 0 53 | COVARIANT = 1 54 | 55 | @staticmethod 56 | def invert(variance: Variance) -> Variance: 57 | if variance == Variance.CONTRAVARIANT: 58 | return Variance.COVARIANT 59 | return Variance.CONTRAVARIANT 60 | 61 | @staticmethod 62 | def combine(lhs: Variance, rhs: Variance) -> Variance: 63 | if lhs == rhs: 64 | return Variance.COVARIANT 65 | return Variance.CONTRAVARIANT 66 | 67 | 68 | class AccessPathLabel(ABC): 69 | """Abstract class for capabilities that can be part of a path. See Table 1. 70 | 71 | All :py:class:`AccessPathLabel` objects are comparable to each other; objects are ordered by 72 | their classes (in an arbitrary order defined by the string representation of their type), then 73 | by values specific to their subclass. So objects of class A always precede objects of class B 74 | and objects of class A are ordered with respect to each other by :py:method:`_less_than`. 75 | """ 76 | 77 | def __lt__(self, other: AccessPathLabel) -> bool: 78 | if not isinstance(other, AccessPathLabel): 79 | raise TypeError(f"Cannot compare {self} and {other}") 80 | s_type = str(type(self)) 81 | o_type = str(type(other)) 82 | if s_type == o_type: 83 | return self._less_than(other) 84 | return s_type < o_type 85 | 86 | def _less_than(self, _other) -> bool: 87 | """Compare two objects of the same exact type. Return True if self is less than other; true 88 | otherwise. Several of the subclasses are singletons, so we return False unless there is a 89 | need for an overriding implementation. 90 | """ 91 | return False 92 | 93 | def variance(self) -> Variance: 94 | """Determines if the access path label is covariant or contravariant, per Table 1.""" 95 | return Variance.COVARIANT 96 | 97 | 98 | class LoadLabel(AccessPathLabel): 99 | """A singleton representing the load (read) capability.""" 100 | 101 | _instance = None 102 | 103 | def __init__(self) -> None: 104 | raise ValueError("Can't instantiate; call instance() instead") 105 | 106 | @classmethod 107 | def instance(cls): 108 | if cls._instance is None: 109 | cls._instance = cls.__new__(cls) 110 | return cls._instance 111 | 112 | def __eq__(self, other: Any) -> bool: 113 | return self is other 114 | 115 | def __hash__(self) -> int: 116 | return 0 117 | 118 | def __str__(self) -> str: 119 | return "load" 120 | 121 | 122 | class StoreLabel(AccessPathLabel): 123 | """A singleton representing the store (write) capability.""" 124 | 125 | _instance = None 126 | 127 | def __init__(self) -> None: 128 | raise ValueError("Can't instantiate; call instance() instead") 129 | 130 | @classmethod 131 | def instance(cls): 132 | if cls._instance is None: 133 | cls._instance = cls.__new__(cls) 134 | return cls._instance 135 | 136 | def __eq__(self, other: Any) -> bool: 137 | return self is other 138 | 139 | def __hash__(self) -> int: 140 | return 1 141 | 142 | def variance(self) -> Variance: 143 | return Variance.CONTRAVARIANT 144 | 145 | def __str__(self) -> str: 146 | return "store" 147 | 148 | 149 | class InLabel(AccessPathLabel): 150 | """Represents a parameter to a function, specified by an index (e.g., the first argument might 151 | use index 0, the second might use index 1, and so on). N.B.: this is a capability and is not 152 | tied to any particular function. 153 | """ 154 | 155 | def __init__(self, index: int) -> None: 156 | self.index = index 157 | 158 | def __eq__(self, other: Any) -> bool: 159 | return isinstance(other, InLabel) and self.index == other.index 160 | 161 | def _less_than(self, other: InLabel) -> bool: 162 | return self.index < other.index 163 | 164 | def __hash__(self) -> int: 165 | return hash(self.index) 166 | 167 | def variance(self) -> Variance: 168 | return Variance.CONTRAVARIANT 169 | 170 | def __str__(self) -> str: 171 | return f"in_{self.index}" 172 | 173 | 174 | class OutLabel(AccessPathLabel): 175 | """Represents a return from a function. This class is a singleton.""" 176 | 177 | _instance = None 178 | 179 | def __init__(self) -> None: 180 | raise ValueError("Can't instantiate; call instance() instead") 181 | 182 | @classmethod 183 | def instance(cls): 184 | if cls._instance is None: 185 | cls._instance = cls.__new__(cls) 186 | return cls._instance 187 | 188 | def __eq__(self, other: Any) -> bool: 189 | return self is other 190 | 191 | def __hash__(self) -> int: 192 | return 2 193 | 194 | def __str__(self) -> str: 195 | return "out" 196 | 197 | 198 | class DerefLabel(AccessPathLabel): 199 | """Represents a dereference in an access path. Specifies a size (the number of bytes read or 200 | written) and an offset (the number of bytes from the base) and an optional count (for array- 201 | like accesses that do size*count accesses in a loop). 202 | """ 203 | 204 | # An unknown number of elements 205 | COUNT_NOBOUND = -1 206 | # A null-terminated string 207 | COUNT_NULLTERM = -2 208 | 209 | def __init__(self, size: int, offset: int, count: int = 1) -> None: 210 | self.size = size 211 | self.offset = offset 212 | self.count = count 213 | 214 | def __eq__(self, other: Any) -> bool: 215 | return ( 216 | isinstance(other, DerefLabel) 217 | and self.size == other.size 218 | and self.offset == other.offset 219 | and self.count == other.count 220 | ) 221 | 222 | def _less_than(self, other: DerefLabel) -> bool: 223 | return (self.offset, self.size, self.count) < ( 224 | other.offset, 225 | other.size, 226 | other.count, 227 | ) 228 | 229 | def __hash__(self) -> int: 230 | return hash((self.offset, self.size, self.count)) 231 | 232 | def __str__(self) -> str: 233 | srep = f"σ{self.size}@{self.offset}" 234 | if self.count > 1: 235 | srep += f"*[{self.count}]" 236 | elif self.count == self.COUNT_NOBOUND: 237 | srep += "*[nobound]" 238 | elif self.count == self.COUNT_NULLTERM: 239 | srep += "*[nullterm]" 240 | return srep 241 | 242 | 243 | class DerivedTypeVariable: 244 | """A _derived_ type variable, per Definition 3.1. Immutable (by convention).""" 245 | 246 | def __init__( 247 | self, type_var: str, path: Optional[Sequence[AccessPathLabel]] = None 248 | ) -> None: 249 | self._base = type_var 250 | if path is None: 251 | self._path: Sequence[AccessPathLabel] = () 252 | else: 253 | self._path = tuple(path) 254 | # Precomputing the hash is a big performance boost (since we are immutable) 255 | self._hash = hash((self._base, self._path)) 256 | 257 | @property 258 | def base(self): 259 | return self._base 260 | 261 | @property 262 | def path(self): 263 | return self._path 264 | 265 | # We weakly "enforce" mutability 266 | @base.setter 267 | def base(self, value): 268 | raise NotImplementedError("Read-only property") 269 | 270 | @path.setter 271 | def path(self, value): 272 | raise NotImplementedError("Read-only property") 273 | 274 | def format(self, separator: str = ".") -> str: 275 | return separator.join([str(self._base)] + list(map(str, self._path))) 276 | 277 | def __eq__(self, other: Any) -> bool: 278 | return ( 279 | isinstance(other, DerivedTypeVariable) 280 | and self._base == other.base 281 | and self._path == other.path 282 | ) 283 | 284 | def __lt__(self, other: DerivedTypeVariable) -> bool: 285 | if self._base == other.base: 286 | if len(self._path) != len(other.path): 287 | return len(self._path) < len(other.path) 288 | return list(self._path) < list(other.path) 289 | return self._base < other.base 290 | 291 | def __hash__(self) -> int: 292 | return self._hash 293 | 294 | @property 295 | def largest_prefix(self) -> Optional[DerivedTypeVariable]: 296 | """Return the prefix obtained by removing the last item from the type variable's path. If 297 | there is no path, return None. 298 | """ 299 | if self._path: 300 | return DerivedTypeVariable(self._base, self._path[:-1]) 301 | return None 302 | 303 | def all_prefixes(self) -> Set[DerivedTypeVariable]: 304 | """Return all prefixes of self, including self.""" 305 | var = self 306 | result: Set[DerivedTypeVariable] = set() 307 | while var: 308 | result.add(var) 309 | var = var.largest_prefix 310 | return result 311 | 312 | def get_suffix( 313 | self, other: DerivedTypeVariable 314 | ) -> Optional[Sequence[AccessPathLabel]]: 315 | """If self is a prefix of other, return the suffix of other's path that is not part of self. 316 | Otherwise, return None. 317 | """ 318 | if self._base != other.base: 319 | return None 320 | if len(self._path) > len(other.path): 321 | return None 322 | for s_item, o_item in zip(self._path, other.path): 323 | if s_item != o_item: 324 | return None 325 | return other.path[len(self._path) :] 326 | 327 | @property 328 | def tail(self) -> AccessPathLabel: 329 | """Retrieve the last item in the access path, if any. Return None if 330 | the path is empty. 331 | """ 332 | if self._path: 333 | return self._path[-1] 334 | return None 335 | 336 | def add_suffix(self, suffix: AccessPathLabel) -> DerivedTypeVariable: 337 | """Create a new :py:class:`DerivedTypeVariable` identical to :param:`self` (which is 338 | unchanged) but with suffix appended to its path. 339 | """ 340 | path: List[AccessPathLabel] = list(self._path) 341 | path.append(suffix) 342 | return DerivedTypeVariable(self._base, path) 343 | 344 | def extend(self, suffix: Iterable[AccessPathLabel]) -> DerivedTypeVariable: 345 | path: List[AccessPathLabel] = list(self._path) 346 | path.extend(suffix) 347 | return DerivedTypeVariable(self._base, path) 348 | 349 | @property 350 | def base_var(self) -> DerivedTypeVariable: 351 | return DerivedTypeVariable(self._base) 352 | 353 | @property 354 | def path_variance(self) -> Variance: 355 | """Determine the variance of the access path.""" 356 | variances = map(lambda label: label.variance(), self._path) 357 | return reduce(Variance.combine, variances, Variance.COVARIANT) 358 | 359 | def __str__(self) -> str: 360 | return self.format() 361 | 362 | def __repr__(self) -> str: 363 | return self.format("$") 364 | 365 | 366 | class SubtypeConstraint: 367 | """A type constraint of the form left ⊑ right (see Definition 3.3)""" 368 | 369 | def __init__( 370 | self, left: DerivedTypeVariable, right: DerivedTypeVariable 371 | ) -> None: 372 | self.left = left 373 | self.right = right 374 | 375 | def __eq__(self, other: Any) -> bool: 376 | return ( 377 | isinstance(other, SubtypeConstraint) 378 | and self.left == other.left 379 | and self.right == other.right 380 | ) 381 | 382 | def __lt__(self, other: SubtypeConstraint) -> bool: 383 | if self.left == other.left: 384 | return self.right < other.right 385 | return self.left < other.left 386 | 387 | def __hash__(self) -> int: 388 | return hash(self.left) ^ hash(self.right) 389 | 390 | def __str__(self) -> str: 391 | return f"{self.left} ⊑ {self.right}" 392 | 393 | def __repr__(self) -> str: 394 | return str(self) 395 | 396 | 397 | class ConstraintSet: 398 | """A (partitioned) set of type constraints""" 399 | 400 | def __init__( 401 | self, subtype: Optional[Iterable[SubtypeConstraint]] = None 402 | ) -> None: 403 | if subtype: 404 | self.subtype = set(subtype) 405 | else: 406 | self.subtype = set() 407 | 408 | def add(self, constraint: SubtypeConstraint) -> bool: 409 | if constraint in self.subtype: 410 | return False 411 | self.subtype.add(constraint) 412 | return True 413 | 414 | def __eq__(self, other: Any) -> bool: 415 | return ( 416 | isinstance(other, ConstraintSet) and self.subtype == other.subtype 417 | ) 418 | 419 | def all_dtvs(self) -> Set[DerivedTypeVariable]: 420 | """ 421 | Return al the derived type variables in a 422 | constraint set. 423 | """ 424 | dtvs = set() 425 | for c in self: 426 | dtvs.add(c.left) 427 | dtvs.add(c.right) 428 | return dtvs 429 | 430 | def all_tvs(self) -> Set[DerivedTypeVariable]: 431 | """ 432 | Return all type variables in a constraint set. 433 | """ 434 | tvs = set() 435 | for c in self: 436 | tvs.add(c.left.base_var) 437 | tvs.add(c.right.base_var) 438 | return tvs 439 | 440 | def __or__(self, other: ConstraintSet) -> ConstraintSet: 441 | return ConstraintSet(self.subtype | other.subtype) 442 | 443 | def __str__(self) -> str: 444 | nt = os.linesep + "\t" 445 | return f"ConstraintSet:{nt}{nt.join(map(str,self.subtype))}" 446 | 447 | def __repr__(self) -> str: 448 | return f"ConstraintSet({repr(self.subtype)})" 449 | 450 | def __iter__(self) -> Iterator[SubtypeConstraint]: 451 | return iter(self.subtype) 452 | 453 | def __len__(self) -> int: 454 | return len(self.subtype) 455 | 456 | def apply_mapping( 457 | self, var_mapping: Dict[DerivedTypeVariable, DerivedTypeVariable] 458 | ) -> ConstraintSet: 459 | """ 460 | Return an equivalent constraint set in which DTVs have been substituted 461 | based on the provided `var_mapping`. 462 | """ 463 | 464 | def apply_mapping_to_dtv( 465 | dtv: DerivedTypeVariable, 466 | ) -> DerivedTypeVariable: 467 | suffix = None 468 | for type_var in var_mapping: 469 | suffix = type_var.get_suffix(dtv) 470 | if suffix is not None: 471 | base = var_mapping[type_var] 472 | break 473 | return base.extend(suffix) if suffix is not None else dtv 474 | 475 | mapped_cs = ConstraintSet() 476 | for cs in self: 477 | new_left = apply_mapping_to_dtv(cs.left) 478 | new_right = apply_mapping_to_dtv(cs.right) 479 | mapped_cs.add(SubtypeConstraint(new_left, new_right)) 480 | return mapped_cs 481 | 482 | 483 | T = TypeVar("T") 484 | 485 | 486 | class Lattice(ABC, Generic[T]): 487 | @property 488 | def atomic_types(self) -> FrozenSet[T]: 489 | pass 490 | 491 | @property 492 | def internal_types(self) -> FrozenSet[T]: 493 | pass 494 | 495 | @property 496 | def top(self) -> T: 497 | pass 498 | 499 | @property 500 | def bottom(self) -> T: 501 | pass 502 | 503 | def meet(self, t: T, v: T) -> T: 504 | pass 505 | 506 | def join(self, t: T, v: T) -> T: 507 | pass 508 | 509 | 510 | class LatticeCTypes: 511 | """ 512 | Class for converting a Lattice type to a CType. 513 | """ 514 | 515 | def atom_to_ctype(self, atom_lower: Any, atom_upper: Any, byte_size: int): 516 | raise NotImplementedError("Child class must implemented") 517 | 518 | 519 | MaybeVar = Union[DerivedTypeVariable, str] 520 | 521 | 522 | def maybe_to_var(mv: MaybeVar) -> DerivedTypeVariable: 523 | if isinstance(mv, str): 524 | return DerivedTypeVariable(mv) 525 | return mv 526 | 527 | 528 | Key = TypeVar("Key") 529 | Value = TypeVar("Value") 530 | MaybeDict = Union[Dict[Key, Value], Iterable[Tuple[Key, Value]]] 531 | 532 | 533 | def maybe_to_bindings( 534 | md: MaybeDict[Key, Value] 535 | ) -> Iterable[Tuple[Key, Value]]: 536 | if isinstance(md, dict): 537 | return md.items() 538 | return md 539 | 540 | 541 | class Program: 542 | """An entire binary. Contains a set of global variables, a mapping from procedures to sets of 543 | constraints, and a call graph. 544 | """ 545 | 546 | def __init__( 547 | self, 548 | types: Lattice[DerivedTypeVariable], 549 | global_vars: Iterable[MaybeVar], 550 | proc_constraints: MaybeDict[MaybeVar, ConstraintSet], 551 | callgraph: Union[ 552 | MaybeDict[MaybeVar, Iterable[MaybeVar]], networkx.DiGraph 553 | ], 554 | ) -> None: 555 | self.types = types 556 | self.global_vars = {maybe_to_var(glob) for glob in global_vars} 557 | self.proc_constraints: Dict[DerivedTypeVariable, ConstraintSet] = {} 558 | if isinstance(callgraph, networkx.DiGraph): 559 | self.callgraph = callgraph 560 | else: # Dict or Iterable[Tuple] 561 | self.callgraph = networkx.DiGraph() 562 | for caller, callees in maybe_to_bindings(callgraph): 563 | caller_var = maybe_to_var(caller) 564 | self.callgraph.add_node(caller_var) 565 | for callee in callees: 566 | self.callgraph.add_edge(caller_var, maybe_to_var(callee)) 567 | for name, constraints in maybe_to_bindings(proc_constraints): 568 | var = maybe_to_var(name) 569 | if var in self.proc_constraints: 570 | raise ValueError(f"Procedure doubly bound: {name}") 571 | self.proc_constraints[var] = Program.specialize_locals( 572 | var, 573 | constraints, 574 | self.procs_and_globals | self.types.atomic_types, 575 | ) 576 | 577 | @staticmethod 578 | def specialize_locals( 579 | base: DerivedTypeVariable, 580 | constraints: ConstraintSet, 581 | procs_and_global_vars: Set[DerivedTypeVariable], 582 | ) -> ConstraintSet: 583 | """ 584 | Specialize temporary variables to a specific function 585 | """ 586 | 587 | def fix_dtv(dtv: DerivedTypeVariable) -> DerivedTypeVariable: 588 | if DerivedTypeVariable(dtv.base) not in procs_and_global_vars: 589 | return DerivedTypeVariable(f"{base}${dtv.base}", dtv.path) 590 | else: 591 | return dtv 592 | 593 | output_cs = ConstraintSet() 594 | 595 | for constraint in constraints: 596 | output_cs.add( 597 | SubtypeConstraint( 598 | fix_dtv(constraint.left), fix_dtv(constraint.right) 599 | ) 600 | ) 601 | 602 | return output_cs 603 | 604 | @property 605 | def procs_and_globals(self): 606 | """ 607 | The set of procedures and global variables in a program. 608 | """ 609 | return self.global_vars | self.procs 610 | 611 | @property 612 | def procs(self): 613 | """ 614 | The set of procedures in the program. 615 | """ 616 | return set(self.callgraph.nodes()) 617 | 618 | 619 | class FreshVarFactory: 620 | """ 621 | A class that produces DTVs with unique names 622 | """ 623 | 624 | FRESH_VAR_PREFIX = "τ$" 625 | 626 | def __init__(self) -> None: 627 | self.fresh_var_counter = 0 628 | 629 | def fresh_var(self) -> DerivedTypeVariable: 630 | fresh_var = DerivedTypeVariable( 631 | f"{FreshVarFactory.FRESH_VAR_PREFIX}{self.fresh_var_counter}" 632 | ) 633 | self.fresh_var_counter += 1 634 | return fresh_var 635 | 636 | @staticmethod 637 | def is_anonymous_variable(dtv: DerivedTypeVariable) -> bool: 638 | return dtv.base.startswith(FreshVarFactory.FRESH_VAR_PREFIX) 639 | 640 | 641 | class RetypdError(Exception): 642 | """ 643 | A retypd specific exception 644 | """ 645 | 646 | pass 647 | -------------------------------------------------------------------------------- /src/sketches.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from .schema import ( 3 | DerivedTypeVariable, 4 | Lattice, 5 | FreshVarFactory, 6 | ConstraintSet, 7 | SubtypeConstraint, 8 | Variance, 9 | RetypdError, 10 | ) 11 | from .loggable import Loggable, LogLevel 12 | import os 13 | import networkx 14 | from typing import Set, Union, Optional, Tuple, Dict 15 | 16 | 17 | class SketchNode: 18 | """ 19 | A node in a sketch graph. The node is associated with a DTV and it 20 | captures its upper and lower bound in the type lattice. 21 | A sketch node might be referenced by `LabelNode` in recursive types. 22 | If that is the case, the sketch node can represent the primitive 23 | type infinite DTVs, e.g.: 24 | f.in_0, f.in_0.load.σ4@0, f.in_0.load.σ4@0.load.σ4@0, ... 25 | """ 26 | 27 | def __init__( 28 | self, 29 | dtv: DerivedTypeVariable, 30 | lower_bound: DerivedTypeVariable, 31 | upper_bound: DerivedTypeVariable, 32 | ) -> None: 33 | self._dtv = dtv 34 | self.lower_bound = lower_bound 35 | self.upper_bound = upper_bound 36 | self._hash = hash(self._dtv) 37 | 38 | @property 39 | def dtv(self): 40 | """ 41 | The main DTV represented by the sketch node. 42 | """ 43 | return self._dtv 44 | 45 | @dtv.setter 46 | def dtv(self, value): 47 | raise NotImplementedError("Read-only property") 48 | 49 | # the atomic type of a DTV is an annotation, not part of its identity 50 | def __eq__(self, other) -> bool: 51 | if isinstance(other, SketchNode): 52 | return self.dtv == other.dtv 53 | return False 54 | 55 | def __hash__(self) -> int: 56 | return self._hash 57 | 58 | def __str__(self) -> str: 59 | return f"({self.lower_bound} <= {self.dtv} <= {self.upper_bound})" 60 | 61 | def __repr__(self) -> str: 62 | return f"SketchNode({self})" 63 | 64 | 65 | class LabelNode: 66 | """ 67 | LableNodes are used to capture cycles in sketches 68 | (recursive types). A LabelNode has a target that 69 | is a DTV, which uniquely identifies the SketchNode 70 | that it points to. 71 | There can be multiple LabelNodes pointing to the 72 | same sketch node, since a type can have multiple 73 | recursive references (e.g., 'previous' and 74 | 'next' references in a doubly linked list). 75 | """ 76 | 77 | counter = 0 78 | 79 | def __init__(self, target: DerivedTypeVariable) -> None: 80 | self.target = target 81 | self.id = LabelNode.counter 82 | LabelNode.counter += 1 83 | self._hash = hash((self.target, self.id)) 84 | 85 | def __eq__(self, other) -> bool: 86 | if isinstance(other, LabelNode): 87 | return self.id == other.id and self.target == other.target 88 | return False 89 | 90 | def __hash__(self) -> int: 91 | return self._hash 92 | 93 | def __str__(self) -> str: 94 | return f"{self.target}.label_{self.id}" 95 | 96 | def __repr__(self) -> str: 97 | return str(self) 98 | 99 | 100 | SkNode = Union[SketchNode, LabelNode] 101 | 102 | 103 | class Sketch(Loggable): 104 | """ 105 | The sketch of a type variable. 106 | """ 107 | 108 | def __init__( 109 | self, 110 | root: DerivedTypeVariable, 111 | types: Lattice[DerivedTypeVariable], 112 | verbose: LogLevel = LogLevel.QUIET, 113 | ) -> None: 114 | super(Sketch, self).__init__(verbose) 115 | # We maintain the invariant that if a node is in `lookup` then it should also be in 116 | # `sketches` as a node (even if there are no edges) 117 | self.sketches = networkx.DiGraph() 118 | self._lookup: Dict[DerivedTypeVariable, SketchNode] = {} 119 | self.types = types 120 | self.root = self.make_node(root) 121 | 122 | def lookup(self, dtv: DerivedTypeVariable) -> Optional[SketchNode]: 123 | """ 124 | Return the sketch node corresponding to the path 125 | represented in the given dtv 126 | """ 127 | if dtv in self._lookup: 128 | return self._lookup[dtv] 129 | # if it is not in the dictionary we traverse the graph 130 | beg = dtv.base_var 131 | curr_node = self._lookup.get(beg) 132 | if curr_node is None: 133 | return None 134 | for access_path in dtv.path: 135 | succs = [ 136 | dest 137 | for (_, dest, label) in self.sketches.out_edges( 138 | curr_node, data="label" 139 | ) 140 | if label == access_path 141 | ] 142 | if len(succs) == 0: 143 | return None 144 | elif len(succs) > 1: 145 | raise ValueError( 146 | f"{curr_node} has multiple successors in sketches" 147 | ) 148 | curr_node = succs[0] 149 | if isinstance(curr_node, LabelNode): 150 | curr_node = self._lookup[curr_node.target] 151 | return curr_node 152 | 153 | def _add_node(self, node: SketchNode) -> None: 154 | """ 155 | Add node to the sketch graph 156 | """ 157 | self._lookup[node.dtv] = node 158 | self.sketches.add_node(node) 159 | 160 | def make_node(self, variable: DerivedTypeVariable) -> SketchNode: 161 | """Make a node from a DTV. Compute its atom from its access path.""" 162 | result = SketchNode(variable, self.types.bottom, self.types.top) 163 | self._add_node(result) 164 | return result 165 | 166 | def add_edge(self, head: SketchNode, tail: SkNode, label: str) -> None: 167 | """ 168 | Add edge labeled with `label` in the sketch graph between `head` 169 | and `tail`. 170 | """ 171 | # don't emit duplicate edges 172 | if (head, tail) not in self.sketches.edges: 173 | self.sketches.add_edge(head, tail, label=label) 174 | else: 175 | if label != self.sketches.edges[head, tail]["label"]: 176 | raise RetypdError( 177 | f"Failed to add edge {label} between {head} and {tail}." 178 | f" Label {self.sketches.edges[head, tail]['label']} exists" 179 | ) 180 | 181 | def instantiate_sketch( 182 | self, 183 | proc: DerivedTypeVariable, 184 | fresh_var_factory: FreshVarFactory, 185 | only_capabilities: bool = False, 186 | ) -> ConstraintSet: 187 | """ 188 | Encode all the capability and primitive type information present in the sketch. 189 | """ 190 | all_constraints = ConstraintSet() 191 | for node in self.sketches.nodes: 192 | if isinstance(node, SketchNode) and node.dtv.base_var == proc: 193 | constraints = [] 194 | if not only_capabilities: 195 | if node.lower_bound != self.types.bottom: 196 | constraints.append( 197 | SubtypeConstraint(node.lower_bound, node.dtv) 198 | ) 199 | if node.upper_bound != self.types.top: 200 | constraints.append( 201 | SubtypeConstraint(node.dtv, node.upper_bound) 202 | ) 203 | 204 | # if the node is a leaf, capture the capability using fake variables 205 | # this could be avoided if we support capability constraints (Var x.l) in 206 | # addition to subtype constraints 207 | if ( 208 | len(constraints) == 0 209 | and self.sketches.out_degree(node) == 0 210 | ): 211 | fresh_var = fresh_var_factory.fresh_var() 212 | if node.dtv.path_variance == Variance.CONTRAVARIANT: 213 | constraints.append( 214 | SubtypeConstraint(node.dtv, fresh_var) 215 | ) 216 | else: 217 | constraints.append( 218 | SubtypeConstraint(fresh_var, node.dtv) 219 | ) 220 | 221 | all_constraints |= ConstraintSet(constraints) 222 | return all_constraints 223 | 224 | def add_constraints(self, constraints: ConstraintSet) -> None: 225 | """Extend the set of sketches with the new set of constraints.""" 226 | 227 | for constraint in constraints: 228 | left = constraint.left 229 | right = constraint.right 230 | if ( 231 | left in self.types.internal_types 232 | and right not in self.types.internal_types 233 | ): 234 | right_node = self.lookup(right) 235 | if right_node is None: 236 | raise RetypdError( 237 | f"Sketch node corresponding to {right} does not exist" 238 | ) 239 | self.debug("JOIN: %s, %s", right_node, left) 240 | right_node.lower_bound = self.types.join( 241 | right_node.lower_bound, left 242 | ) 243 | self.debug(" --> %s", right_node) 244 | elif ( 245 | right in self.types.internal_types 246 | and left not in self.types.internal_types 247 | ): 248 | left_node = self.lookup(left) 249 | if left_node is None: 250 | raise RetypdError( 251 | f"Sketch node corresponding to {left} does not exist" 252 | ) 253 | self.debug("MEET: %s, %s", left_node, left) 254 | left_node.upper_bound = self.types.meet( 255 | left_node.upper_bound, right 256 | ) 257 | self.debug(" --> %s", left_node) 258 | 259 | def remove_subtree(self, node: SkNode) -> None: 260 | """ 261 | Remove the subtree with root node from the sketch. 262 | """ 263 | worklist = [node] 264 | while len(worklist) > 0: 265 | node = worklist.pop() 266 | worklist.extend(self.sketches.successors(node)) 267 | self.sketches.remove_node(node) 268 | if isinstance(node, SketchNode): 269 | if node.dtv in self._lookup: 270 | del self._lookup[node.dtv] 271 | 272 | def meet(self, other: Sketch) -> None: 273 | """ 274 | Compute in-place meet of self and another sketch 275 | """ 276 | if self.root.dtv != other.root.dtv: 277 | raise RetypdError( 278 | "Cannot compute a meet of two sketches with different root" 279 | ) 280 | 281 | worklist = [(self.root, other.root)] 282 | met_nodes = set() 283 | while len(worklist) > 0: 284 | curr_node, other_node = worklist.pop() 285 | # Avoid infinite loop in case of label nodes 286 | if (curr_node, other_node) in met_nodes: 287 | continue 288 | met_nodes.add((curr_node, other_node)) 289 | 290 | # Deal with primitive type 291 | curr_node.lower_bound = self.types.join( 292 | curr_node.lower_bound, other_node.lower_bound 293 | ) 294 | curr_node.upper_bound = self.types.meet( 295 | curr_node.upper_bound, other_node.upper_bound 296 | ) 297 | # Meet of successors: language union 298 | curr_succs = { 299 | label: succ 300 | for _, succ, label in self.sketches.out_edges( 301 | curr_node, data="label" 302 | ) 303 | } 304 | for _, other_succ, label in other.sketches.out_edges( 305 | other_node, data="label" 306 | ): 307 | if label not in curr_succs: 308 | # create new node 309 | if isinstance(other_succ, SketchNode): 310 | curr_succ = self.make_node(other_succ.dtv) 311 | curr_succ.upper_bound = other_succ.upper_bound 312 | curr_succ.lower_bound = other_succ.lower_bound 313 | else: # LabelNode 314 | curr_succ = LabelNode(other_succ.target) 315 | self.add_edge(curr_node, curr_succ, label) 316 | else: 317 | curr_succ = curr_succs[label] 318 | # follow label nodes 319 | if isinstance(curr_succ, LabelNode): 320 | curr_succ = self.lookup(curr_succ.target) 321 | if isinstance(other_succ, LabelNode): 322 | other_succ = other.lookup(other_succ.target) 323 | worklist.append((curr_succ, other_succ)) 324 | 325 | def join(self, other: Sketch) -> None: 326 | """ 327 | Compute in-place join of self and another sketch 328 | """ 329 | 330 | if self.root.dtv != other.root.dtv: 331 | raise RetypdError( 332 | "Cannot compute a join of two sketches with different root" 333 | ) 334 | worklist = [(self.root, other.root)] 335 | while len(worklist) > 0: 336 | curr_node, other_node = worklist.pop() 337 | # Deal with primitive type 338 | curr_node.lower_bound = self.types.meet( 339 | curr_node.lower_bound, other_node.lower_bound 340 | ) 341 | curr_node.upper_bound = self.types.join( 342 | curr_node.upper_bound, other_node.upper_bound 343 | ) 344 | 345 | # Join successors: Language intersection 346 | other_succs = { 347 | label: succ 348 | for _, succ, label in other.sketches.out_edges( 349 | other_node, data="label" 350 | ) 351 | } 352 | for _, curr_succ, label in list( 353 | self.sketches.out_edges(curr_node, data="label") 354 | ): 355 | if label not in other_succs: 356 | self.remove_subtree(curr_succ) 357 | else: 358 | other_succ = other_succs[label] 359 | if isinstance(curr_succ, SketchNode) and isinstance( 360 | other_succ, SketchNode 361 | ): 362 | worklist.append((curr_succ, other_succ)) 363 | # TODO what to do with LabelNodes? 364 | 365 | def to_dot(self, dtv: DerivedTypeVariable) -> str: 366 | nt = f"{os.linesep}\t" 367 | graph_str = f"digraph {dtv} {{" 368 | start = self._lookup[dtv] 369 | edges_str = "" 370 | # emit edges and identify nodes 371 | nodes = {start} 372 | seen: Set[Tuple[SketchNode, SkNode]] = {(start, start)} 373 | frontier = { 374 | (start, succ) for succ in self.sketches.successors(start) 375 | } - seen 376 | while frontier: 377 | new_frontier: Set[Tuple[SketchNode, SkNode]] = set() 378 | for pred, succ in frontier: 379 | edges_str += nt 380 | nodes.add(succ) 381 | new_frontier |= { 382 | (succ, s_s) for s_s in self.sketches.successors(succ) 383 | } 384 | edges_str += f'"{pred}" -> "{succ}"' 385 | edges_str += ( 386 | f' [label="{self.sketches[pred][succ]["label"]}"];' 387 | ) 388 | frontier = new_frontier - seen 389 | # emit nodes 390 | for node in nodes: 391 | if isinstance(node, SketchNode): 392 | if node.dtv == dtv: 393 | graph_str += nt 394 | graph_str += f'"{node}" [label="{node.dtv}"];' 395 | elif node.dtv.base_var == dtv: 396 | graph_str += nt 397 | graph_str += f'"{node}" [label="{node.lower_bound}..{node.upper_bound}"];' 398 | elif node.target.base_var == dtv: 399 | graph_str += nt 400 | graph_str += f'"{node}" [label="{node.target}", shape=house];' 401 | graph_str += edges_str 402 | graph_str += f"{os.linesep}}}" 403 | return graph_str 404 | 405 | def __str__(self) -> str: 406 | if self._lookup: 407 | nt = f"{os.linesep}\t" 408 | 409 | def format(k: DerivedTypeVariable) -> str: 410 | return str(self._lookup[k]) 411 | 412 | return f"nodes:{nt}{nt.join(map(format, self._lookup.keys()))})" 413 | return "no sketches" 414 | -------------------------------------------------------------------------------- /src/version.py: -------------------------------------------------------------------------------- 1 | # Retypd - machine code type inference 2 | # Copyright (C) 2021 GrammaTech, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | # 17 | # This project is sponsored by the Office of Naval Research, One Liberty 18 | # Center, 875 N. Randolph Street, Arlington, VA 22203 under contract # 19 | # N68335-17-C-0700. The content of the information does not necessarily 20 | # reflect the position or policy of the Government and no official 21 | # endorsement should be inferred. 22 | 23 | __packagename__ = "retypd" 24 | __version__ = "0.3" 25 | -------------------------------------------------------------------------------- /test/test_graph.py: -------------------------------------------------------------------------------- 1 | """Simple unit tests from the paper and slides that only look at the final result (sketches) 2 | """ 3 | 4 | import pytest 5 | from retypd import ( 6 | ConstraintSet, 7 | SchemaParser, 8 | ) 9 | 10 | from retypd.schema import DerefLabel, InLabel, LoadLabel, OutLabel, Variance 11 | from retypd.graph import ConstraintGraph, EdgeLabel, SideMark, Node 12 | 13 | VERBOSE_TESTS = False 14 | 15 | 16 | @pytest.mark.commit 17 | def test_simple(): 18 | """ 19 | Check that the constraint graph from one constraint has the expected elements. 20 | A constraint graph from one constraint has two paths that allow us to reconstruct 21 | the constraint, one in covariant version and one in contravariant. 22 | """ 23 | cs = ConstraintSet( 24 | [SchemaParser.parse_constraint("f.in_0 <= A.load.σ4@0")] 25 | ) 26 | graph = ConstraintGraph( 27 | cs, 28 | {SchemaParser.parse_variable("f")}, 29 | keep_graph_before_split=True, 30 | ).graph_before_split 31 | f_co = Node( 32 | SchemaParser.parse_variable("f"), Variance.COVARIANT, SideMark.RIGHT 33 | ) 34 | fin0_co = Node( 35 | SchemaParser.parse_variable("f.in_0"), 36 | Variance.COVARIANT, 37 | SideMark.LEFT, 38 | ) 39 | a_load_0_co = Node( 40 | SchemaParser.parse_variable("A.load.σ4@0"), Variance.COVARIANT 41 | ) 42 | a_load_co = Node(SchemaParser.parse_variable("A.load"), Variance.COVARIANT) 43 | a_co = Node(SchemaParser.parse_variable("A"), Variance.COVARIANT) 44 | 45 | f_cn = Node( 46 | SchemaParser.parse_variable("f"), 47 | Variance.CONTRAVARIANT, 48 | SideMark.LEFT, 49 | ) 50 | fin0_cn = Node( 51 | SchemaParser.parse_variable("f.in_0"), 52 | Variance.CONTRAVARIANT, 53 | SideMark.RIGHT, 54 | ) 55 | a_load_0_cn = Node( 56 | SchemaParser.parse_variable("A.load.σ4@0"), Variance.CONTRAVARIANT 57 | ) 58 | a_load_cn = Node( 59 | SchemaParser.parse_variable("A.load"), Variance.CONTRAVARIANT 60 | ) 61 | a_cn = Node(SchemaParser.parse_variable("A"), Variance.CONTRAVARIANT) 62 | assert { 63 | f_co, 64 | fin0_co, 65 | a_load_0_co, 66 | a_load_co, 67 | a_co, 68 | f_cn, 69 | fin0_cn, 70 | a_load_0_cn, 71 | a_load_cn, 72 | a_cn, 73 | } == set(graph.nodes) 74 | forget = EdgeLabel.Kind.FORGET 75 | recall = EdgeLabel.Kind.RECALL 76 | edges = { 77 | # one path from "f" to "A" 78 | (f_cn, fin0_co, EdgeLabel(InLabel(0), recall)), 79 | (fin0_co, a_load_0_co, None), 80 | (a_load_0_co, a_load_co, EdgeLabel(DerefLabel(4, 0), forget)), 81 | (a_load_co, a_co, EdgeLabel(LoadLabel.instance(), forget)), 82 | # the second path from "A" to "f" 83 | (a_cn, a_load_cn, EdgeLabel(LoadLabel.instance(), recall)), 84 | (a_load_cn, a_load_0_cn, EdgeLabel(DerefLabel(4, 0), recall)), 85 | (a_load_0_cn, fin0_cn, None), 86 | (fin0_cn, f_co, EdgeLabel(InLabel(0), forget)), 87 | } 88 | assert edges == set(graph.edges(data="label")) 89 | 90 | 91 | @pytest.mark.commit 92 | def test_two_constraints(): 93 | """ 94 | A constraint graph from two related constraints has two paths 95 | (a covariant and contravariant version) that allow us to conclude 96 | that A.out <= C. 97 | """ 98 | constraints = ["A <= B", "B.out <= C"] 99 | cs = ConstraintSet(map(SchemaParser.parse_constraint, constraints)) 100 | graph = ConstraintGraph( 101 | cs, 102 | {SchemaParser.parse_variable("A"), SchemaParser.parse_variable("C")}, 103 | keep_graph_before_split=True, 104 | ).graph_before_split 105 | b_co = Node(SchemaParser.parse_variable("B"), Variance.COVARIANT) 106 | b_out_co = Node(SchemaParser.parse_variable("B.out"), Variance.COVARIANT) 107 | a_co = Node( 108 | SchemaParser.parse_variable("A"), Variance.COVARIANT, SideMark.LEFT 109 | ) 110 | c_co = Node( 111 | SchemaParser.parse_variable("C"), Variance.COVARIANT, SideMark.RIGHT 112 | ) 113 | b_cn = Node(SchemaParser.parse_variable("B"), Variance.CONTRAVARIANT) 114 | b_out_cn = Node( 115 | SchemaParser.parse_variable("B.out"), Variance.CONTRAVARIANT 116 | ) 117 | a_cn = Node( 118 | SchemaParser.parse_variable("A"), 119 | Variance.CONTRAVARIANT, 120 | SideMark.RIGHT, 121 | ) 122 | c_cn = Node( 123 | SchemaParser.parse_variable("C"), 124 | Variance.CONTRAVARIANT, 125 | SideMark.LEFT, 126 | ) 127 | forget = EdgeLabel.Kind.FORGET 128 | recall = EdgeLabel.Kind.RECALL 129 | edges = { 130 | # path from A to C 131 | (a_co, b_co, None), 132 | (b_co, b_out_co, EdgeLabel(OutLabel.instance(), recall)), 133 | (b_out_co, c_co, None), 134 | # path from C to A 135 | (c_cn, b_out_cn, None), 136 | (b_out_cn, b_cn, EdgeLabel(OutLabel.instance(), forget)), 137 | (b_cn, a_cn, None), 138 | } 139 | assert edges == set(graph.edges(data="label")) 140 | -------------------------------------------------------------------------------- /test/test_internals.py: -------------------------------------------------------------------------------- 1 | """Simple unit tests from the paper and slides that rely on looking under-the-hood 2 | at generated constraints. 3 | """ 4 | 5 | 6 | import pytest 7 | from retypd import ( 8 | CLattice, 9 | ConstraintSet, 10 | CTypeGenerator, 11 | DerivedTypeVariable, 12 | DummyLattice, 13 | SchemaParser, 14 | Solver, 15 | SolverConfig, 16 | DerefLabel, 17 | Program, 18 | CType, 19 | CLatticeCTypes, 20 | BoolType, 21 | CharType, 22 | FunctionType, 23 | FloatType, 24 | IntType, 25 | PointerType, 26 | VoidType, 27 | ) 28 | from test_endtoend import parse_cs_set, parse_cs, parse_var 29 | 30 | 31 | @pytest.mark.commit 32 | def test_parse_label(): 33 | l = SchemaParser.parse_label("σ8@0") 34 | assert (l.size, l.offset, l.count) == (8, 0, 1) 35 | l = SchemaParser.parse_label("σ0@10000") 36 | assert (l.size, l.offset, l.count) == (0, 10000, 1) 37 | l = SchemaParser.parse_label("σ4@-32") 38 | assert (l.size, l.offset, l.count) == (4, -32, 1) 39 | l = SchemaParser.parse_label("σ2@32*[1000]") 40 | assert (l.size, l.offset, l.count) == (2, 32, 1000) 41 | l = SchemaParser.parse_label("σ2@32*[nobound]") 42 | assert (l.size, l.offset, l.count) == (2, 32, DerefLabel.COUNT_NOBOUND) 43 | 44 | l = SchemaParser.parse_label("σ2@32*[nullterm]") 45 | assert (l.size, l.offset, l.count) == (2, 32, DerefLabel.COUNT_NULLTERM) 46 | with pytest.raises(ValueError): 47 | l = SchemaParser.parse_label("σ-9@100") 48 | 49 | 50 | @pytest.mark.commit 51 | def test_simple_constraints(): 52 | """A simple test from the paper (the right side of Figure 4 on p. 6). This one has no 53 | recursive data structures; as such, the fixed point would suffice. However, we compute type 54 | constraints in the same way as in the presence of recursion. 55 | """ 56 | constraints = parse_cs_set( 57 | ["p ⊑ q", "x ⊑ q.store.σ4@0", "p.load.σ4@0 ⊑ y"] 58 | ) 59 | f, x, y = SchemaParser.parse_variables(["f", "x", "y"]) 60 | lattice = DummyLattice() 61 | solver = Solver(Program(CLattice(), {}, {}, {})) 62 | generated = solver._generate_type_scheme( 63 | constraints, {x, y}, lattice.internal_types 64 | ) 65 | assert parse_cs("x ⊑ y") in generated 66 | 67 | 68 | @pytest.mark.commit 69 | def test_other_simple_constraints(): 70 | """Another simple test from the paper (the program modeled in Figure 14 on p. 26).""" 71 | constraints = parse_cs_set( 72 | ["y <= p", "p <= x", "A <= x.store", "y.load <= B"] 73 | ) 74 | f, A, B = SchemaParser.parse_variables(["f", "A", "B"]) 75 | lattice = DummyLattice() 76 | solver = Solver(Program(CLattice(), {}, {}, {})) 77 | generated = solver._generate_type_scheme( 78 | constraints, {A, B}, lattice.internal_types 79 | ) 80 | assert parse_cs("A ⊑ B") in generated 81 | 82 | 83 | @pytest.mark.commit 84 | def test_forgets(): 85 | """A simple test to check that paths that include "forgotten" labels reconstruct access 86 | paths in the correct order. 87 | """ 88 | l, F = SchemaParser.parse_variables(["l", "F"]) 89 | constraints = ConstraintSet() 90 | constraint = parse_cs("l ⊑ F.in_1.load.σ8@0") 91 | constraints.add(constraint) 92 | lattice = DummyLattice() 93 | solver = Solver(Program(CLattice(), {}, {}, {})) 94 | generated = solver._generate_type_scheme( 95 | constraints, {l, F}, lattice.internal_types 96 | ) 97 | assert constraint in generated 98 | 99 | 100 | @pytest.mark.parametrize( 101 | ("lhs", "rhs", "expected"), 102 | [ 103 | ("float", "uint", "┬"), 104 | ("uint", "int", "uint"), 105 | ("char", "int", "int"), 106 | ("int64", "int", "int"), 107 | ("uint32", "uint64", "uint"), 108 | ], 109 | ) 110 | @pytest.mark.commit 111 | def test_join(lhs: str, rhs: str, expected: str): 112 | """Test C-lattice join operations against known values""" 113 | lhs_dtv = DerivedTypeVariable(lhs) 114 | rhs_dtv = DerivedTypeVariable(rhs) 115 | equal_dtv = DerivedTypeVariable(expected) 116 | assert CLattice().join(lhs_dtv, rhs_dtv) == equal_dtv 117 | 118 | 119 | @pytest.mark.parametrize( 120 | ("name", "ctype", "size"), 121 | [ 122 | ("int", IntType(4, True), 4), 123 | ("int8", IntType(1, True), None), 124 | ("int16", IntType(2, True), None), 125 | ("int32", IntType(4, True), None), 126 | ("int64", IntType(8, True), None), 127 | ("uint", IntType(4, False), 4), 128 | ("uint8", IntType(1, False), None), 129 | ("uint16", IntType(2, False), None), 130 | ("uint32", IntType(4, False), None), 131 | ("uint64", IntType(8, False), None), 132 | ("void", VoidType(), None), 133 | ("char", CharType(1), 1), 134 | ("bool", BoolType(1), 1), 135 | ("float", FloatType(4), None), 136 | ("double", FloatType(8), None), 137 | ], 138 | ) 139 | @pytest.mark.commit 140 | def test_atom_to_ctype(name: str, ctype: CType, size: int): 141 | """Test C-lattice are converted to C-types correctly""" 142 | atom = DerivedTypeVariable(name) 143 | lattice = CLatticeCTypes() 144 | ctype_lhs = lattice.atom_to_ctype(atom, CLattice._top, size) 145 | ctype_rhs = lattice.atom_to_ctype(CLattice._bottom, atom, size) 146 | assert str(ctype) == str(ctype_lhs) 147 | assert str(ctype) == str(ctype_rhs) 148 | 149 | 150 | @pytest.mark.commit 151 | def test_infers_all_inputs(): 152 | """Test that we infer all the inputs for a function""" 153 | F = parse_var("F") 154 | constraints = ConstraintSet() 155 | constraints.add(parse_cs("F.in_2.in_1 ⊑ int")) 156 | lattice = DummyLattice() 157 | solver = Solver(Program(CLattice(), {F}, {F: constraints}, {F: {}})) 158 | _, sketches = solver() 159 | 160 | gen = CTypeGenerator(sketches, lattice, CLatticeCTypes(), 4, 4, verbose=2) 161 | dtv2type = gen() 162 | assert isinstance(dtv2type[F], FunctionType) 163 | assert dtv2type[F].params[0] is not None 164 | assert dtv2type[F].params[1] is not None 165 | assert isinstance(dtv2type[F].params[2], PointerType) 166 | assert isinstance(dtv2type[F].params[2].target_type, FunctionType) 167 | assert dtv2type[F].params[2].target_type.params[0] is not None 168 | assert isinstance(dtv2type[F].params[2].target_type.params[1], IntType) 169 | assert dtv2type[F].params[2].target_type.params[1] is not None 170 | 171 | 172 | @pytest.mark.commit 173 | def test_top_down(): 174 | """ 175 | Test that top-down propagation can propagate information correctly 176 | """ 177 | config = SolverConfig(top_down_propagation=True) 178 | F, G = SchemaParser.parse_variables(["F", "G"]) 179 | constraints = {F: ConstraintSet(), G: ConstraintSet()} 180 | constraints[F].add(parse_cs("int ⊑ F.in_1.load.σ8@0")) 181 | constraints[G].add(parse_cs("int ⊑ x.load.σ8@8")) 182 | constraints[G].add(parse_cs("x ⊑ F.in_1")) 183 | constraints[G].add(parse_cs("F.in_1 ⊑ G.in_1")) 184 | 185 | solver = Solver( 186 | Program(CLattice(), {}, constraints, {G: {F}}), config=config 187 | ) 188 | _, sketches = solver() 189 | assert sketches[F].lookup( 190 | parse_var("F.in_1.load.σ8@8") 191 | ).lower_bound == DerivedTypeVariable("int") 192 | 193 | 194 | @pytest.mark.commit 195 | def test_top_down_two_levels(): 196 | """ 197 | Validate that top-down propagation can handle two levels of indirection 198 | """ 199 | config = SolverConfig(top_down_propagation=True) 200 | F, G, H = SchemaParser.parse_variables(["F", "G", "H"]) 201 | constraints = {F: ConstraintSet(), G: ConstraintSet(), H: ConstraintSet()} 202 | constraints[F].add(parse_cs("F.out ⊑ int")) 203 | constraints[G].add(parse_cs("G.in_1 ⊑ F.in_1")) 204 | constraints[H].add(parse_cs("int ⊑ G.in_1")) 205 | 206 | solver = Solver( 207 | Program(CLattice(), {}, constraints, {G: {F}, H: {G}}), 208 | config=config, 209 | ) 210 | _, sketches = solver() 211 | assert sketches[G].lookup( 212 | parse_var("G.in_1") 213 | ).lower_bound == DerivedTypeVariable("int") 214 | assert sketches[F].lookup( 215 | parse_var("F.in_1") 216 | ).lower_bound == DerivedTypeVariable("int") 217 | 218 | 219 | @pytest.mark.commit 220 | def test_top_down_three_levels(): 221 | """ 222 | Validate that top-down propagation can handle three levels of indirection 223 | """ 224 | config = SolverConfig(top_down_propagation=True) 225 | F, G, H, I = SchemaParser.parse_variables(["F", "G", "H", "I"]) 226 | constraints = { 227 | F: ConstraintSet(), 228 | G: ConstraintSet(), 229 | H: ConstraintSet(), 230 | I: ConstraintSet(), 231 | } 232 | constraints[F].add(parse_cs("F.out ⊑ int")) 233 | constraints[G].add(parse_cs("G.in_1 ⊑ F.in_1")) 234 | constraints[G].add(parse_cs("G.in_2 ⊑ F.in_2")) 235 | constraints[H].add(parse_cs("H.in_1 ⊑ G.in_1")) 236 | constraints[I].add(parse_cs("int ⊑ H.in_1")) 237 | constraints[I].add(parse_cs("int ⊑ G.in_2")) 238 | constraints[I].add(parse_cs("int ⊑ G.in_1")) 239 | 240 | solver = Solver( 241 | Program(CLattice(), {}, constraints, {G: {F}, H: {G}, I: {H, G}}), 242 | config=config, 243 | ) 244 | _, sketches = solver() 245 | assert sketches[F].lookup( 246 | parse_var("F.in_1") 247 | ).lower_bound == DerivedTypeVariable("int") 248 | assert sketches[F].lookup( 249 | parse_var("F.in_2") 250 | ).lower_bound != DerivedTypeVariable("int") 251 | # Both calls constrain G.in_1 to int 252 | assert sketches[G].lookup( 253 | parse_var("G.in_1") 254 | ).lower_bound == DerivedTypeVariable("int") 255 | # Only one call constraints G.in_2 to int 256 | assert sketches[G].lookup( 257 | parse_var("G.in_2") 258 | ).lower_bound != DerivedTypeVariable("int") 259 | assert sketches[H].lookup( 260 | parse_var("H.in_1") 261 | ).lower_bound == DerivedTypeVariable("int") 262 | 263 | 264 | @pytest.mark.commit 265 | def test_top_down_merge_sketches(): 266 | """ 267 | Test that when lattice types are met at an input, they are merged over the lattice correctly 268 | """ 269 | config = SolverConfig(top_down_propagation=True) 270 | F, G, H = SchemaParser.parse_variables(["F", "G", "H"]) 271 | constraints = { 272 | F: ConstraintSet(), 273 | G: ConstraintSet(), 274 | H: ConstraintSet(), 275 | } 276 | 277 | constraints[F].add(parse_cs("F.out ⊑ int")) 278 | 279 | constraints[G].add(parse_cs("A ⊑ F.in_1")) 280 | constraints[G].add(parse_cs("int ⊑ A")) 281 | 282 | constraints[H].add(parse_cs("A ⊑ F.in_1")) 283 | constraints[H].add(parse_cs("char ⊑ A")) 284 | 285 | solver = Solver( 286 | Program(CLattice(), {}, constraints, {G: {F}, H: {F}}), 287 | config=config, 288 | ) 289 | _, sketches = solver() 290 | 291 | assert sketches[F].lookup( 292 | parse_var("F.in_1") 293 | ).lower_bound == DerivedTypeVariable("char") 294 | 295 | 296 | @pytest.mark.commit 297 | def test_top_down_merge_sketch_languages(): 298 | """ 299 | Test that only the common capabilities are kept for top-down propagation. 300 | """ 301 | config = SolverConfig(top_down_propagation=True) 302 | F, G, H = SchemaParser.parse_variables(["F", "G", "H"]) 303 | constraints = { 304 | F: ConstraintSet(), 305 | G: ConstraintSet(), 306 | H: ConstraintSet(), 307 | } 308 | 309 | constraints[F].add(parse_cs("F.out ⊑ int")) 310 | 311 | constraints[G].add(parse_cs("A ⊑ F.in_1")) 312 | constraints[G].add(parse_cs("int ⊑ A.load.σ8@0")) 313 | constraints[G].add(parse_cs("int ⊑ A.load.σ8@4")) 314 | constraints[G].add(parse_cs("F.out ⊑ B")) 315 | constraints[G].add(parse_cs("B.load.σ8@4 ⊑ int16")) 316 | 317 | constraints[H].add(parse_cs("A ⊑ F.in_1")) 318 | constraints[H].add(parse_cs("int ⊑ A.load.σ8@0")) 319 | constraints[H].add(parse_cs("int ⊑ A.load.σ8@8")) 320 | constraints[H].add(parse_cs("F.out ⊑ B")) 321 | constraints[H].add(parse_cs("B.load.σ8@4 ⊑ int32")) 322 | 323 | solver = Solver( 324 | Program(CLattice(), {}, constraints, {G: {F}, H: {F}}), 325 | config=config, 326 | ) 327 | _, sketches = solver() 328 | 329 | assert sketches[F].lookup( 330 | parse_var("F.in_1.load.σ8@0") 331 | ).lower_bound == DerivedTypeVariable("int") 332 | assert sketches[F].lookup(parse_var("F.in_1.load.σ8@4")) == None 333 | assert sketches[F].lookup(parse_var("F.in_1.load.σ8@8")) == None 334 | assert sketches[F].lookup( 335 | parse_var("F.out.load.σ8@4") 336 | ).upper_bound == DerivedTypeVariable("int") 337 | 338 | 339 | @pytest.mark.commit 340 | def test_top_down_merge_incompatible_sketches(): 341 | """ 342 | Test that when conflicting lattice types are met at an input, they are merged over the lattice 343 | to top types (i.e. this function can accept the join of the two) 344 | """ 345 | config = SolverConfig(top_down_propagation=True) 346 | F, G, H = SchemaParser.parse_variables(["F", "G", "H"]) 347 | constraints = { 348 | F: ConstraintSet(), 349 | G: ConstraintSet(), 350 | H: ConstraintSet(), 351 | } 352 | 353 | constraints[F].add(parse_cs("F.out ⊑ int")) 354 | 355 | constraints[G].add(parse_cs("A ⊑ F.in_1")) 356 | constraints[G].add(parse_cs("int ⊑ A.store.σ8@0")) 357 | 358 | constraints[H].add(parse_cs("A ⊑ F.in_1")) 359 | constraints[H].add(parse_cs("double ⊑ A.store.σ8@0")) 360 | 361 | solver = Solver( 362 | Program(CLattice(), {}, constraints, {G: {F}, H: {F}}), 363 | config=config, 364 | ) 365 | _, sketches = solver() 366 | 367 | assert sketches[F].lookup( 368 | parse_var("F.in_1.store.σ8@0") 369 | ).upper_bound == DerivedTypeVariable("┬") 370 | 371 | 372 | @pytest.mark.commit 373 | def test_overlapping_var_in_scc(): 374 | """Validate that variables overlapping in SCCs aren't aliased""" 375 | config = SolverConfig() 376 | F, G = SchemaParser.parse_variables(["F", "G"]) 377 | constraints = { 378 | F: ConstraintSet(), 379 | G: ConstraintSet(), 380 | } 381 | 382 | constraints[F].add(parse_cs("F.in_1 ⊑ A")) 383 | constraints[F].add(parse_cs("A ⊑ int")) 384 | 385 | constraints[G].add(parse_cs("G.in_1 ⊑ A")) 386 | constraints[G].add(parse_cs("A ⊑ double")) 387 | 388 | solver = Solver( 389 | Program(CLattice(), {}, constraints, {G: {F}, F: {G}}), 390 | config=config, 391 | ) 392 | _, sketches = solver() 393 | 394 | assert sketches[F].lookup( 395 | parse_var("F.in_1") 396 | ).upper_bound == DerivedTypeVariable("int") 397 | assert sketches[G].lookup( 398 | parse_var("G.in_1") 399 | ).upper_bound == DerivedTypeVariable("double") 400 | 401 | 402 | @pytest.mark.commit 403 | def test_sketches_not_overlapping(): 404 | """ 405 | Validate that during top-down propagation, sketches are re-inferred from scratch and that when 406 | primitive constraints populate those sketches, all sketch nodes are present. This was derived 407 | from libpng originally. Crashes were non-deterministic. 408 | """ 409 | ( 410 | EmptyFunc, 411 | RootFunc, 412 | CrashFunc, 413 | MiddleFunc, 414 | SCCFunc1, 415 | SCCFunc2, 416 | MiddleFunc2, 417 | ) = SchemaParser.parse_variables( 418 | [ 419 | "EmptyFunc", 420 | "RootFunc", 421 | "CrashFunc", 422 | "MiddleFunc", 423 | "SCCFunc1", 424 | "SCCFunc2", 425 | "MiddleFunc2", 426 | ] 427 | ) 428 | constraints = { 429 | EmptyFunc: ConstraintSet(), 430 | RootFunc: ConstraintSet(), 431 | CrashFunc: ConstraintSet(), 432 | MiddleFunc: ConstraintSet(), 433 | SCCFunc1: ConstraintSet(), 434 | SCCFunc2: ConstraintSet(), 435 | MiddleFunc2: ConstraintSet(), 436 | } 437 | constraints[RootFunc].add(parse_cs("CrashFunc.out ⊑ MiddleFunc.in_2")) 438 | constraints[RootFunc].add( 439 | parse_cs("CrashFunc.out ⊑ MiddleFunc.in_2.store.σ8@32") 440 | ) 441 | constraints[SCCFunc1].add(parse_cs("A ⊑ MiddleFunc.in_0")) 442 | constraints[SCCFunc2].add(parse_cs("A ⊑ MiddleFunc2.in_1")) 443 | constraints[MiddleFunc2].add(parse_cs("A ⊑ MiddleFunc2.in_3.store.σ4@80")) 444 | constraints[MiddleFunc2].add( 445 | parse_cs("MiddleFunc2.in_3 ⊑ MiddleFunc.in_2") 446 | ) 447 | constraints[MiddleFunc].add(parse_cs("MiddleFunc.in_0 ⊑ int")) 448 | constraints[MiddleFunc].add(parse_cs("int ⊑ MiddleFunc.in_0")) 449 | constraints[MiddleFunc].add(parse_cs("MiddleFunc.in_2.load.σ8@32 ⊑ A")) 450 | constraints[MiddleFunc].add( 451 | parse_cs("CrashFunc.out ⊑ MiddleFunc.in_2.store") 452 | ) 453 | 454 | callgraph = { 455 | EmptyFunc: [], 456 | RootFunc: [CrashFunc, MiddleFunc], 457 | CrashFunc: [], 458 | MiddleFunc: [CrashFunc], 459 | SCCFunc1: [CrashFunc, SCCFunc2, MiddleFunc, EmptyFunc], 460 | SCCFunc2: [SCCFunc1, MiddleFunc2], 461 | MiddleFunc2: [MiddleFunc], 462 | } 463 | 464 | program = Program(CLattice(), {}, constraints, callgraph) 465 | config = SolverConfig(graph_solver="dfa", top_down_propagation=True) 466 | solver = Solver(program, config) 467 | 468 | gen_cs, sketches = solver() 469 | -------------------------------------------------------------------------------- /test/test_pathexpr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from retypd.pathexpr import path_expression_between 3 | import networkx 4 | 5 | 6 | @pytest.mark.parametrize( 7 | ("edges", "path_exprs"), 8 | [ 9 | ( 10 | [ 11 | ("a", "b", "A"), 12 | ("b", "c", "B"), 13 | ("c", "d", "C"), 14 | ], 15 | [ 16 | ("a", "b", "A"), 17 | ("a", "c", "(A . B)"), 18 | ("a", "d", "(A . B . C)"), 19 | ("b", "c", "B"), 20 | ("b", "d", "(B . C)"), 21 | ], 22 | ), 23 | ( 24 | [ 25 | ("a", "b", "A"), 26 | ("b", "b", "B"), 27 | ("b", "c", "C"), 28 | ], 29 | [ 30 | ("a", "b", "(A . B*)"), 31 | ("a", "c", "(A . B* . C)"), 32 | ], 33 | ), 34 | ( 35 | [ 36 | ("a", "b", "A"), 37 | ("b", "b", "B"), 38 | ("b", "c", "C"), 39 | ("a", "c", "D"), 40 | ], 41 | [ 42 | ("a", "b", "(A . B*)"), 43 | ("a", "c", {"(D U (A . B* . C))", "((A . B* . C) U D)"}), 44 | ], 45 | ), 46 | ( 47 | [ 48 | ("a", "b", "A"), 49 | ("b", "b", "B"), 50 | ("b", "c", "C"), 51 | ("a", "d", "D"), 52 | ("d", "d", "E"), 53 | ("d", "c", "F"), 54 | ], 55 | [ 56 | ("a", "b", "(A . B*)"), 57 | ("a", "d", "(D . E*)"), 58 | ( 59 | "a", 60 | "c", 61 | { 62 | "((A . B* . C) U (D . E* . F))", 63 | "((D . E* . F) U (A . B* . C))", 64 | }, 65 | ), 66 | ], 67 | ), 68 | ( 69 | [ 70 | ("a", "b", "A"), 71 | ("b", "c", "B"), 72 | ("c", "a", "C"), 73 | ], 74 | [ 75 | ( 76 | "a", 77 | "b", 78 | { 79 | "(A U (A . B . (C . A . B)* . C . A))", 80 | "((A . B . (C . A . B)* . C . A) U A)", 81 | }, 82 | ), 83 | ("b", "c", "(B . (C . A . B)*)"), 84 | ("a", "c", "(A . B . (C . A . B)*)"), 85 | ], 86 | ), 87 | ( 88 | [("a", "b", "A"), ("b", "c", None)], 89 | [("a", "b", "A"), ("a", "c", "A")], 90 | ), 91 | ( 92 | [ 93 | ("a", "b", "A"), 94 | ("b", "c", "B"), 95 | ("c", "a", None), 96 | ], 97 | [ 98 | ( 99 | "a", 100 | "b", 101 | { 102 | "(A U (A . B . (A . B)* . A))", 103 | "((A . B . (A . B)* . A) U A)", 104 | }, 105 | ), 106 | ("a", "c", "(A . B . (A . B)*)"), 107 | ], 108 | ), 109 | ( 110 | [ 111 | ("a", "b", "A"), 112 | ("b", "b", None), 113 | ("b", "c", "B"), 114 | ], 115 | [("a", "b", "A"), ("a", "c", "(A . B)")], 116 | ), 117 | ], 118 | ids=[ 119 | "simple", 120 | "self_cycle", 121 | "self_cycle_with_shortcut", 122 | "multiple_cycles", 123 | "whole_graph_cycle", 124 | "empty_label", 125 | "empty_loop", 126 | "empty_self_loop", 127 | ], 128 | ) 129 | @pytest.mark.parametrize("decompose", [True, False]) 130 | @pytest.mark.commit 131 | def test_path_expr(edges, path_exprs, decompose): 132 | """Generate a unit test for a given set of edges and expected path 133 | expressions 134 | """ 135 | graph = networkx.DiGraph() 136 | 137 | for (src, dest, label) in edges: 138 | if label is not None: 139 | graph.add_edge(src, dest, label=label) 140 | else: 141 | graph.add_edge(src, dest) 142 | 143 | for (source, dest, expr) in path_exprs: 144 | generated = path_expression_between( 145 | graph, "label", source, dest, decompose 146 | ) 147 | if isinstance(expr, str): 148 | assert str(generated) == expr 149 | elif isinstance(expr, set): 150 | assert str(generated) in expr 151 | else: 152 | raise NotImplementedError() 153 | -------------------------------------------------------------------------------- /test/test_regressions.py: -------------------------------------------------------------------------------- 1 | from test_endtoend import ( 2 | compute_sketches, 3 | all_solver_configs, 4 | parse_var, 5 | parse_cs, 6 | ) 7 | import pytest 8 | from retypd import CLattice 9 | 10 | 11 | @all_solver_configs 12 | @pytest.mark.commit 13 | def test_regression3(config): 14 | """ 15 | This test checks that infer_shapes 16 | correctly unifies 'a.load' and 'a.store' 17 | for every DTV 'a'. This is a special case 18 | of combining the S-Refl nand S-Pointer 19 | type rules. 20 | """ 21 | constraints = { 22 | "prvListTasksWithinSingleList": [ 23 | "v_566 ⊑ v_994", 24 | "v_181 ⊑ int32", 25 | "v_451.load.σ4@4 ⊑ v_86", 26 | "v_287 ⊑ int32", 27 | "v_281.load.σ4@12 ⊑ v_995", 28 | "v_572 ⊑ v_368", 29 | "v_211.load.σ4@4 ⊑ v_217", 30 | "v_807 ⊑ int32", 31 | "v_451.load.σ4@4 ⊑ v_211", 32 | "v_451.load.σ4@0 ⊑ v_69", 33 | "int32 ⊑ v_354", 34 | "v_991 ⊑ v_451.store.σ4@4", 35 | "bool ⊑ v_132", 36 | "bool ⊑ v_333", 37 | "v_175.load.σ4@12 ⊑ v_992", 38 | "v_971.load.σ4@4 ⊑ v_262", 39 | "v_217 ⊑ v_993", 40 | "bool ⊑ v_238", 41 | "v_92 ⊑ int32", 42 | "v_992 ⊑ v_181", 43 | "v_569 ⊑ v_991", 44 | "v_185 ⊑ v_184", 45 | "v_451.load.σ4@4 ⊑ v_281", 46 | "v_953.load.σ4@4 ⊑ v_156", 47 | "v_287 ⊑ vTaskGetInfo.in_0", 48 | "v_111 ⊑ int32", 49 | "v_293 ⊑ vTaskGetInfo.in_1", 50 | "v_80 ⊑ int32", 51 | "v_994 ⊑ v_451.store.σ4@4", 52 | "bool ⊑ v_78", 53 | "prvListTasksWithinSingleList.in_0 ⊑ v_450", 54 | "v_993 ⊑ v_451.store.σ4@4", 55 | "v_92 ⊑ v_989", 56 | "v_217 ⊑ int32", 57 | "v_989 ⊑ v_451.store.σ4@4", 58 | "v_69 ⊑ int32", 59 | "v_995 ⊑ v_287", 60 | "v_575 ⊑ int32", 61 | "prvListTasksWithinSingleList.in_1 ⊑ v_451", 62 | "v_451.load.σ4@4 ⊑ v_175", 63 | "prvListTasksWithinSingleList.in_2 ⊑ v_452", 64 | "v_86.load.σ4@4 ⊑ v_92", 65 | "v_990 ⊑ v_111", 66 | "v_368 ⊑ prvListTasksWithinSingleList.out", 67 | ] 68 | } 69 | callgraph = {"prvListTasksWithinSingleList": []} 70 | lattice = CLattice() 71 | (gen_cs, sketches) = compute_sketches( 72 | constraints, 73 | callgraph, 74 | lattice=lattice, 75 | config=config, 76 | ) 77 | sk = sketches[parse_var("prvListTasksWithinSingleList")] 78 | # The sketch has a cycle thanks to the equivalence between 79 | # prvListTasksWithinSingleList.in_1.store.σ4@4 and 80 | # prvListTasksWithinSingleList.in_1.load.σ4@4 81 | assert sk.lookup( 82 | parse_var("prvListTasksWithinSingleList.in_1.load.σ4@4.load.σ4@4") 83 | ) is sk.lookup(parse_var("prvListTasksWithinSingleList.in_1.load.σ4@4")) 84 | assert sk.lookup( 85 | parse_var("prvListTasksWithinSingleList.in_1.store.σ4@4.load.σ4@4") 86 | ) is sk.lookup(parse_var("prvListTasksWithinSingleList.in_1.store.σ4@4")) 87 | 88 | 89 | @pytest.mark.commit 90 | @all_solver_configs 91 | def test_regression4(config): 92 | constraints = { 93 | "b": [ 94 | "RSP_1735 ⊑ int", 95 | "RBX_1732 ⊑ int", 96 | "RAX_1719.load.σ4@0 ⊑ RDX_1723", 97 | "int ⊑ RSP_1711", 98 | "RBP_1707.load.σ8@-24 ⊑ RAX_1725", 99 | "RAX_1729 ⊑ int", 100 | "RSP_1710 ⊑ int", 101 | "b.in_0 ⊑ RDI_1715", 102 | "int ⊑ RSP_1742", 103 | "RBP_1707.load.σ8@-24 ⊑ RAX_1719", 104 | "RDI_1715 ⊑ RBP_1707.store.σ8@-24", 105 | "RAX_1740 ⊑ b.out", 106 | "RAX_1725.load.σ4@4 ⊑ RAX_1729", 107 | "int ⊑ RAX_1740", 108 | ], 109 | "a": [ 110 | "RAX_1771 ⊑ a.out", 111 | "stack_1757 ⊑ RAX_1761", 112 | "RSP_1749 ⊑ int", 113 | "RDI_1775 ⊑ b.in_0", 114 | "RDI_1757 ⊑ stack_1757", 115 | "int ⊑ RSP_1753", 116 | "a.in_0 ⊑ b.in_0", 117 | "int ⊑ RAX_1761.store.σ4@0", 118 | "a.in_0 ⊑ RDI_1757", 119 | "stack_1757 ⊑ RAX_1771", 120 | ], 121 | "main": [ 122 | "RAX_1808 ⊑ main.out", 123 | "RSP_1785 ⊑ int", 124 | "int ⊑ RSP_1789", 125 | "RDX_1820 ⊑ uint", 126 | "RAX_1802 ⊑ stack_1802", 127 | "stack_1802 ⊑ RDX_1820", 128 | "uint ⊑ RDX_1824", 129 | "RDI_1812 ⊑ a.in_0", 130 | ], 131 | } 132 | callgraph = {"a": {"b"}, "main": {"a", "FUN_570"}, "b": {"FUN_580"}} 133 | (gen_const, sketches) = compute_sketches( 134 | constraints, callgraph, CLattice(), config 135 | ) 136 | assert ( 137 | sketches[parse_var("b")].lookup(parse_var("b.in_0.load.σ4@0")) 138 | is not None 139 | ) 140 | assert ( 141 | sketches[parse_var("b")].lookup(parse_var("b.in_0.load.σ4@4")) 142 | is not None 143 | ) 144 | 145 | assert ( 146 | sketches[parse_var("a")].lookup(parse_var("a.in_0.load.σ4@0")) 147 | is not None 148 | ) 149 | assert ( 150 | sketches[parse_var("a")].lookup(parse_var("a.in_0.load.σ4@4")) 151 | is not None 152 | ) 153 | 154 | assert parse_cs("a.in_0.load.σ4@4 ⊑ int") in gen_const[parse_var("a")] 155 | assert parse_cs("int ⊑ a.in_0.store.σ4@0") in gen_const[parse_var("a")] 156 | 157 | 158 | @pytest.mark.commit 159 | @all_solver_configs 160 | def test_case(config): 161 | """ 162 | Test case to ensure that we do not consider paths 163 | across lattice types, not even in the saturation algorithm. 164 | """ 165 | constraints = { 166 | "F": [ 167 | "F.in_0 ⊑ A.store.σ8@0", 168 | "int64 ⊑ A", 169 | "B.load.σ8@0 ⊑ C", 170 | "int64 ⊑ B", 171 | "C.load.σ1@0*[nullterm] ⊑ int", 172 | ] 173 | } 174 | callgraph = {"F": []} 175 | lattice = CLattice() 176 | (gen_cs, sketches) = compute_sketches( 177 | constraints, 178 | callgraph, 179 | lattice=lattice, 180 | config=config, 181 | ) 182 | 183 | F = parse_var("F") 184 | assert len(gen_cs[F]) == 0 185 | -------------------------------------------------------------------------------- /test/test_sketches.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from retypd import ( 3 | Sketch, 4 | DerivedTypeVariable, 5 | CLattice, 6 | ) 7 | from retypd.sketches import LabelNode 8 | from retypd.parser import SchemaParser 9 | from typing import Dict 10 | 11 | 12 | def sketch_from_dict(root: str, tree: Dict[str, Dict[str, str]]): 13 | """ 14 | Utility to build sketches from string dictionaries. 15 | """ 16 | sk = Sketch(DerivedTypeVariable(root), CLattice()) 17 | # transform to DTV 18 | dtv_tree = {} 19 | for k, succs in tree.items(): 20 | dtv_succs = {} 21 | for label, succ in succs.items(): 22 | dtv_succs[ 23 | SchemaParser.parse_label(label) 24 | ] = SchemaParser.parse_variable(succ) 25 | dtv_tree[SchemaParser.parse_variable(k)] = dtv_succs 26 | 27 | for node in dtv_tree: 28 | sk.make_node(node) 29 | for dtv, succs in dtv_tree.items(): 30 | for label, succ in succs.items(): 31 | tail = dtv.get_suffix(succ) 32 | node = sk.lookup(dtv) 33 | if tail is not None: 34 | succ_node = sk.lookup(succ) 35 | if succ_node is None: 36 | succ_node = sk.make_node(succ) 37 | sk.add_edge(node, succ_node, label) 38 | else: 39 | label_node = LabelNode(succ) 40 | sk.add_edge(node, label_node, label) 41 | return sk 42 | 43 | 44 | @pytest.mark.commit 45 | def test_join_sketch(): 46 | """ 47 | Test that joining recursive and non-recursive 48 | results in non-recursive sketch. 49 | """ 50 | sk1 = sketch_from_dict( 51 | "f", 52 | { 53 | "f": {"in_0": "f.in_0"}, 54 | "f.in_0": {"load": "f.in_0.load"}, 55 | "f.in_0.load": { 56 | "σ8@0": "f.in_0", 57 | "σ8@4": "f.in_0.load.σ8@4", 58 | }, 59 | }, 60 | ) 61 | sk2 = sketch_from_dict( 62 | "f", 63 | { 64 | "f": {"in_0": "f.in_0"}, 65 | "f.in_0": {"load": "f.in_0.load"}, 66 | "f.in_0.load": { 67 | "σ8@4": "f.in_0.load.σ8@4", 68 | }, 69 | }, 70 | ) 71 | sk1.join(sk2) 72 | assert sk1.lookup(SchemaParser.parse_variable("f.in_0.load.σ8@0")) is None 73 | --------------------------------------------------------------------------------