├── .gitignore ├── tests ├── invalid.py └── simpleone.epy ├── epython ├── __init__.py ├── blackscholes.py ├── validate.py ├── array.epy ├── importer.py ├── cython_backend.py ├── epython.py └── unparse.py ├── conda └── dev.yaml ├── development ├── environment.yaml └── Dockerfile ├── .github └── workflows │ └── main.yaml ├── setup.py ├── cython-ref ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.egg-info/ 3 | -------------------------------------------------------------------------------- /tests/invalid.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def func(a : int) -> float: 4 | b : int = 3 5 | yield a * b* 1.0 6 | 7 | 8 | -------------------------------------------------------------------------------- /tests/simpleone.epy: -------------------------------------------------------------------------------- 1 | 2 | 3 | def func(a : int) -> float: 4 | b : int = 3 5 | return a * b* 1.0 6 | 7 | 8 | -------------------------------------------------------------------------------- /epython/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | 3 | import epython.importer 4 | 5 | from .epython import register_func 6 | -------------------------------------------------------------------------------- /conda/dev.yaml: -------------------------------------------------------------------------------- 1 | name: epython 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - python 3.9.* 7 | -------------------------------------------------------------------------------- /development/environment.yaml: -------------------------------------------------------------------------------- 1 | name: epython-dev 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python==3.9 6 | - pydantic 7 | - pip 8 | - pyyaml -------------------------------------------------------------------------------- /epython/blackscholes.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | import math 3 | 4 | _sqrt2 = math.sqrt(2) 5 | 6 | def cdf(z, mu=0.0, sigma=1.0): 7 | root2_sigma = sigma * _sqrt2 8 | return (statistics.erf((z-mu)/root2_sigma)+1)/2.0 9 | 10 | def callPrice(s, x, r, sigma, t): 11 | a = ((math.log(s/x) + (r + sigma * sigma/2.0) * t) / 12 | (sigma * math.sqrt(t))) 13 | b = a - sigma * math.sqrt(t) 14 | return s * cdf(a) - x * math.exp(-r * t) * cdf(b) 15 | -------------------------------------------------------------------------------- /epython/validate.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | 4 | disallowed_nodes = [ 5 | ast.AsyncFor, 6 | ast.AsyncFunctionDef, 7 | ast.AsyncWith, 8 | ast.Delete, 9 | ast.Raise, 10 | ast.Try, 11 | ast.GeneratorExp, 12 | ast.Await, 13 | ast.Yield, 14 | ast.YieldFrom, 15 | ast.Del, 16 | ast.ExceptHandler, 17 | ast.Starred, 18 | ast.With, 19 | ast.withitem, 20 | ast.Interactive, 21 | ] 22 | 23 | 24 | def validate(code): 25 | for node in ast.walk(code): 26 | if node.__class__ in disallowed_nodes: 27 | info = f"Invalid node {node.__class__}" 28 | if hasattr(node, "lineno"): 29 | info += f" at line {node.lineno}" 30 | return ValueError, info 31 | return None 32 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: epython 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | main: 11 | 12 | runs-on: ubuntu-latest 13 | timeout-minutes: 35 14 | defaults: 15 | run: 16 | shell: bash -l {0} 17 | concurrency: 18 | group: ci-${{ github.ref }} 19 | cancel-in-progress: true 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | 24 | - uses: conda-incubator/setup-miniconda@v2 25 | with: 26 | miniconda-version: "latest" 27 | mamba-version: "*" 28 | environment-file: conda/dev.yaml 29 | channels: conda-forge,nodefaults 30 | activate-environment: epython 31 | use-mamba: true 32 | miniforge-variant: Mambaforge 33 | 34 | - name: installation 35 | run: | 36 | pip install . 37 | epython tests/simpleone.epy --backend=cpython 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | # read the contents of your README file 4 | from os import path 5 | this_directory = path.abspath(path.dirname(__file__)) 6 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 7 | long_description = f.read() 8 | 9 | setup( 10 | version='0.0.5', 11 | name='epython', 12 | url='https://github.com/epython-dev/epython', 13 | description='A typed subset of Python to be used as an extension language', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | author='Quansight Labs', 17 | author_email = 'labs@quansight.com', 18 | license='BSD-3', 19 | packages=['epython'], 20 | entry_points={ 21 | 'console_scripts': [ 22 | 'epython = epython.epython:main', 23 | ], 24 | }, 25 | zip_safe=False, 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: BSD License", 29 | "Operating System :: OS Independent", 30 | ], 31 | python_requires='>=3.9', 32 | ) 33 | -------------------------------------------------------------------------------- /development/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pyodide/pyodide:0.20.0 2 | 3 | # Switch shell to bash 4 | SHELL ["/bin/bash", "-c"] 5 | 6 | # Setup epyhton-dev 7 | RUN mkdir -p /epython/ &&\ 8 | mkdir -p /epython/epython/ 9 | 10 | # Copy Development Files 11 | COPY ./development/environment.yaml . 12 | COPY ./epython/* /epython/epython/ 13 | COPY setup.py /epython/ 14 | COPY README.md /epython/ 15 | COPY tests /epyhton/ 16 | 17 | # Add Conda to PATH 18 | ENV PATH /opt/conda/bin:$PATH 19 | 20 | # Install Miniconda 21 | RUN wget --quiet "https://repo.anaconda.com/miniconda/Miniconda3-py39_4.11.0-Linux-x86_64.sh" -O ~/miniconda.sh &&\ 22 | /bin/bash ~/miniconda.sh -q -b -p /opt/conda &&\ 23 | conda init &&\ 24 | conda env update --quiet -f ./environment.yaml &&\ 25 | echo "conda activate epython-dev" >> /root/.bashrc 26 | 27 | ENV PATH /opt/conda/envs/epyhton-dev/bin:$PATH 28 | 29 | SHELL ["conda", "run", "-n", "epython-dev", "/bin/bash", "-c"] 30 | 31 | RUN cd /epython/ && pip install . &&\ 32 | cd /src/pyodide/ && pip install -e pyodide-build 33 | 34 | ENV PYTHONPATH /src/pyodide/pyodide-build/:$PYTHONPATH 35 | 36 | RUN cd ./pyodide/ && make 37 | 38 | CMD ["python", "-m", "pyodide_build", "serve"] 39 | -------------------------------------------------------------------------------- /cython-ref: -------------------------------------------------------------------------------- 1 | * cimport 2 | * from cimport 3 | * cdef extern from 4 | name (type) 5 | name (type, type) 6 | 7 | In EPython these should all be just import statements (no cimport or cdef) required. 8 | EPython files will have the .epy spelling to ensure they are compiled first to the run-time 9 | requirement. 10 | 11 | * absolute import to access variables in Python name-space 12 | 13 | There will be a python module available so that import python returns a namespace that has 14 | variables available in sys, etc. 15 | 16 | * cdef 17 | 18 | EPython: this uses standard type annotation 19 | 20 | : 21 | 22 | * @cython.internal 23 | * @cython.final 24 | 25 | EPython: change "cython" to "epython" 26 | 27 | @epython.internal and @epython.final 28 | 29 | * cdef class 30 | 31 | This is just standard class 32 | 33 | * cdef public 34 | 35 | * cpdef name (types, ...) 36 | 37 | EPython: By default classes, functions, and types are public and have both run-time specific and Python implementations. 38 | Use @epython.internal to make certain functions not have external visibility. 39 | 40 | * cpdef ( , ... ) except? : 41 | 42 | EPython: Exception handling will be run-time specific and this is not supported yet. 43 | -------------------------------------------------------------------------------- /epython/array.epy: -------------------------------------------------------------------------------- 1 | # Assume Tensors and DataFrames are builtin to the system (i.e. they produce code) 2 | # based on backend selections). 3 | 4 | # A Tensor is a container of a Type. Zero-dimensional tensors exist 5 | # and are analogous to elements of the Type. 6 | # You must index the 0-d array to produce the element. 7 | 8 | # Elements of the Tensor must be explictly extracted using the syntax a[()] 9 | # where a evaluates to a 0-d Tensor. This always copies the element into a new 10 | # object. 11 | 12 | # A Tensor has attributes: shape, ndim, dtype, T, kind 13 | 14 | # A Tensor kind is a meta-type which is one of the Types: 15 | # General > Gamma > Strided > C-contiguous > F-contiguous 16 | # General > Chunked 17 | 18 | # Based on the kind of Tensor we have private fields that are present or not 19 | # GeneralMeta: 20 | index_map : a function of ndim arguments and a state Object that produces 21 | a 0-d tensor. 22 | # General: 23 | index_map : (a function of ndim integer arguments that produces a 0-d tensor) 24 | # Gamma adds: 25 | gamma_map : (a function of ndim integer arguments that produces an integer) 26 | datapointer : a 1-d array of bytes 27 | # Strided adds: 28 | strides : a Tuple of integers indicating the number of bytes to jump to get 29 | # C-Contiguous: 30 | no additional information, the other attributes functions are computed as needed. 31 | # F-Contiguous: 32 | no additional information, the other attributes are computed as needed. 33 | 34 | Also builtin to the system are "generalized universal functions" 35 | 36 | These operate on -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, extpython 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # epython 2 | 3 | EPython is a typed-subset of the Python language useful for extending the language with new builtin types and methods. 4 | 5 | The goal is to be able to write things like NumPy, SciPy, Pandas, bitarray, and any other extension module of Python in this language and get equivalent or better perfomance than writing it using the C-API typically provided. 6 | 7 | This project is in development and extremely alpha. You should not use this for anything. 8 | 9 | Learn more on this talk [Travis Oliphant gave at PyData Austin 2019](https://www.youtube.com/watch?v=Z8vsTxzmorE). 10 | 11 | 12 | If you are interested in contributing to the design and goals, then join the Discussion at [OpenTeams Slack](https://openteams.com/projects/epython) 13 | 14 | 15 | # Installation 16 | 17 | ```bash 18 | pip install epython 19 | ``` 20 | 21 | # Usage 22 | 23 | ```bash 24 | epython extmodule.epy --backend=cpython 25 | ``` 26 | 27 | Produces a compiled extension module for the given Python backend. 28 | 29 | ## Docker Development 30 | ---------------------- 31 | 32 | Install Docker, then run: 33 | 34 | `docker build -t epython-wasm -f ./development/Dockerfile .` 35 | 36 | From the root of the repository. 37 | 38 | To run the interactive session: 39 | 40 | `docker run -p 8008:8000 -t epython-wasm:latest ` 41 | 42 | # Development 43 | 44 | Create an environment for **epython**: 45 | 46 | ```bash 47 | $ conda env create --file conda/dev.yaml 48 | ``` 49 | 50 | Activate the **epython** environment: 51 | 52 | ```bash 53 | $ conda activate epython 54 | ``` 55 | 56 | Install it locally in development mode: 57 | 58 | ```bash 59 | $ pip install -e . 60 | ``` 61 | -------------------------------------------------------------------------------- /epython/importer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Importing this module allows import .epy files like .py files. 3 | 4 | The main reason for naming epython files with the .epy file extension is to 5 | avoid confusion with regular Python modules. A package may contain a number 6 | of (sub-) modules of which only some are epython extensions. 7 | 8 | For development of epython packages, it is nevertheless very useful to import 9 | .epy files just like .pt files, which is possible by simply importing epython 10 | first. E.g. 11 | 12 | import epython 13 | import myext # will import myext.epy 14 | 15 | Without importing epython first, myext will not work, which helps to avoid 16 | using epython extensions as pure Python modules (which will be quite slow). 17 | """ 18 | import sys 19 | import imp 20 | from os.path import isfile, join 21 | 22 | 23 | class EPY_Importer(object): 24 | 25 | def find_module(self, fullname, path=None): 26 | name = fullname.rsplit('.', 1)[-1] 27 | for dir_path in path or sys.path: 28 | self.path = join(dir_path, name + '.epy') 29 | if isfile(self.path): 30 | self.modtype = imp.PY_SOURCE 31 | return self 32 | return None 33 | 34 | def load_module(self, fullname): 35 | if fullname in sys.modules: 36 | return sys.modules[fullname] 37 | 38 | mod = imp.new_module(fullname) 39 | mod.__file__ = self.path 40 | mod.__loader__ = self 41 | with open(self.path, 'rb') as fi: 42 | code = fi.read() 43 | 44 | exec(code, mod.__dict__) 45 | sys.modules[fullname] = mod 46 | return mod 47 | 48 | 49 | sys.meta_path.append(EPY_Importer()) 50 | -------------------------------------------------------------------------------- /epython/cython_backend.py: -------------------------------------------------------------------------------- 1 | from .unparse import Unparser 2 | 3 | 4 | # For the Cython generator we re-use the Unparser code generator 5 | # and generate Cython code instead for nodes that have type information 6 | class CythonGenerator(Unparser): 7 | 8 | def get_type_from_comment(self, node): 9 | tp = self._type_ignores.get(node.lineno) or node.type_comment 10 | if tp is not None: 11 | return tp 12 | 13 | def visit_Assign(self, node): 14 | self.fill() 15 | for target in node.targets: 16 | if type_comment := self.get_type_from_comment(node): 17 | self.write(f"cdef {type_comment} ") 18 | self.traverse(target) 19 | self.write(" = ") 20 | self.traverse(node.value) 21 | 22 | def visit_AnnAssign(self, node): 23 | self.fill() 24 | self.write("cdef ") 25 | self.traverse(node.annotation) 26 | self.write(" ") 27 | with self.delimit_if("(", ")", not node.simple and 28 | isinstance(node.target, Name)): 29 | self.traverse(node.target) 30 | if node.value: 31 | self.write(" = ") 32 | self.traverse(node.value) 33 | 34 | def _function_helper(self, node, fill_suffix): 35 | self.maybe_newline() 36 | for deco in node.decorator_list: 37 | self.fill("@") 38 | self.traverse(deco) 39 | if node.returns: 40 | self.fill("cdef ") 41 | self.traverse(node.returns) 42 | self.write(" " + node.name) 43 | else: 44 | def_str = fill_suffix + " " + node.name 45 | self.fill(def_str) 46 | with self.delimit("(", ")"): 47 | self.traverse(node.args) 48 | with self.block(extra=self.get_type_from_comment(node)): 49 | self._write_docstring_and_traverse_body(node) 50 | 51 | def visit_arg(self, node): 52 | if node.annotation: 53 | self.traverse(node.annotation) 54 | self.write(" ") 55 | self.write(node.arg) 56 | -------------------------------------------------------------------------------- /epython/epython.py: -------------------------------------------------------------------------------- 1 | """ 2 | EPython is a code-transformer that translates a statically typed subset of 3 | Python syntax into an extension of Python for a particular backend. 4 | 5 | The .epy file is first compiled into an AST 6 | The AST is validated to ensure it uses only the allowed subset of Python 7 | The AST is then fed to a transformer specific to the backend. 8 | 9 | """ 10 | import argparse 11 | import ast 12 | import os.path 13 | 14 | from epython import __version__ 15 | from .validate import validate 16 | 17 | # See https://greentreesnakes.readthedocs.io/en/latest/nodes.html 18 | 19 | _registry = {} 20 | 21 | def register_func(name_or_func): 22 | if isinstance(name_or_func, str): 23 | name = name_or_func 24 | func = None 25 | else: 26 | func = name_or_func 27 | name = func.__name__ 28 | if func is None: 29 | def decorator(new_func): 30 | _registry[name] = new_func 31 | return new_func 32 | return decorator 33 | else: 34 | _registry[name] = func 35 | return func 36 | 37 | # A transformation function needs to take as agruments 38 | # ast: the validated ast of the code 39 | # filename: the name to generate the artefacts 40 | # 41 | # It returns the PATH (or URL) of the created artefact 42 | 43 | @register_func('cpython') 44 | def transform(ast, name): 45 | return name + '.so' 46 | 47 | # @register_func 48 | # def pypy(mine): 49 | # return mine 50 | 51 | def main(): 52 | find_backends() 53 | parser = argparse.ArgumentParser(prog='epython', 54 | description="Compile statically typed subset of Python to a backend.") 55 | parser.add_argument("file") 56 | parser.add_argument("--backend", default="cpython") 57 | parser.add_argument("--name", default="none") 58 | parser.add_argument("--version", action='version', 59 | version='%(prog)s ' + __version__) 60 | args = parser.parse_args() 61 | 62 | if args.name == 'none': 63 | name = os.path.splitext(args.file)[0] 64 | else: 65 | name = args.name 66 | 67 | with open(args.file) as fi: 68 | source = fi.read() 69 | 70 | code = ast.parse(source, name, 'exec', type_comments=True) 71 | result = validate(code) 72 | if result is not None: 73 | raise result[0](result[1]) 74 | 75 | try: 76 | transformer = _registry[args.backend] 77 | except KeyError: 78 | raise RuntimeError(f"There is no epython backend registered for {args.backend}.") 79 | 80 | output = transformer(code, name) 81 | 82 | from .cython_backend import CythonGenerator 83 | translator = CythonGenerator() 84 | print(translator.visit(code)) 85 | 86 | 87 | # importing the backend should be sufficient to call the decorator(s) 88 | # that registers the function in _registry which is why the 89 | # dictionary created here is not returned or seemingly unused. 90 | def find_backends(): 91 | import importlib 92 | import pkgutil 93 | 94 | # importing the module registers the function. 95 | discovered_plugins = { 96 | name: importlib.import_module(name) 97 | for finder, name, ispkg in pkgutil.iter_modules() 98 | if name.startswith('epython-') 99 | } 100 | 101 | if len(discovered_plugins) > len(_registry): 102 | print("Registry: ") 103 | print(_registry) 104 | print("\n\nPlugin Modules Found: ") 105 | print(discovered_plugins) 106 | raise (ValueError, "The number of Plugin Modules Found is larger " 107 | "than the number of transformations successfully registered.") 108 | 109 | 110 | if __name__ == "__main__": 111 | code = main() 112 | -------------------------------------------------------------------------------- /epython/unparse.py: -------------------------------------------------------------------------------- 1 | # This module contains the Unparser class (copied from the Python 3.9.5 ast 2 | # module). 3 | # By sub-classing Unparser, you can target other languages which are close 4 | # to Python, such as Cython. 5 | import ast 6 | import sys 7 | import tokenize 8 | from ast import (NodeVisitor, AsyncFunctionDef, FunctionDef, ClassDef, Module, 9 | Expr, Constant, Tuple, Name, If, Starred) 10 | from contextlib import contextmanager, nullcontext 11 | from enum import IntEnum, auto 12 | 13 | 14 | # Large float and imaginary literals get turned into infinities in the AST. 15 | # We unparse those infinities to INFSTR. 16 | _INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) 17 | 18 | class _Precedence(IntEnum): 19 | """Precedence table that originated from python grammar.""" 20 | 21 | TUPLE = auto() 22 | YIELD = auto() # 'yield', 'yield from' 23 | TEST = auto() # 'if'-'else', 'lambda' 24 | OR = auto() # 'or' 25 | AND = auto() # 'and' 26 | NOT = auto() # 'not' 27 | CMP = auto() # '<', '>', '==', '>=', '<=', '!=', 28 | # 'in', 'not in', 'is', 'is not' 29 | EXPR = auto() 30 | BOR = EXPR # '|' 31 | BXOR = auto() # '^' 32 | BAND = auto() # '&' 33 | SHIFT = auto() # '<<', '>>' 34 | ARITH = auto() # '+', '-' 35 | TERM = auto() # '*', '@', '/', '%', '//' 36 | FACTOR = auto() # unary '+', '-', '~' 37 | POWER = auto() # '**' 38 | AWAIT = auto() # 'await' 39 | ATOM = auto() 40 | 41 | def next(self): 42 | try: 43 | return self.__class__(self + 1) 44 | except ValueError: 45 | return self 46 | 47 | 48 | _SINGLE_QUOTES = ("'", '"') 49 | _MULTI_QUOTES = ('"""', "'''") 50 | _ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES) 51 | 52 | class Unparser(NodeVisitor): 53 | """Methods in this class recursively traverse an AST and 54 | output source code for the abstract syntax; original formatting 55 | is disregarded.""" 56 | 57 | def __init__(self, *, _avoid_backslashes=False): 58 | self._source = [] 59 | self._buffer = [] 60 | self._precedences = {} 61 | self._type_ignores = {} 62 | self._indent = 0 63 | self._avoid_backslashes = _avoid_backslashes 64 | 65 | def interleave(self, inter, f, seq): 66 | """Call f on each item in seq, calling inter() in between.""" 67 | seq = iter(seq) 68 | try: 69 | f(next(seq)) 70 | except StopIteration: 71 | pass 72 | else: 73 | for x in seq: 74 | inter() 75 | f(x) 76 | 77 | def items_view(self, traverser, items): 78 | """Traverse and separate the given *items* with a comma and append it to 79 | the buffer. If *items* is a single item sequence, a trailing comma 80 | will be added.""" 81 | if len(items) == 1: 82 | traverser(items[0]) 83 | self.write(",") 84 | else: 85 | self.interleave(lambda: self.write(", "), traverser, items) 86 | 87 | def maybe_newline(self): 88 | """Adds a newline if it isn't the start of generated source""" 89 | if self._source: 90 | self.write("\n") 91 | 92 | def fill(self, text=""): 93 | """Indent a piece of text and append it, according to the current 94 | indentation level""" 95 | self.maybe_newline() 96 | self.write(" " * self._indent + text) 97 | 98 | def write(self, text): 99 | """Append a piece of text""" 100 | self._source.append(text) 101 | 102 | def buffer_writer(self, text): 103 | self._buffer.append(text) 104 | 105 | @property 106 | def buffer(self): 107 | value = "".join(self._buffer) 108 | self._buffer.clear() 109 | return value 110 | 111 | @contextmanager 112 | def block(self, *, extra = None): 113 | """A context manager for preparing the source for blocks. It adds 114 | the character':', increases the indentation on enter and decreases 115 | the indentation on exit. If *extra* is given, it will be directly 116 | appended after the colon character. 117 | """ 118 | self.write(":") 119 | if extra: 120 | self.write(extra) 121 | self._indent += 1 122 | yield 123 | self._indent -= 1 124 | 125 | @contextmanager 126 | def delimit(self, start, end): 127 | """A context manager for preparing the source for expressions. It adds 128 | *start* to the buffer and enters, after exit it adds *end*.""" 129 | 130 | self.write(start) 131 | yield 132 | self.write(end) 133 | 134 | def delimit_if(self, start, end, condition): 135 | if condition: 136 | return self.delimit(start, end) 137 | else: 138 | return nullcontext() 139 | 140 | def require_parens(self, precedence, node): 141 | """Shortcut to adding precedence related parens""" 142 | return self.delimit_if("(", ")", self.get_precedence(node) > precedence) 143 | 144 | def get_precedence(self, node): 145 | return self._precedences.get(node, _Precedence.TEST) 146 | 147 | def set_precedence(self, precedence, *nodes): 148 | for node in nodes: 149 | self._precedences[node] = precedence 150 | 151 | def get_raw_docstring(self, node): 152 | """If a docstring node is found in the body of the *node* parameter, 153 | return that docstring node, None otherwise. 154 | 155 | Logic mirrored from ``_PyAST_GetDocString``.""" 156 | if not isinstance( 157 | node, (AsyncFunctionDef, FunctionDef, ClassDef, Module) 158 | ) or len(node.body) < 1: 159 | return None 160 | node = node.body[0] 161 | if not isinstance(node, Expr): 162 | return None 163 | node = node.value 164 | if isinstance(node, Constant) and isinstance(node.value, str): 165 | return node 166 | 167 | def get_type_comment(self, node): 168 | comment = self._type_ignores.get(node.lineno) or node.type_comment 169 | if comment is not None: 170 | return f" # type: {comment}" 171 | 172 | def traverse(self, node): 173 | if isinstance(node, list): 174 | for item in node: 175 | self.traverse(item) 176 | else: 177 | super().visit(node) 178 | 179 | def visit(self, node): 180 | """Outputs a source code string that, if converted back to an ast 181 | (using ast.parse) will generate an AST equivalent to *node*""" 182 | self._source = [] 183 | self.traverse(node) 184 | return "".join(self._source) 185 | 186 | def _write_docstring_and_traverse_body(self, node): 187 | if (docstring := self.get_raw_docstring(node)): 188 | self._write_docstring(docstring) 189 | self.traverse(node.body[1:]) 190 | else: 191 | self.traverse(node.body) 192 | 193 | def visit_Module(self, node): 194 | self._type_ignores = { 195 | ignore.lineno: f"ignore{ignore.tag}" 196 | for ignore in node.type_ignores 197 | } 198 | self._write_docstring_and_traverse_body(node) 199 | self._type_ignores.clear() 200 | 201 | def visit_FunctionType(self, node): 202 | with self.delimit("(", ")"): 203 | self.interleave( 204 | lambda: self.write(", "), self.traverse, node.argtypes 205 | ) 206 | 207 | self.write(" -> ") 208 | self.traverse(node.returns) 209 | 210 | def visit_Expr(self, node): 211 | self.fill() 212 | self.set_precedence(_Precedence.YIELD, node.value) 213 | self.traverse(node.value) 214 | 215 | def visit_NamedExpr(self, node): 216 | with self.require_parens(_Precedence.TUPLE, node): 217 | self.set_precedence(_Precedence.ATOM, node.target, node.value) 218 | self.traverse(node.target) 219 | self.write(" := ") 220 | self.traverse(node.value) 221 | 222 | def visit_Import(self, node): 223 | self.fill("import ") 224 | self.interleave(lambda: self.write(", "), self.traverse, node.names) 225 | 226 | def visit_ImportFrom(self, node): 227 | self.fill("from ") 228 | self.write("." * node.level) 229 | if node.module: 230 | self.write(node.module) 231 | self.write(" import ") 232 | self.interleave(lambda: self.write(", "), self.traverse, node.names) 233 | 234 | def visit_Assign(self, node): 235 | self.fill() 236 | for target in node.targets: 237 | self.traverse(target) 238 | self.write(" = ") 239 | self.traverse(node.value) 240 | if type_comment := self.get_type_comment(node): 241 | self.write(type_comment) 242 | 243 | def visit_AugAssign(self, node): 244 | self.fill() 245 | self.traverse(node.target) 246 | self.write(" " + self.binop[node.op.__class__.__name__] + "= ") 247 | self.traverse(node.value) 248 | 249 | def visit_AnnAssign(self, node): 250 | self.fill() 251 | with self.delimit_if("(", ")", not node.simple and 252 | isinstance(node.target, Name)): 253 | self.traverse(node.target) 254 | self.write(": ") 255 | self.traverse(node.annotation) 256 | if node.value: 257 | self.write(" = ") 258 | self.traverse(node.value) 259 | 260 | def visit_Return(self, node): 261 | self.fill("return") 262 | if node.value: 263 | self.write(" ") 264 | self.traverse(node.value) 265 | 266 | def visit_Pass(self, node): 267 | self.fill("pass") 268 | 269 | def visit_Break(self, node): 270 | self.fill("break") 271 | 272 | def visit_Continue(self, node): 273 | self.fill("continue") 274 | 275 | def visit_Delete(self, node): 276 | self.fill("del ") 277 | self.interleave(lambda: self.write(", "), self.traverse, node.targets) 278 | 279 | def visit_Assert(self, node): 280 | self.fill("assert ") 281 | self.traverse(node.test) 282 | if node.msg: 283 | self.write(", ") 284 | self.traverse(node.msg) 285 | 286 | def visit_Global(self, node): 287 | self.fill("global ") 288 | self.interleave(lambda: self.write(", "), self.write, node.names) 289 | 290 | def visit_Nonlocal(self, node): 291 | self.fill("nonlocal ") 292 | self.interleave(lambda: self.write(", "), self.write, node.names) 293 | 294 | def visit_Await(self, node): 295 | with self.require_parens(_Precedence.AWAIT, node): 296 | self.write("await") 297 | if node.value: 298 | self.write(" ") 299 | self.set_precedence(_Precedence.ATOM, node.value) 300 | self.traverse(node.value) 301 | 302 | def visit_Yield(self, node): 303 | with self.require_parens(_Precedence.YIELD, node): 304 | self.write("yield") 305 | if node.value: 306 | self.write(" ") 307 | self.set_precedence(_Precedence.ATOM, node.value) 308 | self.traverse(node.value) 309 | 310 | def visit_YieldFrom(self, node): 311 | with self.require_parens(_Precedence.YIELD, node): 312 | self.write("yield from ") 313 | if not node.value: 314 | raise ValueError("Node can't be used without a value " 315 | "attribute.") 316 | self.set_precedence(_Precedence.ATOM, node.value) 317 | self.traverse(node.value) 318 | 319 | def visit_Raise(self, node): 320 | self.fill("raise") 321 | if not node.exc: 322 | if node.cause: 323 | raise ValueError("Node can't use cause without an exception.") 324 | return 325 | self.write(" ") 326 | self.traverse(node.exc) 327 | if node.cause: 328 | self.write(" from ") 329 | self.traverse(node.cause) 330 | 331 | def visit_Try(self, node): 332 | self.fill("try") 333 | with self.block(): 334 | self.traverse(node.body) 335 | for ex in node.handlers: 336 | self.traverse(ex) 337 | if node.orelse: 338 | self.fill("else") 339 | with self.block(): 340 | self.traverse(node.orelse) 341 | if node.finalbody: 342 | self.fill("finally") 343 | with self.block(): 344 | self.traverse(node.finalbody) 345 | 346 | def visit_ExceptHandler(self, node): 347 | self.fill("except") 348 | if node.type: 349 | self.write(" ") 350 | self.traverse(node.type) 351 | if node.name: 352 | self.write(" as ") 353 | self.write(node.name) 354 | with self.block(): 355 | self.traverse(node.body) 356 | 357 | def visit_ClassDef(self, node): 358 | self.maybe_newline() 359 | for deco in node.decorator_list: 360 | self.fill("@") 361 | self.traverse(deco) 362 | self.fill("class " + node.name) 363 | with self.delimit_if("(", ")", condition = node.bases or node.keywords): 364 | comma = False 365 | for e in node.bases: 366 | if comma: 367 | self.write(", ") 368 | else: 369 | comma = True 370 | self.traverse(e) 371 | for e in node.keywords: 372 | if comma: 373 | self.write(", ") 374 | else: 375 | comma = True 376 | self.traverse(e) 377 | 378 | with self.block(): 379 | self._write_docstring_and_traverse_body(node) 380 | 381 | def visit_FunctionDef(self, node): 382 | self._function_helper(node, "def") 383 | 384 | def visit_AsyncFunctionDef(self, node): 385 | self._function_helper(node, "async def") 386 | 387 | def _function_helper(self, node, fill_suffix): 388 | self.maybe_newline() 389 | for deco in node.decorator_list: 390 | self.fill("@") 391 | self.traverse(deco) 392 | def_str = fill_suffix + " " + node.name 393 | self.fill(def_str) 394 | with self.delimit("(", ")"): 395 | self.traverse(node.args) 396 | if node.returns: 397 | self.write(" -> ") 398 | self.traverse(node.returns) 399 | with self.block(extra=self.get_type_comment(node)): 400 | self._write_docstring_and_traverse_body(node) 401 | 402 | def visit_For(self, node): 403 | self._for_helper("for ", node) 404 | 405 | def visit_AsyncFor(self, node): 406 | self._for_helper("async for ", node) 407 | 408 | def _for_helper(self, fill, node): 409 | self.fill(fill) 410 | self.traverse(node.target) 411 | self.write(" in ") 412 | self.traverse(node.iter) 413 | with self.block(extra=self.get_type_comment(node)): 414 | self.traverse(node.body) 415 | if node.orelse: 416 | self.fill("else") 417 | with self.block(): 418 | self.traverse(node.orelse) 419 | 420 | def visit_If(self, node): 421 | self.fill("if ") 422 | self.traverse(node.test) 423 | with self.block(): 424 | self.traverse(node.body) 425 | # collapse nested ifs into equivalent elifs. 426 | while (node.orelse and len(node.orelse) == 1 and 427 | isinstance(node.orelse[0], If)): 428 | node = node.orelse[0] 429 | self.fill("elif ") 430 | self.traverse(node.test) 431 | with self.block(): 432 | self.traverse(node.body) 433 | # final else 434 | if node.orelse: 435 | self.fill("else") 436 | with self.block(): 437 | self.traverse(node.orelse) 438 | 439 | def visit_While(self, node): 440 | self.fill("while ") 441 | self.traverse(node.test) 442 | with self.block(): 443 | self.traverse(node.body) 444 | if node.orelse: 445 | self.fill("else") 446 | with self.block(): 447 | self.traverse(node.orelse) 448 | 449 | def visit_With(self, node): 450 | self.fill("with ") 451 | self.interleave(lambda: self.write(", "), self.traverse, node.items) 452 | with self.block(extra=self.get_type_comment(node)): 453 | self.traverse(node.body) 454 | 455 | def visit_AsyncWith(self, node): 456 | self.fill("async with ") 457 | self.interleave(lambda: self.write(", "), self.traverse, node.items) 458 | with self.block(extra=self.get_type_comment(node)): 459 | self.traverse(node.body) 460 | 461 | def _str_literal_helper( 462 | self, string, *, quote_types=_ALL_QUOTES, 463 | escape_special_whitespace=False 464 | ): 465 | """Helper for writing string literals, minimizing escapes. 466 | Returns the tuple (string literal to write, possible quote types). 467 | """ 468 | def escape_char(c): 469 | # \n and \t are non-printable, but we only escape them if 470 | # escape_special_whitespace is True 471 | if not escape_special_whitespace and c in "\n\t": 472 | return c 473 | # Always escape backslashes and other non-printable characters 474 | if c == "\\" or not c.isprintable(): 475 | return c.encode("unicode_escape").decode("ascii") 476 | return c 477 | 478 | escaped_string = "".join(map(escape_char, string)) 479 | possible_quotes = quote_types 480 | if "\n" in escaped_string: 481 | possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES] 482 | possible_quotes = [q for q in possible_quotes 483 | if q not in escaped_string] 484 | if not possible_quotes: 485 | # If there aren't any possible_quotes, fallback to using repr 486 | # on the original string. Try to use a quote from quote_types, 487 | # e.g., so that we use triple quotes for docstrings. 488 | string = repr(string) 489 | quote = next((q for q in quote_types if string[0] in q), string[0]) 490 | return string[1:-1], [quote] 491 | if escaped_string: 492 | # Sort so that we prefer '''"''' over """\"""" 493 | possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1]) 494 | # If we're using triple quotes and we'd need to escape a final 495 | # quote, escape it 496 | if possible_quotes[0][0] == escaped_string[-1]: 497 | assert len(possible_quotes[0]) == 3 498 | escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1] 499 | return escaped_string, possible_quotes 500 | 501 | def _write_str_avoiding_backslashes(self, string, *, 502 | quote_types=_ALL_QUOTES): 503 | """Write string literal value with a best effort attempt to avoid 504 | backslashes.""" 505 | string, quote_types = self._str_literal_helper(string, 506 | quote_types=quote_types) 507 | quote_type = quote_types[0] 508 | self.write(f"{quote_type}{string}{quote_type}") 509 | 510 | def visit_JoinedStr(self, node): 511 | self.write("f") 512 | if self._avoid_backslashes: 513 | self._fstring_JoinedStr(node, self.buffer_writer) 514 | self._write_str_avoiding_backslashes(self.buffer) 515 | return 516 | 517 | # If we don't need to avoid backslashes globally (i.e., we only need 518 | # to avoid them inside FormattedValues), it's cosmetically preferred 519 | # to use escaped whitespace. That is, it's preferred to use backslashes 520 | # for cases like: f"{x}\n". To accomplish this, we keep track of what 521 | # in our buffer corresponds to FormattedValues and what corresponds to 522 | # Constant parts of the f-string, and allow escapes accordingly. 523 | buffer = [] 524 | for value in node.values: 525 | meth = getattr(self, "_fstring_" + type(value).__name__) 526 | meth(value, self.buffer_writer) 527 | buffer.append((self.buffer, isinstance(value, Constant))) 528 | new_buffer = [] 529 | quote_types = _ALL_QUOTES 530 | for value, is_constant in buffer: 531 | # Repeatedly narrow down the list of possible quote_types 532 | value, quote_types = self._str_literal_helper( 533 | value, quote_types=quote_types, 534 | escape_special_whitespace=is_constant 535 | ) 536 | new_buffer.append(value) 537 | value = "".join(new_buffer) 538 | quote_type = quote_types[0] 539 | self.write(f"{quote_type}{value}{quote_type}") 540 | 541 | def visit_FormattedValue(self, node): 542 | self.write("f") 543 | self._fstring_FormattedValue(node, self.buffer_writer) 544 | self._write_str_avoiding_backslashes(self.buffer) 545 | 546 | def _fstring_JoinedStr(self, node, write): 547 | for value in node.values: 548 | meth = getattr(self, "_fstring_" + type(value).__name__) 549 | meth(value, write) 550 | 551 | def _fstring_Constant(self, node, write): 552 | if not isinstance(node.value, str): 553 | raise ValueError("Constants inside JoinedStr should be a string.") 554 | value = node.value.replace("{", "{{").replace("}", "}}") 555 | write(value) 556 | 557 | def _fstring_FormattedValue(self, node, write): 558 | write("{") 559 | unparser = type(self)(_avoid_backslashes=True) 560 | unparser.set_precedence(_Precedence.TEST.next(), node.value) 561 | expr = unparser.visit(node.value) 562 | if expr.startswith("{"): 563 | write(" ") # Separate pair of opening brackets as "{ {" 564 | if "\\" in expr: 565 | raise ValueError("Unable to avoid backslash in f-string " 566 | "expression part") 567 | write(expr) 568 | if node.conversion != -1: 569 | conversion = chr(node.conversion) 570 | if conversion not in "sra": 571 | raise ValueError("Unknown f-string conversion.") 572 | write(f"!{conversion}") 573 | if node.format_spec: 574 | write(":") 575 | meth = getattr(self, "_fstring_" + type(node.format_spec).__name__) 576 | meth(node.format_spec, write) 577 | write("}") 578 | 579 | def visit_Name(self, node): 580 | self.write(node.id) 581 | 582 | def _write_docstring(self, node): 583 | self.fill() 584 | if node.kind == "u": 585 | self.write("u") 586 | self._write_str_avoiding_backslashes(node.value, 587 | quote_types=_MULTI_QUOTES) 588 | 589 | def _write_constant(self, value): 590 | if isinstance(value, (float, complex)): 591 | # Substitute overflowing decimal literal for AST infinities, 592 | # and inf - inf for NaNs. 593 | self.write( 594 | repr(value) 595 | .replace("inf", _INFSTR) 596 | .replace("nan", f"({_INFSTR}-{_INFSTR})") 597 | ) 598 | elif self._avoid_backslashes and isinstance(value, str): 599 | self._write_str_avoiding_backslashes(value) 600 | else: 601 | self.write(repr(value)) 602 | 603 | def visit_Constant(self, node): 604 | value = node.value 605 | if isinstance(value, tuple): 606 | with self.delimit("(", ")"): 607 | self.items_view(self._write_constant, value) 608 | elif value is ...: 609 | self.write("...") 610 | else: 611 | if node.kind == "u": 612 | self.write("u") 613 | self._write_constant(node.value) 614 | 615 | def visit_List(self, node): 616 | with self.delimit("[", "]"): 617 | self.interleave(lambda: self.write(", "), self.traverse, node.elts) 618 | 619 | def visit_ListComp(self, node): 620 | with self.delimit("[", "]"): 621 | self.traverse(node.elt) 622 | for gen in node.generators: 623 | self.traverse(gen) 624 | 625 | def visit_GeneratorExp(self, node): 626 | with self.delimit("(", ")"): 627 | self.traverse(node.elt) 628 | for gen in node.generators: 629 | self.traverse(gen) 630 | 631 | def visit_SetComp(self, node): 632 | with self.delimit("{", "}"): 633 | self.traverse(node.elt) 634 | for gen in node.generators: 635 | self.traverse(gen) 636 | 637 | def visit_DictComp(self, node): 638 | with self.delimit("{", "}"): 639 | self.traverse(node.key) 640 | self.write(": ") 641 | self.traverse(node.value) 642 | for gen in node.generators: 643 | self.traverse(gen) 644 | 645 | def visit_comprehension(self, node): 646 | if node.is_async: 647 | self.write(" async for ") 648 | else: 649 | self.write(" for ") 650 | self.set_precedence(_Precedence.TUPLE, node.target) 651 | self.traverse(node.target) 652 | self.write(" in ") 653 | self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs) 654 | self.traverse(node.iter) 655 | for if_clause in node.ifs: 656 | self.write(" if ") 657 | self.traverse(if_clause) 658 | 659 | def visit_IfExp(self, node): 660 | with self.require_parens(_Precedence.TEST, node): 661 | self.set_precedence(_Precedence.TEST.next(), node.body, node.test) 662 | self.traverse(node.body) 663 | self.write(" if ") 664 | self.traverse(node.test) 665 | self.write(" else ") 666 | self.set_precedence(_Precedence.TEST, node.orelse) 667 | self.traverse(node.orelse) 668 | 669 | def visit_Set(self, node): 670 | if node.elts: 671 | with self.delimit("{", "}"): 672 | self.interleave(lambda: self.write(", "), self.traverse, 673 | node.elts) 674 | else: 675 | # `{}` would be interpreted as a dictionary literal, and 676 | # `set` might be shadowed. Thus: 677 | self.write('{*()}') 678 | 679 | def visit_Dict(self, node): 680 | def write_key_value_pair(k, v): 681 | self.traverse(k) 682 | self.write(": ") 683 | self.traverse(v) 684 | 685 | def write_item(item): 686 | k, v = item 687 | if k is None: 688 | # for dictionary unpacking operator in dicts {**{'y': 2}} 689 | # see PEP 448 for details 690 | self.write("**") 691 | self.set_precedence(_Precedence.EXPR, v) 692 | self.traverse(v) 693 | else: 694 | write_key_value_pair(k, v) 695 | 696 | with self.delimit("{", "}"): 697 | self.interleave( 698 | lambda: self.write(", "), write_item, 699 | zip(node.keys, node.values) 700 | ) 701 | 702 | def visit_Tuple(self, node): 703 | with self.delimit("(", ")"): 704 | self.items_view(self.traverse, node.elts) 705 | 706 | unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} 707 | unop_precedence = { 708 | "not": _Precedence.NOT, 709 | "~": _Precedence.FACTOR, 710 | "+": _Precedence.FACTOR, 711 | "-": _Precedence.FACTOR, 712 | } 713 | 714 | def visit_UnaryOp(self, node): 715 | operator = self.unop[node.op.__class__.__name__] 716 | operator_precedence = self.unop_precedence[operator] 717 | with self.require_parens(operator_precedence, node): 718 | self.write(operator) 719 | # factor prefixes (+, -, ~) shouldn't be seperated 720 | # from the value they belong, (e.g: +1 instead of + 1) 721 | if operator_precedence is not _Precedence.FACTOR: 722 | self.write(" ") 723 | self.set_precedence(operator_precedence, node.operand) 724 | self.traverse(node.operand) 725 | 726 | binop = { 727 | "Add": "+", 728 | "Sub": "-", 729 | "Mult": "*", 730 | "MatMult": "@", 731 | "Div": "/", 732 | "Mod": "%", 733 | "LShift": "<<", 734 | "RShift": ">>", 735 | "BitOr": "|", 736 | "BitXor": "^", 737 | "BitAnd": "&", 738 | "FloorDiv": "//", 739 | "Pow": "**", 740 | } 741 | 742 | binop_precedence = { 743 | "+": _Precedence.ARITH, 744 | "-": _Precedence.ARITH, 745 | "*": _Precedence.TERM, 746 | "@": _Precedence.TERM, 747 | "/": _Precedence.TERM, 748 | "%": _Precedence.TERM, 749 | "<<": _Precedence.SHIFT, 750 | ">>": _Precedence.SHIFT, 751 | "|": _Precedence.BOR, 752 | "^": _Precedence.BXOR, 753 | "&": _Precedence.BAND, 754 | "//": _Precedence.TERM, 755 | "**": _Precedence.POWER, 756 | } 757 | 758 | binop_rassoc = frozenset(("**",)) 759 | def visit_BinOp(self, node): 760 | operator = self.binop[node.op.__class__.__name__] 761 | operator_precedence = self.binop_precedence[operator] 762 | with self.require_parens(operator_precedence, node): 763 | if operator in self.binop_rassoc: 764 | left_precedence = operator_precedence.next() 765 | right_precedence = operator_precedence 766 | else: 767 | left_precedence = operator_precedence 768 | right_precedence = operator_precedence.next() 769 | 770 | self.set_precedence(left_precedence, node.left) 771 | self.traverse(node.left) 772 | self.write(f" {operator} ") 773 | self.set_precedence(right_precedence, node.right) 774 | self.traverse(node.right) 775 | 776 | cmpops = { 777 | "Eq": "==", 778 | "NotEq": "!=", 779 | "Lt": "<", 780 | "LtE": "<=", 781 | "Gt": ">", 782 | "GtE": ">=", 783 | "Is": "is", 784 | "IsNot": "is not", 785 | "In": "in", 786 | "NotIn": "not in", 787 | } 788 | 789 | def visit_Compare(self, node): 790 | with self.require_parens(_Precedence.CMP, node): 791 | self.set_precedence(_Precedence.CMP.next(), node.left, 792 | *node.comparators) 793 | self.traverse(node.left) 794 | for o, e in zip(node.ops, node.comparators): 795 | self.write(" " + self.cmpops[o.__class__.__name__] + " ") 796 | self.traverse(e) 797 | 798 | boolops = {"And": "and", "Or": "or"} 799 | boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR} 800 | 801 | def visit_BoolOp(self, node): 802 | operator = self.boolops[node.op.__class__.__name__] 803 | operator_precedence = self.boolop_precedence[operator] 804 | 805 | def increasing_level_traverse(node): 806 | nonlocal operator_precedence 807 | operator_precedence = operator_precedence.next() 808 | self.set_precedence(operator_precedence, node) 809 | self.traverse(node) 810 | 811 | with self.require_parens(operator_precedence, node): 812 | s = f" {operator} " 813 | self.interleave(lambda: self.write(s), 814 | increasing_level_traverse, node.values) 815 | 816 | def visit_Attribute(self, node): 817 | self.set_precedence(_Precedence.ATOM, node.value) 818 | self.traverse(node.value) 819 | # Special case: 3.__abs__() is a syntax error, so if node.value 820 | # is an integer literal then we need to either parenthesize 821 | # it or add an extra space to get 3 .__abs__(). 822 | if isinstance(node.value, Constant) and isinstance(node.value.value, 823 | int): 824 | self.write(" ") 825 | self.write(".") 826 | self.write(node.attr) 827 | 828 | def visit_Call(self, node): 829 | self.set_precedence(_Precedence.ATOM, node.func) 830 | self.traverse(node.func) 831 | with self.delimit("(", ")"): 832 | comma = False 833 | for e in node.args: 834 | if comma: 835 | self.write(", ") 836 | else: 837 | comma = True 838 | self.traverse(e) 839 | for e in node.keywords: 840 | if comma: 841 | self.write(", ") 842 | else: 843 | comma = True 844 | self.traverse(e) 845 | 846 | def visit_Subscript(self, node): 847 | def is_simple_tuple(slice_value): 848 | # when unparsing a non-empty tuple, the parantheses can be safely 849 | # omitted if there aren't any elements that explicitly requires 850 | # parantheses (such as starred expressions). 851 | return ( 852 | isinstance(slice_value, Tuple) 853 | and slice_value.elts 854 | and not any(isinstance(elt, Starred) 855 | for elt in slice_value.elts) 856 | ) 857 | 858 | self.set_precedence(_Precedence.ATOM, node.value) 859 | self.traverse(node.value) 860 | with self.delimit("[", "]"): 861 | if is_simple_tuple(node.slice): 862 | self.items_view(self.traverse, node.slice.elts) 863 | else: 864 | self.traverse(node.slice) 865 | 866 | def visit_Starred(self, node): 867 | self.write("*") 868 | self.set_precedence(_Precedence.EXPR, node.value) 869 | self.traverse(node.value) 870 | 871 | def visit_Ellipsis(self, node): 872 | self.write("...") 873 | 874 | def visit_Slice(self, node): 875 | if node.lower: 876 | self.traverse(node.lower) 877 | self.write(":") 878 | if node.upper: 879 | self.traverse(node.upper) 880 | if node.step: 881 | self.write(":") 882 | self.traverse(node.step) 883 | 884 | def visit_arg(self, node): 885 | self.write(node.arg) 886 | if node.annotation: 887 | self.write(": ") 888 | self.traverse(node.annotation) 889 | 890 | def visit_arguments(self, node): 891 | first = True 892 | # normal arguments 893 | all_args = node.posonlyargs + node.args 894 | defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults 895 | for index, elements in enumerate(zip(all_args, defaults), 1): 896 | a, d = elements 897 | if first: 898 | first = False 899 | else: 900 | self.write(", ") 901 | self.traverse(a) 902 | if d: 903 | self.write("=") 904 | self.traverse(d) 905 | if index == len(node.posonlyargs): 906 | self.write(", /") 907 | 908 | # varargs, or bare '*' if no varargs but keyword-only arguments present 909 | if node.vararg or node.kwonlyargs: 910 | if first: 911 | first = False 912 | else: 913 | self.write(", ") 914 | self.write("*") 915 | if node.vararg: 916 | self.write(node.vararg.arg) 917 | if node.vararg.annotation: 918 | self.write(": ") 919 | self.traverse(node.vararg.annotation) 920 | 921 | # keyword-only arguments 922 | if node.kwonlyargs: 923 | for a, d in zip(node.kwonlyargs, node.kw_defaults): 924 | self.write(", ") 925 | self.traverse(a) 926 | if d: 927 | self.write("=") 928 | self.traverse(d) 929 | 930 | # kwargs 931 | if node.kwarg: 932 | if first: 933 | first = False 934 | else: 935 | self.write(", ") 936 | self.write("**" + node.kwarg.arg) 937 | if node.kwarg.annotation: 938 | self.write(": ") 939 | self.traverse(node.kwarg.annotation) 940 | 941 | def visit_keyword(self, node): 942 | if node.arg is None: 943 | self.write("**") 944 | else: 945 | self.write(node.arg) 946 | self.write("=") 947 | self.traverse(node.value) 948 | 949 | def visit_Lambda(self, node): 950 | with self.require_parens(_Precedence.TEST, node): 951 | self.write("lambda ") 952 | self.traverse(node.args) 953 | self.write(": ") 954 | self.set_precedence(_Precedence.TEST, node.body) 955 | self.traverse(node.body) 956 | 957 | def visit_alias(self, node): 958 | self.write(node.name) 959 | if node.asname: 960 | self.write(" as " + node.asname) 961 | 962 | def visit_withitem(self, node): 963 | self.traverse(node.context_expr) 964 | if node.optional_vars: 965 | self.write(" as ") 966 | self.traverse(node.optional_vars) 967 | 968 | 969 | def unparse(ast_obj): 970 | unparser = Unparser() 971 | return unparser.visit(ast_obj) 972 | 973 | 974 | def roundtrip(filename): 975 | with open(filename, "rb") as pyfile: 976 | encoding = tokenize.detect_encoding(pyfile.readline)[0] 977 | with open(filename, "r", encoding=encoding) as pyfile: 978 | source = pyfile.read() 979 | tree = ast.parse(source, filename, type_comments=True) 980 | print(unparse(tree)) 981 | 982 | 983 | if __name__ == '__main__': 984 | roundtrip(sys.argv[1]) 985 | --------------------------------------------------------------------------------