├── .gitignore ├── LICENSE ├── README.md ├── code_ast ├── __init__.py ├── ast.py ├── config.py ├── parsers.py ├── transformer.py └── visitor.py ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py └── test_code_ast.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Cedric Richter 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code AST 2 | > Fast structural analysis of any programming language in Python 3 | 4 | Programming Language Processing (PLP) brings the capabilities of modern NLP systems to the world of programming languages. 5 | To achieve high performance PLP systems, existing methods often take advantage of the fully defined nature of programming languages. Especially the syntactical structure can be exploited to gain knowledge about programs. 6 | 7 | **code.ast** provides easy access to the syntactic structure of a program. By relying on [tree-sitter](https://github.com/tree-sitter) as the back end, the parser supports fast parsing of variety of programming languages. 8 | 9 | The goal of code.ast is to combine the efficiency and variety of languages supported by tree-sitter with the convenience of more native parsers (like [libcst](https://github.com/Instagram/LibCST)). 10 | 11 | To achieve this, code.ast adds the features: 12 | 1. **Auto-loading:** Compile of source code parsers for any language supported by tree-sitter with a single keyword, 13 | 2. **Visitors:** Search the concrete syntax tree produced by tree-sitter quickly, 14 | 3. **Transformers:** Transform source code easily by transforming the syntax structure 15 | 16 | *Note* that tree-sitter produces a concrete syntax tree and we currently parse 17 | the CST as is. Future versions of code.ast might include options to simplify the CST 18 | to an AST. 19 | 20 | ## Installation 21 | The package is tested under Python 3. It can be installed via: 22 | ```bash 23 | pip install code-ast 24 | ``` 25 | 26 | Note: Install `tree_sitter_language` to utilize pre-compiled languages via: 27 | ```bash 28 | pip install tree_sitter_languages 29 | ``` 30 | If `tree_sitter_language` is not installed, `code_ast` will try 31 | to download and compile the selected language from scratch. 32 | 33 | ## Quick start 34 | code.ast can parse nearly any program code in a few lines of code: 35 | ```python 36 | import code_ast 37 | 38 | # Python 39 | code_ast.ast( 40 | ''' 41 | def my_func(): 42 | print("Hello World") 43 | ''', 44 | lang = "python") 45 | 46 | # Output: 47 | # PythonCodeAST [0, 0] - [4, 4] 48 | # module [1, 8] - [3, 4] 49 | # function_definition [1, 8] - [2, 32] 50 | # identifier [1, 12] - [1, 19] 51 | # parameters [1, 19] - [1, 21] 52 | # block [2, 12] - [2, 32] 53 | # expression_statement [2, 12] - [2, 32] 54 | # call [2, 12] - [2, 32] 55 | # identifier [2, 12] - [2, 17] 56 | # argument_list [2, 17] - [2, 32] 57 | # string [2, 18] - [2, 31] 58 | 59 | # Java 60 | code_ast.ast( 61 | ''' 62 | public class HelloWorld { 63 | public static void main(String[] args){ 64 | System.out.println("Hello World"); 65 | } 66 | } 67 | ''', 68 | lang = "java") 69 | 70 | # Output: 71 | # JavaCodeAST [0, 0] - [7, 4] 72 | # program [1, 0] - [6, 4] 73 | # class_declaration [1, 0] - [5, 1] 74 | # modifiers [1, 0] - [1, 6] 75 | # identifier [1, 13] - [1, 23] 76 | # class_body [1, 24] - [5, 1] 77 | # method_declaration [2, 8] - [4, 9] 78 | # ... 79 | 80 | 81 | ``` 82 | 83 | ## Visitors 84 | code.ast implements the visitor pattern to quickly traverse the CST structure: 85 | ```python 86 | import code_ast 87 | from code_ast import ASTVisitor 88 | 89 | code = ''' 90 | def f(x, y): 91 | return x + y 92 | ''' 93 | 94 | # Count the number of identifiers 95 | class IdentifierCounter(ASTVisitor): 96 | 97 | def __init__(self): 98 | self.count = 0 99 | 100 | def visit_identifier(self, node): 101 | self.count += 1 102 | 103 | # Parse the AST and then visit it with our visitor 104 | source_ast = code_ast.ast(code, lang = "python") 105 | 106 | count_visitor = IdentifierCounter() 107 | source_ast.visit(count_visitor) 108 | 109 | count_visitor.count 110 | # Output: 5 111 | 112 | ``` 113 | 114 | ## Transformers 115 | Transformers provide an easy way to transform source code. For example, in the following, we want to mirror each binary addition: 116 | ```python 117 | import code_ast 118 | from code_ast import ASTTransformer, FormattedUpdate, TreeUpdate 119 | 120 | code = ''' 121 | def f(x, y): 122 | return x + y + 0.5 123 | ''' 124 | 125 | # Mirror binary operator on leave 126 | class MirrorAddTransformer(ASTTransformer): 127 | def leave_binary_operator(self, node): 128 | if node.children[1].type == "+": 129 | return FormattedUpdate( 130 | " %s + %s", 131 | [ 132 | TreeUpdate(node.children[2]), 133 | TreeUpdate(node.children[0]) 134 | ] 135 | ) 136 | 137 | # Parse the AST and then visit it with our visitor 138 | source_ast = code_ast.ast(code, lang = "python") 139 | 140 | mirror_transformer = MirrorAddTransformer() 141 | 142 | # Mirror transformer are initialized by running them as visitors 143 | source_ast.visit(mirror_transformer) 144 | 145 | # Transformer provide a minimal AST edit 146 | mirror_transformer.edit() 147 | # Output: 148 | # module [2, 0] - [5, 0] 149 | # function_definition [2, 0] - [3, 22] 150 | # block [3, 4] - [3, 22] 151 | # return_statement [3, 4] - [3, 22] 152 | # binary_operator -> FormattedUpdate [3, 11] - [3, 22] 153 | # binary_operator -> FormattedUpdate [3, 11] - [3, 16] 154 | 155 | # And it can be used to directly transform the code 156 | mirror_transformer.code() 157 | # Output: 158 | # def f(x, y): 159 | # return 0.5 + y + x 160 | 161 | ``` 162 | 163 | ## Project Info 164 | The goal of this project is to provide developer in the 165 | programming language processing community with easy 166 | access to syntax parsing. This is currently developed as a helper library for internal research projects. Therefore, it will only be updated 167 | as needed. 168 | 169 | Feel free to open an issue if anything unexpected 170 | happens. 171 | 172 | Distributed under the MIT license. See ``LICENSE`` for more information. 173 | 174 | We thank the developer of [tree-sitter](https://tree-sitter.github.io/tree-sitter/) library. Without tree-sitter this project would not be possible. 175 | -------------------------------------------------------------------------------- /code_ast/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import logging as logger 3 | 4 | from .config import ParserConfig 5 | 6 | from .ast import SourceCodeAST 7 | 8 | from .parsers import ( 9 | ASTParser, 10 | match_span 11 | ) 12 | 13 | from .visitor import ( 14 | ASTVisitor, 15 | VisitorComposition 16 | ) 17 | 18 | from .transformer import ( 19 | ASTTransformer, 20 | FormattedUpdate, 21 | TextUpdate, 22 | NodeUpdate 23 | ) 24 | 25 | 26 | # Main function -------------------------------- 27 | 28 | def ast(source_code, lang = "guess", **kwargs): 29 | """ 30 | Parses the AST of source code of most programming languages quickly. 31 | 32 | Parameters 33 | ---------- 34 | source_code : str 35 | Source code to parsed as a string. Also 36 | supports parsing of incomplete source code 37 | snippets (by deactivating the syntax checker; see syntax_error) 38 | 39 | lang : [python, java, javascript, ...] 40 | String identifier of the programming language 41 | to be parsed. Supported are most programming languages 42 | including python, java and javascript (see README) 43 | Default: guess (Guesses language / Not supported currently throws error currently) 44 | 45 | syntax_error : [raise, warn, ignore] 46 | Reaction to syntax error in code snippet. 47 | raise: raises a Syntax Error 48 | warn: prints a warning to console 49 | ignore: Ignores syntax errors. Helpful for parsing code snippets. 50 | Default: raise 51 | 52 | Returns 53 | ------- 54 | Root 55 | root of AST tree as parsed by tree-sitter 56 | 57 | """ 58 | 59 | if len(source_code.strip()) == 0: raise ValueError("The code string is empty. Cannot tokenize anything empty: %s" % source_code) 60 | 61 | # If lang == guess, automatically determine the language 62 | if lang == "guess": lang = _lang_detect(source_code) 63 | 64 | logger.debug("Parses source code with parser for %s" % lang) 65 | 66 | # Setup config 67 | config = ParserConfig(lang, **kwargs) 68 | 69 | # Parse source tree 70 | parser = ASTParser(config.lang) 71 | tree, code = parser.parse(source_code) 72 | 73 | # Check for errors if necessary 74 | check_tree_for_errors(tree, mode = config.syntax_error) 75 | 76 | return SourceCodeAST(config, tree, code) 77 | 78 | 79 | # Lang detect -------------------------------------- 80 | 81 | 82 | def _lang_detect(source_code): 83 | """Guesses the source code type using pygments""" 84 | raise NotImplementedError( 85 | "Guessing the language automatically is currently not implemented. Please specify a language with the lang keyword\n code_tokenize.tokenize(code, lang = your_lang)" 86 | ) 87 | 88 | # Detect error -------------------------------- 89 | 90 | class ErrorVisitor(ASTVisitor): 91 | 92 | def __init__(self, error_mode): 93 | self.error_mode = error_mode 94 | 95 | def visit_ERROR(self, node): 96 | 97 | if self.error_mode == "raise": 98 | raise_syntax_error(node) 99 | return 100 | 101 | if self.error_mode == "warn": 102 | warn_syntax_error(node) 103 | return 104 | 105 | 106 | def check_tree_for_errors(tree, mode = "raise"): 107 | if mode == "ignore": return 108 | 109 | # Check for errors 110 | ErrorVisitor(mode)(tree) 111 | 112 | 113 | # Error handling ----------------------------------------------------------- 114 | 115 | def _construct_error_msg(node): 116 | 117 | start_line, start_char = node.start_point 118 | end_line, end_char = node.end_point 119 | 120 | position = "?" 121 | if start_line == end_line: 122 | position = "in line %d [pos. %d - %d]" % (start_line, start_char, end_char) 123 | else: 124 | position = "inbetween line %d (start: %d) to line %d (end: %d)" % (start_line, start_char, end_line, end_char) 125 | 126 | return "Problem while parsing given code snipet. Error occured %s" % position 127 | 128 | 129 | def warn_syntax_error(node): 130 | logger.warn(_construct_error_msg(node)) 131 | 132 | 133 | def raise_syntax_error(node): 134 | raise SyntaxError(_construct_error_msg(node)) -------------------------------------------------------------------------------- /code_ast/ast.py: -------------------------------------------------------------------------------- 1 | from .parsers import match_span 2 | 3 | class SourceCodeAST: 4 | 5 | def __init__(self, config, source_tree, source_lines): 6 | self.config = config 7 | self.source_tree = source_tree 8 | self.source_lines = source_lines 9 | 10 | def root_node(self): 11 | return self.source_tree.root_node 12 | 13 | def match(self, source_node): 14 | return match_span(source_node, self.source_lines) 15 | 16 | # Visit tree ---------------------------------------------------------------- 17 | 18 | def visit(self, visitor): 19 | 20 | try: 21 | visitor.from_code_lines(self.source_lines) 22 | except AttributeError: 23 | # Is not a transformer 24 | pass 25 | 26 | visitor(self.source_tree) 27 | 28 | # Repr ---------------------------------------------------------------- 29 | 30 | def code(self): 31 | return "\n".join(self.source_lines) 32 | 33 | def __repr__(self): 34 | 35 | lang = self.config.lang 36 | lang_name = "".join((lang_part[0].upper() + lang_part[1:] for lang_part in lang.split("-"))) 37 | 38 | ast_repr = ast_to_str(self.source_tree, indent = 1) 39 | 40 | return f"{lang_name}CodeAST [0, 0] - [{len(self.source_lines)}, {len(self.source_lines[-1])}]\n{ast_repr}" 41 | 42 | 43 | # AST to readable ---------------------------------------------------------------- 44 | 45 | LEAVE_WHITELIST = {"identifier", "integer", "float"} 46 | 47 | def _serialize_node(node): 48 | return f"{node.type} [{node.start_point[0]}, {node.start_point[1]}] - [{node.end_point[0]}, {node.end_point[1]}]" 49 | 50 | def ast_to_str(tree, indent = 0): 51 | ast_lines = [] 52 | root_node = tree.root_node 53 | cursor = root_node.walk() 54 | 55 | has_next = True 56 | 57 | while has_next: 58 | current_node = cursor.node 59 | 60 | if current_node.child_count > 0 or current_node.type in LEAVE_WHITELIST: 61 | ast_lines.append(" "*indent + _serialize_node(current_node)) 62 | 63 | # Step 1: Try to go to next child if we continue the subtree 64 | if cursor.goto_first_child(): 65 | indent += 1 66 | has_next = True 67 | else: 68 | has_next = False 69 | 70 | # Step 2: Try to go to next sibling 71 | if not has_next: 72 | has_next = cursor.goto_next_sibling() 73 | 74 | # Step 3: Go up until sibling exists 75 | while not has_next and cursor.goto_parent(): 76 | indent -= 1 77 | has_next = cursor.goto_next_sibling() 78 | 79 | return "\n".join(ast_lines) -------------------------------------------------------------------------------- /code_ast/config.py: -------------------------------------------------------------------------------- 1 | 2 | class ParserConfig: 3 | """Helper object to translate arguments of ast to config object""" 4 | 5 | def __init__(self, lang, **kwargs): 6 | self.lang = lang 7 | self.syntax_error = "raise" # Options: raise, warn, ignore 8 | 9 | # A list of all statement node defined in the language 10 | self.statement_types = [ 11 | "*_statement", "*_definition", "*_declaration" 12 | ] 13 | 14 | self.update(kwargs) 15 | 16 | 17 | def update(self, kwargs): 18 | for k, v in kwargs.items(): 19 | 20 | if k not in self.__dict__: 21 | raise TypeError("TypeError: tokenize() got an unexpected keyword argument '%s'" % k) 22 | 23 | self.__dict__[k] = v 24 | 25 | def __repr__(self): 26 | 27 | elements = [] 28 | for k, v in self.__dict__.items(): 29 | if v is not None: 30 | elements.append("%s=%s" % (k, v)) 31 | 32 | return "Config(%s)" % ", ".join(elements) 33 | 34 | -------------------------------------------------------------------------------- /code_ast/parsers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lean wrapper around the tree-sitter Python API 3 | 4 | Main features: 5 | - Parses arbitrary code as string and bytes 6 | - Autoloading / Compiling of AST parsers 7 | 8 | """ 9 | import os 10 | from tree_sitter import Language, Parser 11 | 12 | import logging as logger 13 | 14 | # For autoloading 15 | import requests 16 | from git import Repo 17 | 18 | try: 19 | from tree_sitter_languages import get_language, get_parser 20 | except ImportError: 21 | get_language, get_parser = None, None 22 | 23 | 24 | # Automatic loading of Tree-Sitter parsers -------------------------------- 25 | 26 | def load_language(lang): 27 | """ 28 | Loads a language specification object necessary for tree-sitter. 29 | 30 | Language specifications are loaded from remote or a local cache. 31 | If language specification is not contained in cache, the function 32 | clones the respective git project and then builds the language specification 33 | via tree-sitter. 34 | We employ the same language identifier as tree-sitter and 35 | lang is translated to a remote repository 36 | (https://github.com/tree-sitter/tree-sitter-[lang]). 37 | 38 | Parameters 39 | ---------- 40 | lang : [python, java, javascript, ...] 41 | language identifier specific to tree-sitter. 42 | As soon as there is a repository with the same language identifier 43 | the language is supported by this function. 44 | 45 | Returns 46 | ------- 47 | Language 48 | language specification object 49 | 50 | """ 51 | 52 | if get_language is not None: 53 | try: 54 | return get_language(lang) 55 | except Exception as e: 56 | logger.exception("No pre-compiled language for %s exists. Start compiling." % lang) 57 | 58 | cache_path = _path_to_local() 59 | 60 | compiled_lang_path = os.path.join(cache_path, "%s-lang.so" % lang) 61 | source_lang_path = os.path.join(cache_path, "tree-sitter-%s" % lang) 62 | 63 | if os.path.isfile(compiled_lang_path): 64 | return Language(compiled_lang_path, _lang_to_fnname(lang)) 65 | 66 | if os.path.exists(source_lang_path) and os.path.isdir(source_lang_path): 67 | logger.warning("Compiling language for %s" % lang) 68 | _compile_lang(source_lang_path, compiled_lang_path) 69 | return load_language(lang) 70 | 71 | logger.warning("Autoloading AST parser for %s: Start download from Github." % lang) 72 | _clone_parse_def_from_github(lang, source_lang_path) 73 | return load_language(lang) 74 | 75 | # Parser --------------------------------------------------------------- 76 | 77 | class ASTParser: 78 | """ 79 | Wrapper for tree-sitter AST parser 80 | 81 | Supports autocompiling the language specification needed 82 | for parsing (see load_language) 83 | 84 | """ 85 | 86 | def __init__(self, lang): 87 | """ 88 | Autoload language specification and parser 89 | 90 | Parameters 91 | ---------- 92 | lang : [python, java, javascript, ...] 93 | Language identifier specific to tree-sitter. 94 | Same as for load_language 95 | 96 | """ 97 | 98 | self.lang_id = lang 99 | self.lang = load_language(lang) 100 | 101 | if get_parser is not None: 102 | self.parser = get_parser(self.lang_id) 103 | else: 104 | self.parser = Parser() 105 | self.parser.set_language(self.lang) 106 | 107 | def parse_bytes(self, data): 108 | """ 109 | Parses source code as bytes into AST 110 | 111 | Parameters 112 | ---------- 113 | data : bytes 114 | Source code as a stream of bytes 115 | 116 | Returns 117 | ------- 118 | tree-sitter syntax tree 119 | 120 | """ 121 | return self.parser.parse(data) 122 | 123 | def parse(self, source_code): 124 | """ 125 | Parses source code into AST 126 | 127 | Parameters 128 | ---------- 129 | source_code : str 130 | Source code as a string 131 | 132 | Returns 133 | ------- 134 | tree-sitter syntax tree 135 | tree-sitter object representing the syntax tree 136 | 137 | source_lines 138 | a list of code lines for reference 139 | 140 | """ 141 | source_lines = source_code.splitlines() 142 | source_bytes = source_code.encode("utf-8") 143 | 144 | return self.parse_bytes(source_bytes), source_lines 145 | 146 | 147 | # Utils ------------------------------------------------ 148 | 149 | def match_span(source_tree, source_lines): 150 | """ 151 | Greps the source text represented by the given source tree from the original code 152 | 153 | Parameters 154 | ---------- 155 | source_tree : tree-sitter node object 156 | Root of the AST which should be used to match the code 157 | 158 | source_lines : list[str] 159 | Source code as a list of source lines 160 | 161 | Returns 162 | ------- 163 | str 164 | the source code that is represented by the given source tree 165 | 166 | """ 167 | 168 | start_line, start_char = source_tree.start_point 169 | end_line, end_char = source_tree.end_point 170 | 171 | assert start_line <= end_line 172 | assert start_line != end_line or start_char <= end_char 173 | 174 | source_area = source_lines[start_line:end_line + 1] 175 | 176 | if start_line == end_line: 177 | return source_area[0][start_char:end_char] 178 | else: 179 | source_area[0] = source_area[0][start_char:] 180 | source_area[-1] = source_area[-1][:end_char] 181 | return "\n".join(source_area) 182 | 183 | 184 | # Auto Load Languages -------------------------------------------------- 185 | 186 | PATH_TO_LOCALCACHE = None 187 | 188 | def _path_to_local(): 189 | global PATH_TO_LOCALCACHE 190 | 191 | if PATH_TO_LOCALCACHE is None: 192 | current_path = os.path.abspath(__file__) 193 | 194 | while os.path.basename(current_path) != "code_ast": 195 | current_path = os.path.dirname(current_path) 196 | 197 | current_path = os.path.dirname(current_path) # Top dir 198 | PATH_TO_LOCALCACHE = os.path.join(current_path, "build") 199 | 200 | return PATH_TO_LOCALCACHE 201 | 202 | 203 | def _compile_lang(source_path, compiled_path): 204 | logger.debug("Compile language from %s" % compiled_path) 205 | 206 | Language.build_library( 207 | compiled_path, 208 | [ 209 | source_path 210 | ] 211 | ) 212 | 213 | 214 | def _lang_to_fnname(lang): 215 | """ 216 | dash is not supported for function names. Therefore, 217 | we assume that dashes represented by underscores. 218 | """ 219 | return lang.replace("-", "_") 220 | 221 | 222 | # Auto Clone from Github -------------------------------- 223 | 224 | def _exists_url(url): 225 | req = requests.get(url) 226 | return req.status_code == 200 227 | 228 | 229 | def _clone_parse_def_from_github(lang, cache_path): 230 | 231 | # Start by testing whethe repository exists 232 | REPO_URL = "https://github.com/tree-sitter/tree-sitter-%s" % lang 233 | 234 | if not _exists_url(REPO_URL): 235 | raise ValueError("There is no parsing def for language %s available." % lang) 236 | 237 | logger.warning("Start cloning the parser definition from Github.") 238 | try: 239 | Repo.clone_from(REPO_URL, cache_path) 240 | except Exception: 241 | raise ValueError("To autoload a parsing definition, git needs to be installed on the system!") 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | -------------------------------------------------------------------------------- /code_ast/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from dataclasses import dataclass 3 | 4 | from .visitor import ASTVisitor 5 | from .parsers import match_span 6 | 7 | 8 | class ASTTransformer(ASTVisitor): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.code_lines = None 13 | self._edit_trees = [] 14 | 15 | # Code processing functions -------------------------------- 16 | 17 | def from_code_lines(self, code_lines): 18 | self.code_lines = code_lines 19 | 20 | def code(self): 21 | return self.edit().apply(self.code_lines) 22 | 23 | def edit(self): 24 | assert len(self._edit_trees) == 1, "Something went wrong during parsing" 25 | return self._edit_trees[0] 26 | 27 | def on_leave(self, original_node): 28 | node_update = super().on_leave(original_node) 29 | 30 | if isinstance(node_update, str): 31 | node_update = TextUpdate(node_update) 32 | 33 | num_children = original_node.child_count 34 | child_trees = [self._edit_trees.pop(-1) for _ in range(num_children)][::-1] 35 | 36 | if node_update is None or not isinstance(node_update, EditUpdate): 37 | self._edit_trees.append(EditTree(original_node, None, child_trees)) 38 | return 39 | 40 | self._edit_trees.append(EditTree(original_node, node_update, child_trees)) 41 | 42 | 43 | # Minimal edit tree ------------------------------------------------------------ 44 | 45 | class EditTree: 46 | 47 | def __init__(self, source_node, target_edit = None, children = []): 48 | self.source_node = source_node 49 | self.target_edit = target_edit 50 | self.children = children 51 | 52 | for c in children: 53 | if c.target_edit is None: c.children = [] 54 | 55 | if target_edit is None and any(c.target_edit is not None for c in self.children): 56 | self.target_edit = SubtreeUpdate() 57 | 58 | def apply(self, code_lines): 59 | return EditExecutor(self, code_lines).walk() 60 | 61 | def __repr__(self): 62 | return "\n".join(_edit_to_str(self)) 63 | 64 | # Edit operations ---------------------------------------------------------------- 65 | 66 | @dataclass 67 | class EditUpdate: 68 | 69 | def compile(self, sub_edits = None, code_lines = None): 70 | return "" 71 | 72 | @property 73 | def type(self): 74 | return self.__class__.__name__ 75 | 76 | 77 | @dataclass 78 | class SubtreeUpdate(EditUpdate): 79 | pass 80 | 81 | 82 | @dataclass 83 | class TextUpdate(EditUpdate): 84 | text : str 85 | 86 | def compile(self, sub_edits = None, code_lines = None): 87 | return self.text 88 | 89 | @dataclass 90 | class NodeUpdate(EditUpdate): 91 | node : Any 92 | 93 | def compile(self, sub_edits = None, code_lines = None): 94 | return match_span(self.node, code_lines) 95 | 96 | 97 | @dataclass 98 | class TreeUpdate(EditUpdate): 99 | node : Any 100 | 101 | def compile(self, sub_edits = None, code_lines = None): 102 | 103 | for sub_edit in sub_edits: 104 | if sub_edit.target_edit is None: continue 105 | if sub_edit.source_node == self.node: 106 | return sub_edit.target_edit.compile( 107 | sub_edit.children, code_lines 108 | ) 109 | 110 | return match_span(self.node, code_lines) 111 | 112 | 113 | @dataclass 114 | class FormattedUpdate(EditUpdate): 115 | format_str : str 116 | args : List[EditUpdate] 117 | 118 | def compile(self, sub_edits = None, code_lines = None): 119 | args = tuple(arg.compile(sub_edits, code_lines) 120 | for arg in self.args) 121 | return self.format_str % args 122 | 123 | 124 | # Edit to str ---------------------------------------------------------------- 125 | 126 | def _serialize_tree(edit_tree): 127 | source = edit_tree.source_node 128 | if edit_tree.target_edit.type == "SubtreeUpdate": 129 | return f"{source.type} [{source.start_point[0]}, {source.start_point[1]}] - [{source.end_point[0]}, {source.end_point[1]}]" 130 | 131 | return f"{source.type} -> {edit_tree.target_edit.type} [{source.start_point[0]}, {source.start_point[1]}] - [{source.end_point[0]}, {source.end_point[1]}]" 132 | 133 | 134 | def _edit_to_str(edit_tree, indent = 0): 135 | str_lines = [] 136 | if edit_tree.target_edit is None: return [] 137 | 138 | str_lines.append( 139 | " " * indent + _serialize_tree(edit_tree) 140 | ) 141 | str_lines.extend([l for c in edit_tree.children for l in _edit_to_str(c, indent = indent + 1)]) 142 | 143 | return str_lines 144 | 145 | 146 | 147 | # A simple edit executor -------------------------------------------------------- 148 | 149 | class EditExecutor: 150 | 151 | def __init__(self, edit_tree, code_lines): 152 | self.code_lines = code_lines 153 | 154 | self._edit_stack = [edit_tree] 155 | self._target_lines = [] 156 | 157 | # Cursors 158 | self._cursor = (0, 0) 159 | self._delay_move = (0, 0) 160 | 161 | def _move_cursor(self, position): 162 | assert position >= self._cursor 163 | 164 | while self._cursor[0] < position[0]: 165 | if self._cursor[1] == 0: 166 | self._target_lines.append(self.code_lines[self._cursor[0]]) 167 | else: 168 | add_part = self.code_lines[self._cursor[0]][self._cursor[1]:] 169 | self._target_lines.append(add_part) 170 | 171 | self._target_lines.append("\n") 172 | self._cursor = (self._cursor[0] + 1, 0) 173 | 174 | if self._cursor[1] < position[1]: 175 | add_part = self.code_lines[self._cursor[0]][self._cursor[1]:position[1]] 176 | self._target_lines.append(add_part) 177 | self._cursor = (self._cursor[0], position[1]) 178 | 179 | def _delay_cursor(self, position): 180 | assert position >= self._cursor 181 | self._delay_move = position 182 | 183 | def _execute_noop(self, edit_tree): 184 | node = edit_tree.source_node 185 | node_end = node.end_point 186 | self._delay_cursor(node_end) 187 | 188 | def _execute(self, edit_tree): 189 | 190 | if edit_tree.target_edit is None: 191 | self._execute_noop(edit_tree) 192 | return 193 | 194 | if edit_tree.target_edit.type == "SubtreeUpdate": 195 | self._edit_stack.extend(edit_tree.children[::-1]) 196 | return 197 | 198 | if self._delay_move >= self._cursor: 199 | self._move_cursor(self._delay_move) 200 | self._delay_move = self._cursor 201 | 202 | self._cursor = edit_tree.source_node.end_point 203 | self._target_lines.append( 204 | edit_tree.target_edit.compile( 205 | edit_tree.children, 206 | self.code_lines 207 | ) 208 | ) 209 | 210 | def walk(self): 211 | 212 | while len(self._edit_stack) > 0: 213 | self._execute(self._edit_stack.pop(-1)) 214 | 215 | if self._delay_move >= self._cursor: 216 | self._move_cursor(self._delay_move) 217 | self._delay_move = self._cursor 218 | 219 | return "".join(self._target_lines) -------------------------------------------------------------------------------- /code_ast/visitor.py: -------------------------------------------------------------------------------- 1 | """ 2 | AST Visitors 3 | 4 | Can be directly executed on AST structures 5 | """ 6 | 7 | class ASTVisitor: 8 | 9 | # Decreased version of visit (no edges are supported) ----------------------- 10 | 11 | def visit(self, node): 12 | """ 13 | Default visitor function 14 | 15 | Override this to capture all nodes that are not covered by a specific visitor. 16 | """ 17 | 18 | def leave(self, node): 19 | """ 20 | Default leave function 21 | 22 | This is called when a subtree rooted at the given node is left. 23 | Override this to capture all nodes that are not covered by a specific leave function. 24 | """ 25 | 26 | # Internal methods --------------------------------------------------------- 27 | 28 | def on_visit(self, node): 29 | visitor_fn = getattr(self, "visit_%s" % node.type, self.visit) 30 | return visitor_fn(node) is not False 31 | 32 | def on_leave(self, node): 33 | leave_fn = getattr(self, "leave_%s" % node.type, self.leave) 34 | return leave_fn(node) 35 | 36 | # Navigation ---------------------------------------------------------------- 37 | 38 | def walk(self, root_node): 39 | if root_node is None: return 40 | 41 | cursor = root_node.walk() 42 | has_next = True 43 | 44 | while has_next: 45 | current_node = cursor.node 46 | # Step 1: Try to go to next child if we continue the subtree 47 | if self.on_visit(current_node): 48 | has_next = cursor.goto_first_child() 49 | else: 50 | has_next = False 51 | 52 | # Step 2: Try to go to next sibling 53 | if not has_next: 54 | self.on_leave(current_node) 55 | has_next = cursor.goto_next_sibling() 56 | 57 | # Step 3: Go up until sibling exists 58 | while not has_next and cursor.goto_parent(): 59 | self.on_leave(cursor.node) # We will never return back to this specific parent 60 | has_next = cursor.goto_next_sibling() 61 | 62 | 63 | def __call__(self, root_node): 64 | return self.walk(root_node) 65 | 66 | 67 | # Compositions ---------------------------------------------------------------- 68 | 69 | class VisitorComposition(ASTVisitor): 70 | 71 | def __init__(self, *visitors): 72 | super().__init__() 73 | self.visitors = visitors 74 | 75 | def on_visit(self, node): 76 | for base_visitor in self.visitors: 77 | if base_visitor.on_visit(node) is False: return False 78 | return True 79 | 80 | def on_leave(self, node): 81 | for base_visitor in self.visitors: 82 | base_visitor.on_leave(node) 83 | 84 | def __repr__(self): 85 | return str(self.visitors) 86 | 87 | 88 | class ResumingVisitorComposition(ASTVisitor): 89 | """ 90 | Unlike a standard composition, visitors 91 | are resumed even if one visitor stops for a branch. 92 | 93 | This class should be equivalent to running N visitors 94 | in sequence. 95 | """ 96 | 97 | def __init__(self, *visitors): 98 | super().__init__() 99 | self.visitors = visitors 100 | 101 | self.__active_visitors = [True] * len(visitors) 102 | self.__resume_on = {} 103 | 104 | def on_visit(self, node): 105 | for pos, base_visitor in enumerate(self.visitors): 106 | if not self.__active_visitors[pos]: continue 107 | 108 | if base_visitor.on_visit(node) is False: 109 | self.__active_visitors[pos] = False 110 | self.__resume_on[pos] = node 111 | 112 | return any(self.__active_visitors) 113 | 114 | 115 | def on_leave(self, node): 116 | for pos, base_visitor in enumerate(self.visitors): 117 | if not self.__active_visitors[pos]: 118 | resume_node = self.__resume_on[pos] 119 | if resume_node == node: 120 | self.__active_visitors[pos] = True 121 | else: 122 | continue 123 | 124 | base_visitor.on_leave(node) 125 | 126 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "code_ast" 7 | version = "v0.1.1" 8 | description = "Fast structural analysis of any programming language in Python" 9 | readme = "README.md" 10 | requires-python = ">= 3.8" 11 | license = { file = "LICENSE.txt" } 12 | keywords = ["code", "ast", "cst", "syntax", "program", "language processing"] 13 | 14 | authors = [{name = "Cedric Richter", email = "cedricr.upb@gmail.com"}] 15 | maintainers = [{name = "Cedric Richter", email = "cedricr.upb@gmail.com"}] 16 | 17 | classifiers = [ 18 | "Development Status :: 3 - Alpha", 19 | "Intended Audience :: Developers", 20 | "Topic :: Software Development :: Build Tools", 21 | "License :: OSI Approved :: MIT License", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.6", 24 | "Programming Language :: Python :: 3.7", 25 | "Programming Language :: Python :: 3.8", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | "Programming Language :: Python :: 3.13", 31 | "Programming Language :: Python :: 3 :: Only", 32 | ] 33 | 34 | dependencies = ["tree_sitter", "GitPython", "requests"] 35 | 36 | [project.urls] 37 | "Homepage" = "https://github.com/cedricrupb/code_ast" 38 | "Bug Reports" = "https://github.com/cedricrupb/code_ast/issues" 39 | "Source" = "https://github.com/cedricrupb/code_ast" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tree_sitter==0.21.3 2 | requests>=2.32.0 3 | GitPython>=3.1.41 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | ong_description_content_type = text/markdown -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name = 'code_ast', 8 | packages = ['code_ast'], 9 | version = '0.1.1', 10 | license='MIT', 11 | description = 'Fast structural analysis of any programming language in Python', 12 | long_description = long_description, 13 | long_description_content_type="text/markdown", 14 | author = 'Cedric Richter', 15 | author_email = 'cedricr.upb@gmail.com', 16 | url = 'https://github.com/cedricrupb/code_ast', 17 | download_url = 'https://github.com/cedricrupb/code_ast/archive/refs/tags/v0.1.0.tar.gz', 18 | keywords = ['code', 'ast', 'syntax', 'program', 'language processing'], 19 | install_requires=[ 20 | 'tree_sitter==0.21.3', 21 | 'GitPython>=3.1.41', 22 | 'requests>=2.32.0', 23 | ], 24 | extra_requires=[ 25 | 'GitPython>=3.1.41', 26 | ], 27 | classifiers=[ 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Developers', 30 | 'Topic :: Software Development :: Build Tools', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Programming Language :: Python :: 3', 33 | 'Programming Language :: Python :: 3.6', 34 | 'Programming Language :: Python :: 3.7', 35 | 'Programming Language :: Python :: 3.8', 36 | 'Programming Language :: Python :: 3.9', 37 | 'Programming Language :: Python :: 3.10', 38 | 'Programming Language :: Python :: 3.11', 39 | 'Programming Language :: Python :: 3.12', 40 | 'Programming Language :: Python :: 3.13', 41 | ], 42 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedricrupb/code_ast/582bfe2125a40fa2dc0eb182170d3f3ab3a576ea/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_code_ast.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from code_ast import ast, ASTParser, ASTVisitor, ASTTransformer 4 | 5 | from code_ast.transformer import FormattedUpdate, TreeUpdate 6 | 7 | # Prepare parser 8 | ASTParser("python") # Bootstrap parser for all test runs 9 | 10 | # Test the general language parsers ---------------------------------------------------------------- 11 | 12 | class TestPythonParser(TestCase): 13 | 14 | def test_ast_fn(self): 15 | code_ast = ast("def foo():\n bar()", lang = "python") 16 | 17 | current_node = code_ast.root_node() 18 | self.assertEqual(current_node.type, "module") 19 | 20 | self.assertEqual(current_node.child_count, 1) 21 | current_node = current_node.children[0] 22 | self.assertEqual(current_node.type, "function_definition") 23 | 24 | self.assertEqual(current_node.child_count, 5) 25 | self.assertEqual(current_node.children[0].type, "def") 26 | self.assertEqual(current_node.children[1].type, "identifier") 27 | self.assertEqual(current_node.children[2].type, "parameters") 28 | self.assertEqual(current_node.children[4].type, "block") 29 | 30 | current_node = current_node.children[4] 31 | self.assertEqual(current_node.child_count, 1) 32 | current_node = current_node.children[0] 33 | self.assertEqual(current_node.type, "expression_statement") 34 | 35 | self.assertEqual(current_node.child_count, 1) 36 | current_node = current_node.children[0] 37 | self.assertEqual(current_node.type, "call") 38 | 39 | def test_match_fn(self): 40 | code_ast = ast("def foo():\n bar()", lang = "python") 41 | 42 | current_node = code_ast.root_node() 43 | current_node = current_node.children[0] 44 | 45 | self.assertEqual(current_node.child_count, 5) 46 | self.assertEqual(current_node.children[1].type, "identifier") 47 | self.assertEqual( 48 | code_ast.match(current_node.children[1]), "foo" 49 | ) 50 | 51 | current_node = current_node.children[4] 52 | current_node = current_node.children[0] 53 | current_node = current_node.children[0] 54 | current_node = current_node.children[0] 55 | self.assertEqual(current_node.type, "identifier") 56 | 57 | self.assertEqual( 58 | code_ast.match(current_node), "bar" 59 | ) 60 | 61 | 62 | # Test visitors ------------------------------------------------------------------------------------ 63 | 64 | 65 | class TestVisitor(TestCase): 66 | 67 | def test_count_identifier(self): 68 | code_ast = ast("def foo():\n bar()", lang = "python") 69 | 70 | class IdCounter(ASTVisitor): 71 | 72 | def __init__(self): 73 | self.count = 0 74 | 75 | def visit_identifier(self, node): 76 | self.count += 1 77 | 78 | counter = IdCounter() 79 | code_ast.visit(counter) 80 | self.assertEqual(counter.count, 2) 81 | 82 | def test_count_identifier2(self): 83 | code_ast = ast("def foo(x, y):\n return x + y", lang = "python") 84 | 85 | class IdCounter(ASTVisitor): 86 | 87 | def __init__(self): 88 | self.count = 0 89 | 90 | def visit_identifier(self, node): 91 | self.count += 1 92 | 93 | counter = IdCounter() 94 | code_ast.visit(counter) 95 | self.assertEqual(counter.count, 5) 96 | 97 | 98 | # Test transforms ---------------------------------------------------------------------------------- 99 | 100 | class TestTransformer(TestCase): 101 | 102 | def test_transform_add(self): 103 | code_ast = ast("def foo(x, y):\n return x + y", lang = "python") 104 | 105 | class MirrorAddTransformer(ASTTransformer): 106 | def leave_binary_operator(self, node): 107 | if node.children[1].type == "+": 108 | return FormattedUpdate( 109 | " %s + %s", 110 | [ 111 | TreeUpdate(node.children[2]), 112 | TreeUpdate(node.children[0]) 113 | ] 114 | ) 115 | 116 | mirror_transformer = MirrorAddTransformer() 117 | code_ast.visit(mirror_transformer) 118 | 119 | source_edit = mirror_transformer.edit() 120 | 121 | current_node = source_edit.children[0] 122 | self.assertEqual(current_node.source_node.type, "function_definition") 123 | current_node = current_node.children[-1] 124 | self.assertEqual(current_node.source_node.type, "block") 125 | current_node = current_node.children[0] 126 | self.assertEqual(current_node.source_node.type, "return_statement") 127 | current_node = current_node.children[-1] 128 | self.assertTrue(current_node.target_edit is not None) 129 | 130 | 131 | transformed_code = mirror_transformer.code() 132 | self.assertEqual(transformed_code, "def foo(x, y):\n return y + x") 133 | 134 | 135 | 136 | --------------------------------------------------------------------------------