├── test ├── __init__.py ├── data │ ├── test.vimrc │ ├── README.md │ ├── pep-0563-annotations.py │ └── grammar35.py ├── test_fuzz.py ├── conftest.py ├── test_plugin.py └── test_parser.py ├── semshi ├── rplugin └── python3 │ └── semshi │ ├── __init__.py │ ├── util.py │ ├── node.py │ ├── plugin.py │ ├── parser.py │ ├── handler.py │ └── visitor.py ├── .flake8 ├── .gitignore ├── .travis.yml ├── pyproject.toml ├── setup.py ├── script └── dev.vimrc ├── .github └── workflows │ └── ci.yml ├── plugin └── semshi.vim └── README.md /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /semshi: -------------------------------------------------------------------------------- 1 | rplugin/python3/semshi -------------------------------------------------------------------------------- /rplugin/python3/semshi/__init__.py: -------------------------------------------------------------------------------- 1 | from .plugin import Plugin 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E731,F402,E261,E306,E302,E305,W504 3 | exclude = __init__.py,lib/ 4 | -------------------------------------------------------------------------------- /test/data/test.vimrc: -------------------------------------------------------------------------------- 1 | let &runtimepath .= ',' . getcwd() 2 | 3 | set noswapfile 4 | set hidden 5 | set viminfo="NONE" 6 | set shada="NONE" 7 | -------------------------------------------------------------------------------- /test/data/README.md: -------------------------------------------------------------------------------- 1 | To get test_grammar file, see the history from: 2 | 3 | https://github.com/python/cpython/blob/main/Lib/test/test_grammar.py 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info/ 5 | .eggs/ 6 | .coverage 7 | .cache/ 8 | .report.json 9 | TODO 10 | .tox 11 | build/ 12 | dist/ 13 | .pytest_cache/ 14 | rplugin.vim 15 | .coverage.* 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | - "3.7-dev" 6 | - "3.8-dev" 7 | jobs: 8 | include: 9 | - python: 3.7 10 | env: TOXENV=lint 11 | before_install: 12 | - wget -O nvim https://github.com/neovim/neovim/releases/latest/download/nvim.appimage 13 | - chmod +x nvim 14 | - export PATH="$PATH:." 15 | install: 16 | - pip3 install tox-travis codecov 17 | script: 18 | - nvim --version 19 | - tox 20 | - codecov 21 | -------------------------------------------------------------------------------- /test/test_fuzz.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from semshi.parser import UnparsableError 6 | 7 | from .conftest import dump_symtable, parse 8 | 9 | 10 | def test_multiple_files(): 11 | """Attempt to parse random Python files.""" 12 | for root, dirs, files in os.walk('/usr/lib/python3.8/'): 13 | for file in files: 14 | if not file.endswith('.py'): 15 | continue 16 | path = os.path.join(root, file) 17 | print(path) 18 | with open(path, encoding='utf-8', errors='ignore') as f: 19 | code = f.read() 20 | try: 21 | names = parse(code) 22 | except UnparsableError as e: 23 | print('unparsable', path, e.error) 24 | continue 25 | except Exception as e: 26 | dump_symtable(code) 27 | raise 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pylint.'MESSAGES CONTROL'] 2 | 3 | disable = [ 4 | "missing-docstring", 5 | "invalid-name", 6 | "too-few-public-methods", 7 | "too-many-public-methods", 8 | "fixme", 9 | "redefined-builtin", 10 | "redefined-outer-name", 11 | "too-many-return-statements", 12 | "import-error", 13 | "too-many-branches", 14 | "too-many-arguments", 15 | "too-many-instance-attributes", 16 | "no-member", 17 | ] 18 | 19 | [tool.isort] 20 | # https://pycqa.github.io/isort/docs/configuration/options.html 21 | profile = "black" 22 | single_line_exclusions = ['typing', 'typing_extensions'] 23 | extra_standard_library = ['typing_extensions'] 24 | known_third_party = [] 25 | skip_glob = ['test/data/*.py'] 26 | 27 | [tool.yapf] 28 | # see https://github.com/google/yapf#knobs 29 | based_on_style = "pep8" 30 | indent_width = 4 31 | spaces_before_comment = 2 32 | 33 | [tool.yapfignore] 34 | ignore_patterns = ['test/data/*.py'] 35 | -------------------------------------------------------------------------------- /test/data/pep-0563-annotations.py: -------------------------------------------------------------------------------- 1 | # See https://peps.python.org/pep-0563/ 2 | 3 | class C: 4 | 5 | class D: 6 | field2 = 'd_field' 7 | def method(self) -> C.D.field2: # this is OK 8 | ... 9 | 10 | def method(self) -> D.field2: # this FAILS, class D is local to C 11 | ... # and is therefore only available 12 | # as C.D. This was already true 13 | # before the PEP. 14 | 15 | def method(self) -> field2: # this is OK 16 | ... 17 | 18 | def method(self) -> field: # this FAILS, field is local to C and 19 | ... # is therefore not visible to D unless 20 | # accessed as C.field. This was already 21 | # true before the PEP. 22 | 23 | field = 'c_field' 24 | def method(self) -> C.field: # this is OK 25 | ... 26 | 27 | def method(self) -> field: # this is OK 28 | ... 29 | 30 | def method(self) -> C.D: # this is OK 31 | ... 32 | a = C.D 33 | 34 | def method(self) -> D: # this is OK 35 | ... 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | # The setup is currently only used for tests. See the README for installation 4 | # instructions in Neovim. 5 | setup( 6 | name='semshi', 7 | description='Semantic Highlighting for Python in Neovim', 8 | version='0.5.0.dev0', 9 | packages=['semshi'], 10 | # Original repo: https://github.com/numirias/semshi 11 | # author='numirias', 12 | # author_email='numirias@users.noreply.github.com', 13 | author='wookayin', 14 | author_email='wookayin@gmail.com', 15 | url='https://github.com/wookayin/semshi', 16 | license='MIT', 17 | python_requires='>=3.7', 18 | install_requires=[ 19 | 'pytest>=7.0', 20 | 'pytest-pudb', 21 | 'pynvim>=0.4.3', 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Programming Language :: Python', 26 | 'Programming Language :: Python :: 3', 27 | 'Programming Language :: Python :: 3.7', 28 | 'Programming Language :: Python :: 3.8', 29 | 'Programming Language :: Python :: 3.9', 30 | 'Programming Language :: Python :: 3.10', 31 | 'Programming Language :: Python :: 3.11', 32 | 'Programming Language :: Python :: 3.12', 33 | 'Programming Language :: Python :: 3.13', 34 | 'Topic :: Text Editors', 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /script/dev.vimrc: -------------------------------------------------------------------------------- 1 | "A minimal vimrc for development 2 | 3 | syntax on 4 | set nocompatible 5 | colorscheme zellner 6 | 7 | set noswapfile 8 | set hidden 9 | set tabstop=8 10 | set shiftwidth=4 11 | set softtabstop=4 12 | set smarttab 13 | set expandtab 14 | set number 15 | 16 | let &runtimepath .= ',' . getcwd() 17 | let $NVIM_RPLUGIN_MANIFEST = './script/rplugin.vim' 18 | 19 | let mapleader = ',' 20 | 21 | noremap 4j 22 | noremap 4k 23 | noremap q :q 24 | noremap Q :qa! 25 | noremap :bnext 26 | noremap :bprev 27 | 28 | 29 | function! SynStack() 30 | if !exists('*synstack') 31 | return 32 | endif 33 | echo map(synstack(line('.'), col('.')), "synIDattr(v:val, 'name')") 34 | endfunc 35 | nnoremap v :call SynStack() 36 | 37 | 38 | let $SEMSHI_LOG_FILE = '/tmp/semshi.log' 39 | let $SEMSHI_LOG_LEVEL = 'DEBUG' 40 | 41 | let g:semshi#error_sign_delay = 0.5 42 | 43 | nmap rr :Semshi rename 44 | nmap :Semshi goto name next 45 | nmap :Semshi goto name prev 46 | 47 | nmap :Semshi goto class next 48 | nmap :Semshi goto class prev 49 | 50 | nmap :Semshi goto function next 51 | nmap :Semshi goto function prev 52 | 53 | nmap ee :Semshi error 54 | nmap ge :Semshi goto error 55 | -------------------------------------------------------------------------------- /rplugin/python3/semshi/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | import time 5 | 6 | 7 | def lines_to_code(lines): 8 | return '\n'.join(lines) 9 | 10 | 11 | def code_to_lines(code): 12 | return code.split('\n') 13 | 14 | 15 | def debug_time(label_or_callable=None, detail=None): 16 | 17 | def inner(func): 18 | 19 | @functools.wraps(func) 20 | def wrapper(*args, **kwargs): 21 | t = time.time() 22 | res = func(*args, **kwargs) 23 | label = label_or_callable 24 | if not isinstance(label, str): 25 | try: 26 | label = func.__name__ 27 | except AttributeError: 28 | label = func.__class__.__name__ 29 | text = 'TIME %s: %f ' % (label, time.time() - t) 30 | if detail is not None: 31 | if callable(detail): 32 | text += detail(*args, **kwargs) 33 | else: 34 | text += detail.format(*args, **kwargs) 35 | logger.debug(text) 36 | return res 37 | 38 | return wrapper 39 | 40 | if callable(label_or_callable): 41 | return inner(label_or_callable) 42 | return inner 43 | 44 | 45 | def make_logger(): 46 | logger = logging.getLogger('semshi') 47 | logger.setLevel(logging.ERROR) 48 | log_file = os.environ.get('SEMSHI_LOG_FILE') 49 | if log_file: 50 | handler = logging.FileHandler(log_file) 51 | logger.setLevel(os.environ.get('SEMSHI_LOG_LEVEL', logging.ERROR)) 52 | logger.addHandler(handler) 53 | logger.debug('Semshi logger started.') 54 | return logger 55 | 56 | 57 | logger = make_logger() 58 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import symtable 4 | from textwrap import dedent 5 | 6 | from semshi.parser import Parser 7 | 8 | 9 | def make_tree(names): 10 | root = {} 11 | for node in names: 12 | n = root 13 | prev = None 14 | for scope in node.env: 15 | name = scope.get_name() 16 | if name not in n: 17 | n[name] = {} 18 | prev = n 19 | n = n[name] 20 | if 'names' not in n: 21 | n['names'] = [] 22 | n['names'].append(node.symname) 23 | return root['top'] 24 | 25 | 26 | def dump_dict(root): 27 | print(json.dumps(root, indent=4)) 28 | 29 | 30 | def dump_symtable(table_or_code): 31 | if isinstance(table_or_code, str): 32 | table = symtable.symtable(dedent(table_or_code), '?', 'exec') 33 | else: 34 | table = table_or_code 35 | 36 | def visit_table(table, indent=0): 37 | it = indent * ' ' 38 | print(it, table) 39 | if isinstance(table, symtable.Class): 40 | print(table.get_methods()) 41 | for symbol in table.get_symbols(): 42 | print((indent + 4) * ' ', symbol, symbol.is_namespace(), 43 | symbol.get_namespaces(), 'free', symbol.is_free(), 'local', 44 | symbol.is_local(), 'global', symbol.is_global()) 45 | for child in table.get_children(): 46 | visit_table(child, indent=indent + 4) 47 | 48 | visit_table(table) 49 | 50 | 51 | def dump_ast(node_or_code): 52 | if isinstance(node_or_code, str): 53 | node = ast.parse(dedent(node_or_code)) 54 | else: 55 | node = node_or_code 56 | tree = ast.dump(node) 57 | print(tree) 58 | 59 | 60 | def parse(code): 61 | add, remove = Parser().parse(dedent(code)) 62 | assert len(remove) == 0 63 | for node in add: 64 | node.base_table() 65 | return add 66 | 67 | 68 | def make_parser(code): 69 | parser = Parser() 70 | add, remove = parser.parse(dedent(code)) 71 | return parser 72 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: "Run Tests" 2 | on: [push, pull_request, workflow_dispatch] 3 | 4 | jobs: 5 | unit-tests: 6 | name: "Unit Tests" 7 | runs-on: ${{ matrix.os }} 8 | timeout-minutes: 10 9 | 10 | strategy: 11 | matrix: 12 | include: 13 | - python-version: "3.7" 14 | os: ubuntu-22.04 15 | - python-version: "3.8" 16 | os: ubuntu-latest 17 | - python-version: "3.9" 18 | os: ubuntu-latest 19 | - python-version: "3.10" 20 | os: ubuntu-latest 21 | - python-version: "3.11" 22 | os: ubuntu-latest 23 | - python-version: "3.12" 24 | os: ubuntu-latest 25 | - python-version: "3.13" 26 | os: ubuntu-latest 27 | # neovim-version: nightly 28 | # PYNVIM_MASTER: true 29 | 30 | steps: 31 | - uses: actions/checkout@v2 32 | 33 | - name: Set up Python ${{ matrix.python-version }} 34 | uses: actions/setup-python@v2 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | 38 | - name: Upgrade pip 39 | run: | 40 | python -m pip install -U pip 41 | 42 | - name: Configure environments 43 | run: | 44 | python3 --version 45 | 46 | - name: Enforce code style 47 | run: | 48 | pip install yapf isort 49 | yapf --recursive --diff semshi/ test/ 50 | isort --check-only --diff semshi/ test/ 51 | if: ${{ matrix.python-version == '3.11' }} 52 | 53 | - name: Setup neovim 54 | run: | 55 | sudo apt install libfuse2 56 | 57 | NVIM_VERSION="${{ matrix.neovim-version }}" 58 | if [ -z "$NVIM_VERSION" ]; then 59 | NVIM_VERSION="stable" 60 | fi 61 | NVIM_DOWNLOAD_URL="https://github.com/neovim/neovim/releases/download/${NVIM_VERSION}/nvim-linux-x86_64.appimage" 62 | 63 | mkdir -p $HOME/.local/bin 64 | wget -O $HOME/.local/bin/nvim $NVIM_DOWNLOAD_URL 65 | chmod +x $HOME/.local/bin/nvim 66 | 67 | if [ -n "${{ matrix.PYNVIM_MASTER }}" ]; then 68 | python3 -m pip install 'pynvim @ git+https://github.com/neovim/pynvim.git' 69 | else 70 | python3 -m pip install pynvim 71 | fi 72 | python3 -c 'import pynvim; print("__version__ =", pynvim.__version__)' 73 | python3 -c 'import pynvim; print(pynvim.__file__)' 74 | echo "$HOME/.local/bin" >> $GITHUB_PATH 75 | 76 | - name: Check neovim version 77 | run: | 78 | nvim --version 79 | 80 | - name: Install dependencies 81 | run: | 82 | python3 -m pip install pytest codecov pytest-cov 83 | 84 | - name: Run tests 85 | run: | 86 | pytest -xsv --cov semshi/ --cov-report term-missing:skip-covered --ignore test/test_fuzz.py test/ 87 | 88 | - name: Test coverage 89 | run: codecov 90 | 91 | -------------------------------------------------------------------------------- /plugin/semshi.vim: -------------------------------------------------------------------------------- 1 | " plugin/semshi.vim 2 | " vim: set ts=4 sts=4 sw=4: 3 | 4 | " These options can't be initialized in the Python plugin since they must be 5 | " known immediately. 6 | let g:semshi#filetypes = get(g:, 'semshi#filetypes', ['python']) 7 | let g:semshi#simplify_markup = get(g:, 'semshi#simplify_markup', v:true) 8 | let g:semshi#no_default_builtin_highlight = get(g:, 'semshi#no_default_builtin_highlight', v:true) 9 | 10 | function! s:simplify_markup() 11 | autocmd FileType python call s:simplify_markup_extra() 12 | 13 | " For python-syntax plugin 14 | let g:python_highlight_operators = 0 15 | endfunction 16 | 17 | function! s:simplify_markup_extra() 18 | hi link pythonConditional pythonStatement 19 | hi link pythonImport pythonStatement 20 | hi link pythonInclude pythonStatement 21 | hi link pythonRaiseFromStatement pythonStatement 22 | hi link pythonDecorator pythonStatement 23 | hi link pythonException pythonStatement 24 | hi link pythonConditional pythonStatement 25 | hi link pythonRepeat pythonStatement 26 | endfunction 27 | 28 | function! s:disable_builtin_highlights() 29 | autocmd FileType python call s:remove_builtin_extra() 30 | let g:python_no_builtin_highlight = 1 31 | hi link pythonBuiltin NONE 32 | let g:python_no_exception_highlight = 1 33 | hi link pythonExceptions NONE 34 | hi link pythonAttribute NONE 35 | hi link pythonDecoratorName NONE 36 | 37 | " For python-syntax plugin 38 | let g:python_highlight_class_vars = 0 39 | let g:python_highlight_builtins = 0 40 | let g:python_highlight_exceptions = 0 41 | hi link pythonDottedName NONE 42 | endfunction 43 | 44 | function! s:remove_builtin_extra() 45 | syn keyword pythonKeyword True False None 46 | hi link pythonKeyword pythonNumber 47 | endfunction 48 | 49 | " Ensure the rplugin manifest 50 | function! s:check_rplugin_manifest() abort 51 | if exists('s:semshi_rplugin_error') > 0 52 | return v:false 53 | endif 54 | if exists(':Semshi') > 0 55 | return v:true 56 | endif 57 | let s:semshi_rplugin_error = 1 58 | command! -nargs=* Semshi call nvim_err_writeln(":Semshi not found. Run :UpdateRemotePlugins.") 59 | 60 | " notify with an asynchronous error message 61 | if exists(':lua') && has('nvim-0.5.0') > 0 62 | lua << EOF 63 | vim.schedule(function() 64 | vim.notify(":Semshi not found. Run :UpdateRemotePlugins.", 'ERROR', { title = "semshi" }) 65 | end) 66 | EOF 67 | endif 68 | return v:false 69 | endfunction 70 | 71 | function! s:filetype_changed() abort 72 | if !s:check_rplugin_manifest() 73 | " Avoid exceptions inside FileType autocmd, because the stacktrace is ugly. 74 | " Instead, an asynchronous notification that something is broken will be made. 75 | return 76 | endif 77 | 78 | let l:ft = expand('') 79 | if index(g:semshi#filetypes, l:ft) != -1 80 | if !get(b:, 'semshi_attached', v:false) 81 | Semshi enable 82 | endif 83 | else 84 | if get(b:, 'semshi_attached', v:false) 85 | Semshi disable 86 | endif 87 | endif 88 | endfunction 89 | 90 | function! semshi#buffer_attach() 91 | if get(b:, 'semshi_attached', v:false) 92 | return 93 | endif 94 | let b:semshi_attached = v:true 95 | augroup SemshiEvents 96 | autocmd! * 97 | autocmd BufEnter call SemshiBufEnter(+expand(''), line('w0'), line('w$')) 98 | autocmd BufLeave call SemshiBufLeave() 99 | autocmd VimResized call SemshiVimResized(line('w0'), line('w$')) 100 | autocmd TextChanged call SemshiTextChanged() 101 | autocmd TextChangedI call SemshiTextChanged() 102 | autocmd CursorMoved call SemshiCursorMoved(line('w0'), line('w$')) 103 | autocmd CursorMovedI call SemshiCursorMoved(line('w0'), line('w$')) 104 | augroup END 105 | call SemshiBufEnter(bufnr('%'), line('w0'), line('w$')) 106 | endfunction 107 | 108 | function! semshi#buffer_detach() 109 | let b:semshi_attached = v:false 110 | augroup SemshiEvents 111 | autocmd! * 112 | augroup END 113 | endfunction 114 | 115 | function! semshi#buffer_wipeout() 116 | try 117 | call SemshiBufWipeout(+expand('')) 118 | catch /:E117:/ 119 | " UpdateRemotePlugins probably not done yet, ignore 120 | endtry 121 | endfunction 122 | 123 | function! semshi#init() 124 | hi def semshiLocal ctermfg=209 guifg=#ff875f 125 | hi def semshiGlobal ctermfg=214 guifg=#ffaf00 126 | hi def semshiImported ctermfg=214 guifg=#ffaf00 cterm=bold gui=bold 127 | hi def semshiParameter ctermfg=75 guifg=#5fafff 128 | hi def semshiParameterUnused ctermfg=117 guifg=#87d7ff cterm=underline gui=underline 129 | hi def semshiFree ctermfg=218 guifg=#ffafd7 130 | hi def semshiBuiltin ctermfg=207 guifg=#ff5fff 131 | hi def semshiAttribute ctermfg=49 guifg=#00ffaf 132 | hi def semshiSelf ctermfg=249 guifg=#b2b2b2 133 | hi def semshiUnresolved ctermfg=226 guifg=#ffff00 cterm=underline gui=underline 134 | hi def semshiSelected ctermfg=231 guifg=#ffffff ctermbg=161 guibg=#d7005f 135 | 136 | hi def semshiErrorSign ctermfg=231 guifg=#ffffff ctermbg=160 guibg=#d70000 137 | hi def semshiErrorChar ctermfg=231 guifg=#ffffff ctermbg=160 guibg=#d70000 138 | sign define semshiError text=E> texthl=semshiErrorSign 139 | 140 | augroup SemshiInit 141 | autocmd! 142 | if g:semshi#no_default_builtin_highlight 143 | call s:disable_builtin_highlights() 144 | endif 145 | if g:semshi#simplify_markup 146 | call s:simplify_markup() 147 | endif 148 | autocmd ColorScheme * call semshi#init() 149 | autocmd FileType * call s:filetype_changed() 150 | autocmd BufWipeout * call semshi#buffer_wipeout() 151 | augroup END 152 | endfunction 153 | 154 | call semshi#init() 155 | -------------------------------------------------------------------------------- /rplugin/python3/semshi/node.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | from itertools import count 3 | from typing import Dict 4 | 5 | # e.g. "global" -> "semshiGlobal" 6 | hl_groups: Dict[str, str] = {} 7 | 8 | 9 | def group(s): 10 | label = 'semshi' + s[0].capitalize() + s[1:] 11 | hl_groups[s] = label 12 | return label 13 | 14 | 15 | UNRESOLVED = group('unresolved') 16 | ATTRIBUTE = group('attribute') 17 | BUILTIN = group('builtin') 18 | FREE = group('free') 19 | GLOBAL = group('global') 20 | PARAMETER = group('parameter') 21 | PARAMETER_UNUSED = group('parameterUnused') 22 | SELF = group('self') 23 | IMPORTED = group('imported') 24 | LOCAL = group('local') 25 | SELECTED = group('selected') 26 | 27 | more_builtins = {'__file__', '__path__', '__cached__'} 28 | builtins = set(vars(builtins)) | more_builtins 29 | 30 | 31 | class Node: 32 | """A node in the source code. 33 | 34 | """ 35 | # Highlight ID for selected nodes 36 | MARK_ID = 31400 37 | # Highlight ID counter (chosen arbitrarily) 38 | id_counter = count(314001) 39 | 40 | __slots__ = [ 41 | 'id', 'name', 'lineno', 'col', 'end', 'env', 'symname', 'symbol', 42 | 'hl_group', 'target', '_tup' 43 | ] 44 | 45 | def __init__(self, name, lineno, col, env, target=None, hl_group=None): 46 | self.id = next(Node.id_counter) 47 | self.name = name 48 | self.lineno = lineno 49 | self.col = col 50 | # Encode the name to get the byte length, not the number of chars 51 | self.end = self.col + len(self.name.encode('utf-8')) 52 | self.env = env 53 | self.symname = self._make_symname(name) 54 | # The target node for an attribute 55 | self.target = target 56 | 57 | if hl_group == ATTRIBUTE: 58 | self.symbol = None 59 | else: 60 | self.symbol = self._lookup_symbol(self.env, self.symname) 61 | 62 | if hl_group is None: 63 | hl_group = self._make_hl_group() 64 | 65 | self.hl_group = hl_group 66 | self.update_tup() 67 | 68 | def update_tup(self): 69 | """Update tuple used for comparing with other nodes.""" 70 | self._tup = (self.lineno, self.col, self.hl_group, self.name) 71 | 72 | def __lt__(self, other): 73 | return self._tup < other._tup # pylint: disable=protected-access 74 | 75 | def __eq__(self, other): 76 | return self._tup == other._tup # pylint: disable=protected-access 77 | 78 | def __hash__(self): 79 | # Currently only required for tests 80 | return hash(self._tup) 81 | 82 | def __repr__(self): 83 | return '<%s %s %s (%s, %s) %d>' % ( 84 | self.name, 85 | self.hl_group[6:], 86 | '.'.join([x.get_name() for x in self.env]), 87 | self.lineno, 88 | self.col, 89 | self.id, 90 | ) 91 | 92 | @staticmethod 93 | def _lookup_symbol(env, symname): 94 | # Lookup a symbol in the current context, 95 | # from possibly nested symbol tables. 96 | for table in reversed(env): 97 | try: 98 | return table.lookup(symname) 99 | except KeyError: 100 | pass # move up to the parent scope/symtable 101 | 102 | # no matching symbol found 103 | return None 104 | 105 | def _make_hl_group(self): 106 | """Return highlight group the node belongs to.""" 107 | sym = self.symbol 108 | name = self.name 109 | 110 | if sym is None: 111 | # With PEP-563 (postponed annotations), symtable does not 112 | # return a symbol for an unresolved node. 113 | 114 | if name in builtins: 115 | return BUILTIN 116 | else: 117 | return UNRESOLVED 118 | 119 | if sym.is_parameter(): 120 | table = self.env[-1] 121 | # We have seen the node, so remove from unused parameters 122 | table.unused_params.pop(self.name, None) 123 | try: 124 | self_param = table.self_param 125 | except AttributeError: 126 | pass 127 | else: 128 | if self_param == name: 129 | return SELF 130 | return PARAMETER 131 | if sym.is_free(): 132 | table = self._ref_function_table() 133 | if table is not None: 134 | table.unused_params.pop(self.name, None) 135 | return FREE 136 | if sym.is_imported(): 137 | return IMPORTED 138 | if sym.is_local() and not sym.is_global(): 139 | return LOCAL 140 | if sym.is_global(): 141 | try: 142 | global_sym = self.env[0].lookup(name) 143 | except KeyError: 144 | pass 145 | else: 146 | if global_sym.is_assigned(): 147 | return GLOBAL 148 | if name in builtins: 149 | return BUILTIN 150 | if global_sym.is_imported(): 151 | return IMPORTED 152 | return UNRESOLVED 153 | if name in builtins: 154 | return BUILTIN 155 | return UNRESOLVED 156 | 157 | def _make_symname(self, name): 158 | """Return actual symbol name. 159 | 160 | The symname may be different due to name mangling. 161 | """ 162 | # Check if the name is a candidate for name mangling 163 | if not name.startswith('__') or name.endswith('__'): 164 | return name 165 | try: 166 | cls = next(t for t in reversed(self.env) 167 | if t.get_type() == 'class') 168 | except StopIteration: 169 | # Not inside a class, so no candidate for name mangling 170 | return name 171 | symname = '_' + cls.get_name().lstrip('_') + name 172 | return symname 173 | 174 | def _ref_function_table(self): 175 | """Return enclosing function table.""" 176 | for table in reversed(self.env): 177 | try: 178 | symbol = table.lookup(self.name) 179 | except KeyError: 180 | continue 181 | if symbol.is_parameter(): 182 | return table 183 | return None 184 | 185 | def base_table(self): 186 | """Return base symtable. 187 | 188 | The base symtable is the lowest scope with an associated symbol. 189 | """ 190 | if self.hl_group == ATTRIBUTE: 191 | return self.env[-1] 192 | 193 | if self.symbol: 194 | if self.symbol.is_global(): 195 | return self.env[0] 196 | if self.symbol.is_local() and not self.symbol.is_free(): 197 | return self.env[-1] 198 | 199 | for table in reversed(self.env): 200 | # Class scopes don't extend to enclosed scopes 201 | if table.get_type() == 'class': 202 | continue 203 | try: 204 | symbol = table.lookup(self.name) 205 | except KeyError: 206 | continue 207 | if symbol.is_local() and not symbol.is_free(): 208 | return table 209 | return None 210 | 211 | @property 212 | def pos(self): 213 | return (self.lineno, self.col) 214 | -------------------------------------------------------------------------------- /rplugin/python3/semshi/plugin.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | from functools import partial, wraps 5 | from typing import TYPE_CHECKING, List, Optional, Sequence, cast 6 | 7 | if TYPE_CHECKING: 8 | from typing import Literal # for py37 9 | 10 | import pynvim 11 | import pynvim.api 12 | 13 | from .handler import BufferHandler 14 | from .node import hl_groups 15 | 16 | # pylint: disable=consider-using-f-string 17 | 18 | _subcommands = {} 19 | 20 | 21 | def subcommand(func=None, needs_handler=False, silent_fail=True): 22 | """Decorator to register `func` as a ":Semshi [...]" subcommand. 23 | 24 | If `needs_handler`, the subcommand will fail if no buffer handler is 25 | currently active. If `silent_fail`, it will fail silently, otherwise an 26 | error message is printed. 27 | """ 28 | if func is None: 29 | return partial(subcommand, 30 | needs_handler=needs_handler, 31 | silent_fail=silent_fail) 32 | 33 | @wraps(func) 34 | def wrapper(self, *args, **kwargs): 35 | # pylint: disable=protected-access 36 | if self._options is None: 37 | self._init_with_vim() 38 | if needs_handler and self._cur_handler is None: 39 | if not silent_fail: 40 | self.echo_error('Semshi is not enabled in this buffer!') 41 | return 42 | func(self, *args, **kwargs) 43 | 44 | _subcommands[func.__name__] = wrapper 45 | return wrapper 46 | 47 | 48 | @pynvim.plugin 49 | class Plugin: 50 | """Semshi Neovim plugin. 51 | 52 | The plugin handles vim events and commands, and delegates them to a buffer 53 | handler. (Each buffer is handled by a semshi.BufferHandler instance.) 54 | """ 55 | 56 | def __init__(self, vim: pynvim.api.Nvim): 57 | self._vim = vim 58 | 59 | # A mapping (buffer number -> buffer handler) 60 | self._handlers = {} 61 | # The currently active buffer handler 62 | self._cur_handler: Optional[BufferHandler] = None 63 | self._options = None 64 | 65 | # Python version check 66 | if (3, 7) <= sys.version_info <= (3, 13, 9999): 67 | self._disabled = False 68 | else: 69 | self._disabled = True 70 | self.echom("Semshi currently supports Python 3.7 - 3.13. " + 71 | "(Current: {})".format(sys.version.split()[0])) 72 | 73 | def echom(self, msg: str): 74 | args = ([[msg, "WarningMsg"]], True, {}) 75 | self._vim.api.echo(*args) 76 | 77 | def _init_with_vim(self): 78 | """Initialize with vim available. 79 | 80 | Initialization code which interacts with vim can't be safely put in 81 | __init__ because vim itself may not be fully started up. 82 | """ 83 | self._options = Options(self._vim) 84 | 85 | def echo(self, *msgs): 86 | msg = ' '.join([str(m) for m in msgs]) 87 | self._vim.out_write(msg + '\n') 88 | 89 | def echo_error(self, *msgs): 90 | msg = ' '.join([str(m) for m in msgs]) 91 | self._vim.err_write(msg + '\n') 92 | 93 | # Must not be async here because we have to make sure that switching the 94 | # buffer handler is completed before other events are handled. 95 | @pynvim.function('SemshiBufEnter', sync=True) 96 | def event_buf_enter(self, args): 97 | buf_num, view_start, view_stop = args 98 | self._select_handler(buf_num) 99 | assert self._cur_handler is not None 100 | self._update_viewport(view_start, view_stop) 101 | self._cur_handler.update() 102 | self._mark_selected() 103 | 104 | @pynvim.function('SemshiBufLeave', sync=True) 105 | def event_buf_leave(self, _): 106 | self._cur_handler = None 107 | 108 | @pynvim.function('SemshiBufWipeout', sync=True) 109 | def event_buf_wipeout(self, args): 110 | self._remove_handler(args[0]) 111 | 112 | @pynvim.function('SemshiVimResized', sync=False) 113 | def event_vim_resized(self, args): 114 | self._update_viewport(*args) 115 | self._mark_selected() 116 | 117 | @pynvim.function('SemshiCursorMoved', sync=False) 118 | def event_cursor_moved(self, args): 119 | if self._cur_handler is None: 120 | # CursorMoved may trigger before BufEnter, so select the buffer if 121 | # we didn't enter it yet. 122 | self.event_buf_enter((self._vim.current.buffer.number, *args)) 123 | return 124 | self._update_viewport(*args) 125 | self._mark_selected() 126 | 127 | @pynvim.function('SemshiTextChanged', sync=False) 128 | def event_text_changed(self, _): 129 | if self._cur_handler is None: 130 | return 131 | # Note: TextChanged event doesn't trigger if text was changed in 132 | # unfocused buffer via e.g. nvim_buf_set_lines(). 133 | self._cur_handler.update() 134 | 135 | @pynvim.autocmd('VimLeave', sync=True) 136 | def event_vim_leave(self): 137 | for handler in self._handlers.values(): 138 | handler.shutdown() 139 | 140 | @pynvim.command( 141 | 'Semshi', 142 | nargs='*', # type: ignore 143 | complete='customlist,SemshiComplete', 144 | sync=True, 145 | ) 146 | def cmd_semshi(self, args): 147 | if not args: 148 | filetype = cast(pynvim.api.Buffer, 149 | self._vim.current.buffer).options.get('filetype') 150 | py_filetypes = self._vim.vars.get('semshi#filetypes', []) 151 | if filetype in py_filetypes: # for python buffers 152 | self._vim.command('Semshi status') 153 | else: # non-python 154 | self.echo('This is semshi.') 155 | return 156 | 157 | try: 158 | func = _subcommands[args[0]] 159 | except KeyError: 160 | self.echo_error('Subcommand not found: %s' % args[0]) 161 | return 162 | func(self, *args[1:]) 163 | 164 | @staticmethod 165 | @pynvim.function('SemshiComplete', sync=True) 166 | def func_complete(arg): 167 | lead, *_ = arg 168 | return [c for c in _subcommands if c.startswith(lead)] 169 | 170 | @pynvim.function('SemshiInternalEval', sync=True) 171 | def _internal_eval(self, args): 172 | """Eval Python code in plugin context. 173 | 174 | Only used for testing. 175 | """ 176 | plugin = self # noqa pylint: disable=unused-variable 177 | return eval(args[0]) # pylint: disable=eval-used 178 | 179 | @subcommand 180 | def enable(self): 181 | if self._disabled: 182 | return 183 | self._attach_listeners() 184 | self._select_handler(self._vim.current.buffer) 185 | self._update_viewport(*self._vim.eval('[line("w0"), line("w$")]')) 186 | self.highlight() 187 | 188 | @subcommand(needs_handler=True) 189 | def disable(self): 190 | self.clear() 191 | self._detach_listeners() 192 | self._cur_handler = None 193 | self._remove_handler(self._vim.current.buffer) 194 | 195 | @subcommand 196 | def toggle(self): 197 | if self._listeners_attached(): 198 | self.disable() 199 | else: 200 | self.enable() 201 | 202 | @subcommand(needs_handler=True) 203 | def pause(self): 204 | self._detach_listeners() 205 | 206 | @subcommand(needs_handler=True, silent_fail=False) 207 | def highlight(self): 208 | assert self._cur_handler 209 | self._cur_handler.update(force=True, sync=True) 210 | 211 | @subcommand(needs_handler=True) 212 | def clear(self): 213 | assert self._cur_handler 214 | self._cur_handler.clear_highlights() 215 | 216 | @subcommand(needs_handler=True, silent_fail=False) 217 | def rename(self, new_name=None): 218 | assert self._cur_handler 219 | self._cur_handler.rename(self._vim.current.window.cursor, new_name) 220 | 221 | @subcommand(needs_handler=True, silent_fail=False) 222 | def goto(self, *args, **kwargs): 223 | assert self._cur_handler 224 | self._cur_handler.goto(*args, **kwargs) 225 | 226 | @subcommand(needs_handler=True, silent_fail=False) 227 | def error(self): 228 | assert self._cur_handler 229 | self._cur_handler.show_error() 230 | 231 | @subcommand 232 | def status(self): 233 | if self._disabled: 234 | self.echo('Semshi is disabled: unsupported python version.') 235 | return 236 | 237 | buffer: pynvim.api.Buffer = self._vim.current.buffer 238 | attached: bool = buffer.vars.get('semshi_attached', False) 239 | 240 | syntax_error = '(not attached)' 241 | if self._cur_handler: 242 | syntax_error = str(self._cur_handler.syntax_error or '(none)') 243 | 244 | self.echo('\n'.join([ 245 | 'Semshi is {attached} on (bufnr={bufnr})', 246 | '- current handler: {handler}', 247 | '- handlers: {handlers}', 248 | '- syntax error: {syntax_error}', 249 | ]).format( 250 | attached=attached and "attached" or "detached", 251 | bufnr=str(buffer.number), 252 | handler=self._cur_handler, 253 | handlers=self._handlers, 254 | syntax_error=syntax_error, 255 | )) 256 | 257 | def _select_handler(self, buf_or_buf_num): 258 | """Select handler for `buf_or_buf_num`.""" 259 | if isinstance(buf_or_buf_num, int): 260 | buf = None 261 | buf_num = buf_or_buf_num 262 | else: 263 | buf = buf_or_buf_num 264 | buf_num = buf.number 265 | try: 266 | handler = self._handlers[buf_num] 267 | except KeyError: 268 | if buf is None: 269 | buf = self._vim.buffers[buf_num] 270 | assert self._options is not None, "must have been initialized" 271 | handler = BufferHandler(buf, self._vim, self._options) 272 | self._handlers[buf_num] = handler 273 | self._cur_handler = handler 274 | 275 | def _remove_handler(self, buf_or_buf_num): 276 | """Remove handler for buffer with the number `buf_num`.""" 277 | if isinstance(buf_or_buf_num, int): 278 | buf_num = buf_or_buf_num 279 | else: 280 | buf_num = buf_or_buf_num.number 281 | try: 282 | handler = self._handlers.pop(buf_num) 283 | except KeyError: 284 | return 285 | else: 286 | handler.shutdown() 287 | 288 | def _update_viewport(self, start, stop): 289 | if self._cur_handler: 290 | self._cur_handler.viewport(start, stop) 291 | 292 | def _mark_selected(self): 293 | assert self._options is not None, "must have been initialized" 294 | if not self._options.mark_selected_nodes: 295 | return 296 | try: 297 | handler = self._cur_handler 298 | if handler: 299 | cursor = self._vim.current.window.cursor 300 | handler.mark_selected(cursor) 301 | except pynvim.api.NvimError as ex: 302 | # Ignore "Invalid window ID" errors (see wookayin/semshi#3) 303 | if str(ex).startswith("Invalid window id:"): 304 | return 305 | 306 | raise ex # Re-raise other errors. 307 | 308 | def _attach_listeners(self): 309 | self._vim.call('semshi#buffer_attach') 310 | 311 | def _detach_listeners(self): 312 | self._vim.call('semshi#buffer_detach') 313 | 314 | def _listeners_attached(self): 315 | """Return whether event listeners are attached to the current buffer. 316 | """ 317 | return self._vim.eval('get(b:, "semshi_attached", v:false)') 318 | 319 | 320 | class Options: 321 | """Plugin options. 322 | 323 | The options will only be read and set once on init. 324 | """ 325 | _defaults = { 326 | 'filetypes': ['python'], 327 | 'excluded_hl_groups': ['local'], 328 | 'mark_selected_nodes': 1, 329 | 'no_default_builtin_highlight': True, 330 | 'simplify_markup': True, 331 | 'error_sign': True, 332 | 'error_sign_delay': 1.5, 333 | 'always_update_all_highlights': False, 334 | 'tolerate_syntax_errors': True, 335 | 'update_delay_factor': .0, 336 | 'self_to_attribute': True, 337 | } 338 | filetypes: List[str] 339 | excluded_hl_groups: List[str] 340 | mark_selected_nodes: Literal[0, 1, 2] 341 | no_default_builtin_highlight: bool 342 | simplify_markup: bool 343 | error_sign: bool 344 | error_sign_delay: float 345 | always_update_all_highlights: bool 346 | tolerate_syntax_errors: bool 347 | update_delay_factor: float 348 | self_to_attribute: bool 349 | 350 | def __init__(self, vim: pynvim.api.Nvim): 351 | for key, val_default in Options._defaults.items(): 352 | val = vim.vars.get('semshi#' + key, val_default) 353 | # vim.vars doesn't support setdefault(), so set value manually 354 | vim.vars['semshi#' + key] = val 355 | try: 356 | converter = getattr(Options, '_convert_' + key) 357 | except AttributeError: 358 | pass 359 | else: 360 | val = converter(val) 361 | setattr(self, key, val) 362 | 363 | @staticmethod 364 | def _convert_excluded_hl_groups(items: Sequence[str]) -> List[str]: 365 | try: 366 | return [hl_groups[g] for g in items] 367 | except KeyError as e: 368 | # TODO Use err_write instead? 369 | raise ValueError( 370 | f'"{e.args[0]}" is an unknown highlight group.') from e 371 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semshi (A maintained fork) 2 | 3 | [![Build Status](https://github.com/wookayin/semshi/actions/workflows/ci.yml/badge.svg)](https://github.com/wookayin/semshi/actions) 4 | ![Python Versions](https://img.shields.io/badge/python-3.7,%203.8,%203.9,%203.10,%203.11,%203.12,%203.13-blue.svg) 5 | 6 | 7 | 8 | Semshi provides semantic highlighting for Python in Neovim. 9 | 10 | > [!NOTE] 11 | > This is a maintained fork of [the original repository, numirias/semshi](https://github.com/numirias/semshi), which has been abandoned since 2021. 12 | > Please see [release notes](https://github.com/wookayin/semshi/releases) to what has been changed and updated in this fork. 13 | 14 | > [!NOTE] 15 | > For [LSP-based semantic highlight](https://neovim.io/doc/user/lsp.html#lsp-semantic-highlight), I strongly recommend 16 | > [basedpyright](https://github.com/DetachHead/basedpyright), a Python language server that natively support LSP's semantic tokens, in place of semshi. 17 | 18 | 19 | Unlike regex-based syntax highlighters, Semshi understands Python code and performs static analysis as you type. It builds a syntax tree and symbol tables to highlight names based on their scope and context. This makes code easier to read and lets you quickly identify missing imports, unused arguments, misspelled names, and more. 20 | 21 | | With Semshi | Without Semshi | 22 | | --- | --- | 23 | | ![After](https://i.imgur.com/rDBSM8s.png) | ![Before](https://i.imgur.com/t40TNZ6.png) | 24 | 25 | In the above example, you can easily distinguish arguments (blue), instance attributes (teal), globals (orange), unresolved globals (yellow underlined), etc. Also, Semshi understands that the first `list` is assigned locally, while the default highlighter still shows it as builtin. 26 | 27 | ## Features 28 | 29 | - Different highlighting of locals, globals, imports, used and unused function parameters, builtins, attributes, free and unresolved names. 30 | 31 | - Scope-aware marking and renaming of related nodes. 32 | 33 | ![Renaming](https://i.imgur.com/5zWRFyg.gif) 34 | 35 | - Indication of syntax errors. 36 | 37 | ![Syntax errors](https://i.imgur.com/tCj9myJ.gif) 38 | 39 | - Jumping between classes, functions and related names. (In modern neovim, using LSP is more recommended.) 40 | 41 | ## Installation 42 | 43 | - You need Neovim with Python 3 support (`:echo has("python3")`) with Python >= 3.7. To install the Python provider run: 44 | 45 | python3 -m pip install pynvim --upgrade 46 | 47 | ...for `python3` as configured in `g:python3_host_prog`. If you don't set this variable, neovim will use `python3` with respect to `$PATH`. 48 | 49 | - Note: Python 3.6 support has been dropped. If you have to use Python 3.6, use the v0.2.0 version. 50 | 51 | - Add `wookayin/semshi` via your favorite plugin manager. 52 | 53 | - [vim-plug](https://github.com/junegunn/vim-plug): 54 | ```vim 55 | Plug 'wookayin/semshi', { 'do': ':UpdateRemotePlugins', 'tag': '*' } 56 | ``` 57 | ... and then run `:PlugInstall`. 58 | 59 | - [lazy.nvim](https://github.com/folke/lazy.nvim): 60 | ```lua 61 | { 62 | "wookayin/semshi", 63 | build = ":UpdateRemotePlugins", 64 | version = "*", -- Recommended to use the latest release 65 | init = function() -- example, skip if you're OK with the default config 66 | vim.g['semshi#error_sign'] = false 67 | end, 68 | config = function() 69 | -- any config or setup that would need to be done after plugin loading 70 | end, 71 | } 72 | ``` 73 | 74 | 75 | 76 | - Make sure you run `:UpdateRemotePlugins` to update the plugin manifest, 77 | whenever the plugin is installed or updated; also when semshi doesn't work (e.g., 78 | [command `:Semshi` not found](https://github.com/numirias/semshi/issues/74), 79 | [Unknown function](https://github.com/numirias/semshi/issues/60), etc.) 80 | 81 | - Using [deoplete.nvim](https://github.com/Shougo/deoplete.nvim)? [Make sure it doesn't slow down Semshi.](#semshi-is-slow-together-with-deopletenvim) 82 | 83 | 84 | ## Configuration 85 | 86 | ### Options 87 | 88 | You can set these options in your vimrc (`~/.config/nvim/init.vim`), 89 | or if you're using [lazy.nvim](https://github.com/folke/lazy.nvim), in the `init` function of the plugin spec: 90 | 91 | | Option | Default | Description | 92 | | --- | --- | --- | 93 | | `g:semshi#filetypes` | `['python']` | List of file types on which to enable Semshi automatically. | 94 | | `g:semshi#excluded_hl_groups` | `['local']` | List of highlight groups not to highlight. Choose from `local`, `unresolved`, `attribute`, `builtin`, `free`, `global`, `parameter`, `parameterUnused`, `self`, `imported`. (It's recommended to keep `local` in the list because highlighting all locals in a large file can cause performance issues.) | 95 | | `g:semshi#mark_selected_nodes ` | `1` | Mark selected nodes (those with the same name and scope as the one under the cursor). Set to `2` to highlight the node currently under the cursor, too. | 96 | | `g:semshi#no_default_builtin_highlight` | `v:true` | Disable highlighting of builtins (`list`, `len`, etc.) by Vim's own Python syntax highlighter, because that's Semshi's job. If you turn it off, Vim may add incorrect highlights. | 97 | | `g:semshi#simplify_markup` | `v:true` | Simplify Python markup. Semshi introduces lots of new colors, so this option makes the highlighting of other syntax elements less distracting, binding most of them to `pythonStatement`. If you think Semshi messes with your colorscheme too much, try turning this off. | 98 | | `g:semshi#error_sign` | `v:true` | Show a sign in the sign column if a syntax error occurred. | 99 | | `g:semshi#error_sign_delay` | `1.5` | Delay in seconds until the syntax error sign is displayed. (A low delay time may distract while typing.) | 100 | | `g:semshi#always_update_all_highlights` | `v:false` | Update all visible highlights for every change. (Semshi tries to detect small changes and update only changed highlights. This can lead to some missing highlights. Turn this on for more reliable highlighting, but a small additional overhead.) | 101 | | `g:semshi#tolerate_syntax_errors` | `v:true` | Tolerate some minor syntax errors to update highlights even when the syntax is (temporarily) incorrect. (Smoother experience, but comes with some overhead.) | 102 | | `g:semshi#update_delay_factor` | `0.0` | Factor to delay updating of highlights. Updates will be delayed by `factor * number of lines` seconds. This is useful if instant re-parsing while editing large files stresses your CPU too much. A good starting point may be a factor of `0.0001` (that is, in a file with 1000 lines, parsing will be delayed by 0.1 seconds). | 103 | | `g:semshi#self_to_attribute` | `v:true` | Prefer the attribute of `self`/`cls` nodes. That is, when selecting the `self` in `self.foo`, Semshi will use the instance attribute `foo` instead. | 104 | 105 | ### Highlights 106 | 107 | Semshi sets these highlights/signs (which work best on dark backgrounds): 108 | 109 | ```VimL 110 | hi semshiLocal ctermfg=209 guifg=#ff875f 111 | hi semshiGlobal ctermfg=214 guifg=#ffaf00 112 | hi semshiImported ctermfg=214 guifg=#ffaf00 cterm=bold gui=bold 113 | hi semshiParameter ctermfg=75 guifg=#5fafff 114 | hi semshiParameterUnused ctermfg=117 guifg=#87d7ff cterm=underline gui=underline 115 | hi semshiFree ctermfg=218 guifg=#ffafd7 116 | hi semshiBuiltin ctermfg=207 guifg=#ff5fff 117 | hi semshiAttribute ctermfg=49 guifg=#00ffaf 118 | hi semshiSelf ctermfg=249 guifg=#b2b2b2 119 | hi semshiUnresolved ctermfg=226 guifg=#ffff00 cterm=underline gui=underline 120 | hi semshiSelected ctermfg=231 guifg=#ffffff ctermbg=161 guibg=#d7005f 121 | 122 | hi semshiErrorSign ctermfg=231 guifg=#ffffff ctermbg=160 guibg=#d70000 123 | hi semshiErrorChar ctermfg=231 guifg=#ffffff ctermbg=160 guibg=#d70000 124 | sign define semshiError text=E> texthl=semshiErrorSign 125 | ``` 126 | If you want to overwrite them in your vimrc, make sure to apply them *after* Semshi has set the defaults, e.g. in a function: 127 | 128 | ```VimL 129 | function MyCustomHighlights() 130 | hi semshiGlobal ctermfg=red guifg=#ff0000 131 | endfunction 132 | autocmd FileType python call MyCustomHighlights() 133 | ``` 134 | 135 | Also, if you want the highlight groups to persist across colorscheme switches, add: 136 | 137 | ```VimL 138 | autocmd ColorScheme * call MyCustomHighlights() 139 | ``` 140 | 141 | ## Usage 142 | 143 | Semshi parses and highlights code in all files with a `.py` extension. With every change to the buffer, the code is re-parsed and highlights are updated. When moving the cursor above a name, all nodes with the same name in the same scope are additionally marked. Semshi also attempts to tolerate syntax errors as you type. 144 | 145 | 146 | ### Commands 147 | 148 | The following commands can be executed via `:Semshi `: 149 | 150 | | Command | Description | 151 | | --- | --- | 152 | | `enable` | Enable highlighting for current buffer. | 153 | | `disable` | Disable highlighting for current buffer. | 154 | | `toggle` | Toggle highlighting for current buffer. | 155 | | `pause` | Like `disable`, but doesn't clear the highlights. | 156 | | `highlight` | Force update of highlights for current buffer. (Useful when for some reason highlighting hasn't been triggered.) | 157 | | `clear` | Clear all highlights in current buffer. | 158 | | `rename [new_name]` | Rename node under the cursor. If `new_name` isn't set, you're interactively prompted for the new name. | 159 | | `error` | Echo current syntax error message. | 160 | | `goto error` | Jump to current syntax error. | 161 | | `goto (name\|function\|class) (next\|prev\|first\|last)` | Jump to next/previous/first/last name/function/class. (See below for sample mappings.) | 162 | | `goto [highlight_group] (next\|prev\|first\|last)` | Jump to next/previous/first/last node with given highlight group. (Groups: `local`, `unresolved`, `attribute`, `builtin`, `free`, `global`, `parameter`, `parameterUnused`, `self`, `imported`) | 163 | 164 | Here are some possible mappings: 165 | 166 | ```VimL 167 | nmap rr :Semshi rename 168 | 169 | nmap :Semshi goto name next 170 | nmap :Semshi goto name prev 171 | 172 | nmap c :Semshi goto class next 173 | nmap C :Semshi goto class prev 174 | 175 | nmap f :Semshi goto function next 176 | nmap F :Semshi goto function prev 177 | 178 | nmap gu :Semshi goto unresolved first 179 | nmap gp :Semshi goto parameterUnused first 180 | 181 | nmap ee :Semshi error 182 | nmap ge :Semshi goto error 183 | ``` 184 | 185 | ## Limitations 186 | 187 | - Features like wildcard imports (`from foo import *`) or fancy metaprogramming may hide name bindings from simple static analysis. In that case, Semshi can't pick them up and may show these names as unresolved or highlight incorrectly. 188 | 189 | - While a syntax error is present (which can't be automatically compensated), Semshi can't update any highlights. So, highlights may be temporarily incorrect or misplaced while typing. 190 | 191 | - Although Semshi parses the code asynchronously and is not *that* slow, editing large files may stress your CPU and cause highlighting delays. 192 | 193 | - Semshi works with the same syntax version as your Neovim Python 3 provider. This means you can't use Semshi on code that's Python 2-only or uses incompatible syntax features. (Support for different versions is planned. See [#19](https://github.com/numirias/semshi/issues/19)) 194 | 195 | 196 | ## FAQ 197 | 198 | ### How does Semshi compare to refactoring/completion plugins like [jedi-vim](https://github.com/davidhalter/jedi-vim)? 199 | 200 | Semshi's primary focus is to provide reasonably fast semantic highlighting to make code easier to read. It's meant to replace your syntax highlighter, not your refactoring tools. So, Semshi works great alongside refactoring and completion libraries like Jedi. 201 | 202 | ### Is Vim 8 supported? 203 | 204 | No. Semshi relies on Neovim's fast highlighting API to quickly update lots of highlights. Regular Vim unfortunately doesn't have an equivalent API. (If you think this can be implemented for Vim 8, let me know.) 205 | 206 | ### Is Python 2 supported? 207 | 208 | No. [Migrate your old python code!](https://pythonclock.org/) 209 | 210 | We require **Python >= 3.7**. 211 | Please note that the python interpreter being used by semshi is [the python3 rplugin provider](https://neovim.io/doc/user/provider.html#provider-python), i.e. see `g:python3_host_prog`. 212 | 213 | ### Semshi is too slow. 214 | 215 | Semshi should be snappy on reasonably-sized Python files with ordinary hardware. But some plugins hooking the same events (e.g. [deoplete.nvim](https://github.com/Shougo/deoplete.nvim)) may cause significant delays. If you experience any performance problems, please file an issue. 216 | 217 | ### Semshi is slow together with [deoplete.nvim](https://github.com/Shougo/deoplete.nvim). 218 | 219 | Completion triggers may block Semshi from highlighting instantly. Try to increase Deoplete's `auto_complete_delay`, e.g.: 220 | 221 | ```VimL 222 | call deoplete#custom#option('auto_complete_delay', 100) 223 | ``` 224 | 225 | Or in older (<= 5.2) Deoplete versions: 226 | 227 | ```VimL 228 | let g:deoplete#auto_complete_delay = 100 229 | ``` 230 | 231 | ### There are some incorrect extra highlights. 232 | 233 | You might be using other Python syntax highlighters alongside (such as [python-syntax](https://github.com/vim-python/python-syntax)) which may interfere with Semshi. Try to disable these plugins if they cause problems. 234 | 235 | ### Sometimes highlights aren't updated. 236 | 237 | As you type code, you introduce temporary syntax errors, e.g. when opening a new bracket. Not all syntax errors can be compensated, so most of the time Semshi can only refresh highlights when the syntax becomes correct again. 238 | 239 | ## Contributing 240 | 241 | If you found a bug or have a suggestion, please don't hesitate to [file an issue](https://github.com/wookayin/semshi/issues/new). 242 | -------------------------------------------------------------------------------- /rplugin/python3/semshi/parser.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import symtable 3 | from collections import deque 4 | from collections.abc import Iterable 5 | from functools import singledispatch 6 | from keyword import kwlist 7 | from token import INDENT, NAME, OP 8 | from tokenize import TokenError, tokenize 9 | from typing import List, Optional 10 | 11 | from .util import code_to_lines, debug_time, lines_to_code, logger 12 | from .visitor import visitor 13 | 14 | 15 | class UnparsableError(Exception): 16 | 17 | def __init__(self, error): 18 | super().__init__() 19 | self.error = error 20 | 21 | 22 | class Parser: 23 | """The parser parses Python code and generates source code nodes. For every 24 | run of `parse()` on changed source code, it returns the nodes that have 25 | been added and removed. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | exclude: Optional[List[str]] = None, 31 | fix_syntax: bool = True, 32 | ): 33 | self._excluded = exclude or [] 34 | self._fix_syntax = fix_syntax 35 | self._locations = {} 36 | self._nodes = [] 37 | self.lines = [] 38 | # Incremented after every parse call 39 | self.tick = 0 40 | # Holds the error of the current and previous run, so the buffer 41 | # handler knows if error signs need to be updated. 42 | self.syntax_errors = deque([None, None], maxlen=2) 43 | self.same_nodes = singledispatch(self.same_nodes) 44 | self.same_nodes.register(Iterable, self._same_nodes_cursor) 45 | 46 | @debug_time 47 | def parse(self, *args, **kwargs): 48 | """Wrapper for `_parse()`. 49 | 50 | Raises UnparsableError() if an unrecoverable error occurred. 51 | """ 52 | try: 53 | return self._parse(*args, **kwargs) 54 | except (SyntaxError, RecursionError) as e: 55 | logger.debug('parsing error: %s', e) 56 | raise UnparsableError(e) from e 57 | finally: 58 | self.tick += 1 59 | 60 | @debug_time 61 | def _filter_excluded(self, nodes): 62 | return [n for n in nodes if n.hl_group not in self._excluded] 63 | 64 | def _parse(self, code, force=False): 65 | """Parse code and return tuple (`add`, `remove`) of added and removed 66 | nodes since last run. With `force`, all highlights are refreshed, even 67 | those that didn't change. 68 | """ 69 | self._locations.clear() 70 | old_lines = self.lines 71 | new_lines = code_to_lines(code) 72 | minor_change, change_lineno = self._minor_change(old_lines, new_lines) 73 | old_nodes = self._nodes 74 | new_nodes = self._make_nodes(code, new_lines, change_lineno) 75 | # Detecting minor changes keeps us from updating a lot of highlights 76 | # while the user is only editing a single line. 77 | if minor_change and not force: 78 | add, rem, keep = self._diff(old_nodes, new_nodes) 79 | self._nodes = keep + add 80 | else: 81 | add, rem = new_nodes, old_nodes 82 | self._nodes = add 83 | # Only assign new lines when nodes have been updated accordingly 84 | self.lines = new_lines 85 | logger.debug('[%d] nodes: +%d, -%d', self.tick, len(add), len(rem)) 86 | return (self._filter_excluded(add), self._filter_excluded(rem)) 87 | 88 | def _make_nodes(self, code, lines=None, change_lineno=None): 89 | """Return nodes in code. 90 | 91 | Runs AST visitor on code and produces nodes. We're passing both code 92 | *and* lines around to avoid lots of conversions. 93 | """ 94 | if lines is None: 95 | lines = code_to_lines(code) 96 | try: 97 | ast_root, fixed_code, fixed_lines, error = \ 98 | self._fix_syntax_and_make_ast(code, lines, change_lineno) 99 | except SyntaxError as e: 100 | # Apparently, fixing syntax errors failed 101 | self.syntax_errors.append(e) 102 | raise 103 | if fixed_code is not None: 104 | code = fixed_code 105 | lines = fixed_lines 106 | try: 107 | symtable_root = self._make_symtable(code) 108 | except SyntaxError as e: 109 | # In some cases, the symtable() call raises a syntax error which 110 | # hasn't been raised earlier (such as duplicate arguments) 111 | self.syntax_errors.append(e) 112 | raise 113 | self.syntax_errors.append(error) 114 | return visitor(lines, symtable_root, ast_root) 115 | 116 | @debug_time 117 | def _fix_syntax_and_make_ast(self, code, lines, change_lineno): 118 | """Try to fix syntax errors in code (if present) and return AST, fixed 119 | code and list of fixed lines of code. 120 | 121 | Current strategy to fix syntax errors: 122 | - Try to build AST from original code. 123 | - If that fails, call _fix_line() on the line indicated by the 124 | SyntaxError exception and try to build AST again. 125 | - If that fails, do the same with the line of the last change. 126 | - If all attempts failed, raise original SyntaxError exception. 127 | """ 128 | # TODO Cache previous attempt? 129 | try: 130 | return self._make_ast(code), None, None, None 131 | except SyntaxError as e: 132 | orig_error = e 133 | error_idx = e.lineno - 1 134 | if not self._fix_syntax: 135 | # Don't even attempt to fix syntax errors. 136 | raise orig_error 137 | new_lines = lines[:] 138 | # Save original line to restore later 139 | orig_line = new_lines[error_idx] 140 | new_lines[error_idx] = self._fix_line(orig_line) 141 | new_code = lines_to_code(new_lines) 142 | try: 143 | ast_root = self._make_ast(new_code) 144 | except SyntaxError: 145 | # Restore original line 146 | new_lines[error_idx] = orig_line 147 | # Fixing the line of the syntax error failed, so try again with the 148 | # line of last change. 149 | if change_lineno is None or change_lineno == error_idx: 150 | # Don't try to fix the changed line if it's unknown or the same 151 | # as the one we tried to fix before. 152 | raise orig_error from None 153 | new_lines[change_lineno] = self._fix_line(new_lines[change_lineno]) 154 | new_code = lines_to_code(new_lines) 155 | try: 156 | ast_root = self._make_ast(new_code) 157 | except SyntaxError: 158 | # All fixing attempts failed, so raise original syntax error. 159 | raise orig_error from None 160 | return ast_root, new_code, new_lines, orig_error 161 | 162 | @staticmethod 163 | def _fix_line(line): 164 | """Take a line of code which may have introduced a syntax error and 165 | return a modified version which is less likely to cause a syntax error. 166 | """ 167 | tokens = tokenize(iter([line.encode('utf-8')]).__next__) 168 | prev = None 169 | text = '' 170 | 171 | def add_token(token, filler): 172 | nonlocal text, prev 173 | text += (token.start[1] - len(text)) * filler + token.string 174 | prev = token 175 | 176 | try: 177 | for token in tokens: 178 | if token.type == INDENT: 179 | text += token.string 180 | elif (token.type == OP and token.string == '.' and prev 181 | and prev.type == NAME): 182 | add_token(token, ' ') 183 | elif token.type == NAME and token.string not in kwlist: 184 | if prev and prev.type == OP and prev.string == '.': 185 | add_token(token, ' ') 186 | else: 187 | add_token(token, '+') 188 | except TokenError as e: 189 | logger.debug('token error %s', e) 190 | if prev and prev.type == OP and prev.string == '.': 191 | # Cut superfluous dot from the end of line 192 | text = text[:-1] 193 | return text 194 | 195 | @staticmethod 196 | @debug_time 197 | def _make_ast(code): 198 | """Return AST for code.""" 199 | return ast.parse(code) 200 | 201 | @staticmethod 202 | @debug_time 203 | def _make_symtable(code): 204 | """Return symtable for code.""" 205 | return symtable.symtable(code, '?', 'exec') 206 | 207 | @staticmethod 208 | def _minor_change(old_lines, new_lines): 209 | """Determine whether a minor change between old and new lines occurred. 210 | Return (`minor_change`, `change_lineno`) where `minor_change` is True 211 | when at most one change occurred and `change_lineno` is the line number 212 | of the change. 213 | 214 | A minor change is a change in a single line while the total number of 215 | lines doesn't change. 216 | """ 217 | if len(old_lines) != len(new_lines): 218 | # A different number of lines doesn't count as minor change 219 | return (False, None) 220 | old_iter = iter(old_lines) 221 | new_iter = iter(new_lines) 222 | diff_lineno = None 223 | lineno = 0 224 | try: 225 | while True: 226 | old_lines = next(old_iter) 227 | new_lines = next(new_iter) 228 | if old_lines != new_lines: 229 | if diff_lineno is not None: 230 | # More than one change must have happened 231 | return (False, None) 232 | diff_lineno = lineno 233 | lineno += 1 234 | except StopIteration: 235 | # We iterated through all lines with at most one change 236 | return (True, diff_lineno) 237 | 238 | @staticmethod 239 | @debug_time 240 | def _diff(old_nodes, new_nodes): 241 | """Return difference between iterables of nodes old_nodes and new_nodes 242 | as three lists of nodes to add, remove and keep. 243 | """ 244 | add_iter = iter(sorted(new_nodes)) 245 | rem_iter = iter(sorted(old_nodes)) 246 | add_nodes = [] 247 | rem_nodes = [] 248 | keep_nodes = [] 249 | try: 250 | add = rem = None 251 | while True: 252 | if add == rem: 253 | if add is not None: 254 | keep_nodes.append(add) 255 | # A new node needs to adopt the highlight ID of 256 | # corresponding currently highlighted node 257 | add.id = rem.id 258 | add = rem = None 259 | add = next(add_iter) 260 | rem = next(rem_iter) 261 | elif add < rem: 262 | add_nodes.append(add) 263 | add = None 264 | add = next(add_iter) 265 | elif rem < add: 266 | rem_nodes.append(rem) 267 | rem = None 268 | rem = next(rem_iter) 269 | except StopIteration: 270 | if add is not None: 271 | add_nodes.append(add) 272 | if rem is not None: 273 | rem_nodes.append(rem) 274 | add_nodes += list(add_iter) 275 | rem_nodes += list(rem_iter) 276 | return add_nodes, rem_nodes, keep_nodes 277 | 278 | @debug_time 279 | def node_at(self, cursor): 280 | """Return node at cursor position.""" 281 | lineno, col = cursor 282 | for node in self._nodes: 283 | if node.lineno == lineno and node.col <= col < node.end: 284 | return node 285 | return None 286 | 287 | # pylint: disable=method-hidden 288 | def same_nodes(self, cur_node, mark_original=True, use_target=True): 289 | """Return nodes with the same scope as cur_node. 290 | 291 | The same scope is to be understood as all nodes with the same base 292 | symtable. In some cases this can be ambiguous. 293 | """ 294 | if use_target: 295 | target = cur_node.target 296 | if target is not None: 297 | cur_node = target 298 | cur_name = cur_node.name 299 | base_table = cur_node.base_table() 300 | for node in self._nodes: 301 | if node.name != cur_name: 302 | continue 303 | if not mark_original and node is cur_node: 304 | continue 305 | if node.base_table() == base_table: 306 | yield node 307 | 308 | def _same_nodes_cursor(self, cursor, mark_original=True, use_target=True): 309 | """Return nodes with the same scope as node at the cursor position.""" 310 | cur_node = self.node_at(cursor) 311 | if cur_node is None: 312 | return [] 313 | return self.same_nodes(cur_node, mark_original, use_target) 314 | 315 | def locations_by_node_types(self, types): 316 | """Return locations of all AST nodes in code whose type is contained in 317 | `types`.""" 318 | types_set = frozenset(types) 319 | try: 320 | return self._locations[types_set] 321 | except KeyError: 322 | pass 323 | visitor = _LocationCollectionVisitor(types) 324 | try: 325 | ast_ = ast.parse(lines_to_code(self.lines)) 326 | except SyntaxError: 327 | return [] 328 | visitor.visit(ast_) 329 | locations = visitor.locations 330 | self._locations[types_set] = locations 331 | return locations 332 | 333 | def locations_by_hl_group(self, group): 334 | """Return locations of all nodes whose highlight group is `group`.""" 335 | return [n.pos for n in self._nodes if n.hl_group == group] 336 | 337 | 338 | class _LocationCollectionVisitor(ast.NodeVisitor): 339 | """Node vistor which collects the locations of all AST nodes of a given 340 | type.""" 341 | 342 | def __init__(self, types): 343 | self._types = types 344 | self.locations = [] 345 | 346 | def visit(self, node): 347 | if type(node) in self._types: # pylint: disable=unidiomatic-typecheck 348 | self.locations.append((node.lineno, node.col_offset)) 349 | return self.generic_visit(node) 350 | -------------------------------------------------------------------------------- /test/test_plugin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import pynvim 6 | import pynvim.api 7 | import pytest 8 | 9 | # pylint: disable=consider-using-f-string 10 | # pylint: disable=unnecessary-lambda-assignment 11 | # pylint: disable=redefined-outer-name 12 | 13 | VIMRC = 'test/data/test.vimrc' 14 | SLEEP = 0.1 15 | 16 | 17 | @pytest.fixture(scope='session') 18 | def plugin_dir(tmpdir_factory): 19 | return tmpdir_factory.mktemp('test_plugin') 20 | 21 | 22 | NVIM_ARGV = [ 23 | 'nvim', '-i', 'NONE', '--embed', '--headless', '--cmd', 24 | 'let g:python3_host_prog="{}"'.format(sys.executable), '-u', VIMRC 25 | ] 26 | 27 | 28 | @pytest.fixture(scope='module', autouse=True) 29 | def register_plugin(plugin_dir): 30 | rplugin_manifest = str(plugin_dir.join('rplugin.vim')) 31 | os.environ['NVIM_RPLUGIN_MANIFEST'] = rplugin_manifest 32 | vim = pynvim.attach('child', argv=NVIM_ARGV) 33 | 34 | vim.command("python3 print(sys.executable)") 35 | vim.command('UpdateRemotePlugins') 36 | vim.quit() 37 | yield 38 | 39 | 40 | def wait_for(func, cond=None, sleep=.001, tries=1000): 41 | for _ in range(tries): 42 | res = func() 43 | if cond is None: 44 | if res: 45 | return 46 | else: 47 | if cond(res): 48 | return res 49 | time.sleep(sleep) 50 | raise TimeoutError() 51 | 52 | 53 | class WrappedVim: 54 | 55 | def __init__(self, vim): 56 | self._vim = vim 57 | 58 | def __getattr__(self, item): 59 | return getattr(self._vim, item) 60 | 61 | def host_eval(self, code): 62 | return self._vim.call('SemshiInternalEval', code) 63 | 64 | def wait_for_update_thread(self): 65 | wait_for( 66 | lambda: self.host_eval( 67 | 'plugin._cur_handler._update_thread.is_alive()'), 68 | lambda x: not x, 69 | ) 70 | 71 | 72 | @pytest.fixture(scope='function') 73 | def vim(): 74 | vim = pynvim.attach('child', argv=NVIM_ARGV) 75 | return WrappedVim(vim) 76 | 77 | 78 | @pytest.fixture(scope='function') 79 | def start_vim(tmp_path): 80 | 81 | def f(argv=None, file=None): 82 | if argv is None: 83 | argv = [] 84 | argv = list(NVIM_ARGV) + list(argv) 85 | vim = pynvim.attach('child', argv=argv) 86 | if file is not None: 87 | fn = file or (tmp_path / 'foo.py') 88 | vim.command('edit %s' % fn) 89 | return WrappedVim(vim) 90 | 91 | return f 92 | 93 | 94 | def test_commands(vim): 95 | """The :Semshi command is registered and doesn't error""" 96 | vim.command('Semshi') 97 | 98 | 99 | @pytest.mark.xfail 100 | def test_no_python_file(vim): 101 | """If no Python file is open, Semshi doesn't handle the current file""" 102 | assert vim.host_eval('plugin._cur_handler is None') 103 | 104 | 105 | def test_python_file(vim, tmp_path): 106 | """If a Python file is open, Semshi handles the current file""" 107 | vim.command('edit %s' % (tmp_path / 'foo.py')) 108 | assert vim.host_eval('plugin._cur_handler is not None') 109 | 110 | 111 | def test_current_nodes(vim, tmp_path): 112 | """Changes to the code cause changes to the registered nodes""" 113 | node_names = lambda: vim.host_eval( 114 | '[n.name for n in plugin._cur_handler._parser._nodes]') 115 | 116 | vim.command('edit %s' % (tmp_path / 'foo.py')) 117 | vim.current.buffer[:] = ['aaa', 'bbb'] 118 | wait_for(node_names, lambda x: x == ['aaa', 'bbb']) 119 | vim.feedkeys('yyp') 120 | wait_for(node_names, lambda x: x == ['aaa', 'aaa', 'bbb']) 121 | vim.feedkeys('x') 122 | wait_for(node_names, lambda x: set(x) == {'aaa', 'aa', 'bbb'}) 123 | vim.feedkeys('ib') 124 | wait_for(node_names, lambda x: set(x) == {'aaa', 'baa', 'bbb'}) 125 | 126 | 127 | @pytest.mark.xfail 128 | def test_highlights(): 129 | """Assert that highlights were applied correctly. This test can only be 130 | implemented once the neovim API provides a way to retrieve the currently 131 | active highlights. See: https://github.com/neovim/neovim/issues/6412 132 | """ 133 | raise NotImplementedError() # TODO 134 | 135 | 136 | def test_status(vim, tmp_path): 137 | """:Semshi status should report a syntax error.""" 138 | vim.command('edit %s' % (tmp_path / 'foo.py')) 139 | vim.current.buffer[:] = ['def():'] 140 | time.sleep(SLEEP) 141 | 142 | status = vim.command_output('Semshi status') 143 | print("\n", status) 144 | lines = status.split('\n') 145 | assert lines[0] == 'Semshi is attached on (bufnr=1)' 146 | assert lines[1] == '- current handler: ' 147 | assert lines[2] == '- handlers: {1: }' 148 | assert 'invalid syntax' in lines[3] 149 | 150 | 151 | def test_switch_handler(vim, tmp_path): 152 | """When switching to a different buffer, the current handlers is updated""" 153 | node_names = lambda: vim.host_eval( 154 | '[n.name for n in plugin._cur_handler._parser._nodes]') 155 | 156 | vim.command('edit %s' % (tmp_path / 'foo.py')) 157 | vim.current.buffer[:] = ['aaa', 'bbb'] 158 | wait_for(node_names, lambda x: x == ['aaa', 'bbb']) 159 | vim.command('edit %s' % (tmp_path / 'bar.py')) 160 | vim.current.buffer[:] = ['ccc'] 161 | wait_for(node_names, lambda x: x == ['ccc']) 162 | vim.command('bnext') 163 | wait_for(node_names, lambda x: x == ['aaa', 'bbb']) 164 | vim.command('edit %s' % (tmp_path / 'bar.notpython')) 165 | assert not vim.eval('get(b:, "semshi_attached", v:false)') 166 | assert vim.host_eval('plugin._cur_handler is None') 167 | assert vim.host_eval('len(plugin._handlers)') == 2 168 | vim.command('bprev') 169 | assert vim.eval('get(b:, "semshi_attached", v:false)') 170 | assert not vim.host_eval('plugin._cur_handler is None') 171 | 172 | 173 | def test_wipe_buffer(vim, tmp_path): 174 | """When a buffer is wiped out, the handler should be removed.""" 175 | # Without calling a Semshi command first, the subsequent assertions 176 | # accessing the plugin may fail 177 | vim.command('Semshi') 178 | assert vim.host_eval('len(plugin._handlers)') == 0 179 | vim.command('edit %s' % (tmp_path / 'foo.py')) 180 | assert vim.host_eval('len(plugin._handlers)') == 1 181 | vim.command('bwipeout %s' % (tmp_path / 'foo.py')) 182 | assert vim.host_eval('len(plugin._handlers)') == 0 183 | 184 | 185 | @pytest.mark.xfail(reason="Flaky test due to neovim.") 186 | def test_cursormoved_before_bufenter(vim, tmp_path): 187 | """When CursorMoved is triggered before BufEnter, switch the buffer.""" 188 | vim.command('edit %s' % (tmp_path / 'foo.py')) 189 | vim.command('new %s' % (tmp_path / 'bar.py')) 190 | vim.command('q') 191 | assert vim.host_eval('plugin._cur_handler._buf_num') == 1 192 | 193 | 194 | def test_selected_nodes(vim, tmp_path): 195 | """When moving the cursor above a node, it's registered as selected""" 196 | node_positions = lambda: vim.host_eval( 197 | '[n.pos for n in plugin._cur_handler._selected_nodes]') 198 | 199 | vim.command('edit %s' % (tmp_path / 'foo.py')) 200 | vim.current.buffer[:] = ['aaa', 'aaa'] 201 | vim.call('setpos', '.', [0, 1, 1]) 202 | wait_for(node_positions, lambda x: x == [[2, 0]]) 203 | vim.call('setpos', '.', [0, 2, 1]) 204 | wait_for(node_positions, lambda x: x == [[1, 0]]) 205 | 206 | 207 | def test_option_filetypes(start_vim): 208 | vim = start_vim(file='foo.py') 209 | assert vim.host_eval('plugin._cur_handler is not None') 210 | 211 | vim = start_vim(['--cmd', 'let g:semshi#filetypes = []'], file='foo.py') 212 | assert vim.host_eval('plugin._cur_handler is None') 213 | 214 | vim = start_vim(['--cmd', 'let g:semshi#filetypes = ["html", "php"]'], 215 | file='foo.py') 216 | assert vim.host_eval('plugin._cur_handler is None') 217 | 218 | vim = start_vim(['--cmd', 'let g:semshi#filetypes = ["html", "php"]'], 219 | file='foo.php') 220 | assert vim.host_eval('plugin._cur_handler is not None') 221 | 222 | 223 | def test_option_excluded_hl_groups(start_vim): 224 | vim = start_vim( 225 | ['--cmd', 'let g:semshi#excluded_hl_groups = ["global", "imported"]'], 226 | file='') 227 | # TODO Actually, we don't want to inspect the object but check which 228 | # highlights are applied - but we can't until the neovim API becomes 229 | # available. 230 | assert vim.host_eval( 231 | 'plugin._cur_handler._parser._excluded == ["semshiGlobal", "semshiImported"]' 232 | ) 233 | 234 | 235 | def test_option_mark_selected_nodes(start_vim): 236 | vim = start_vim(['--cmd', 'let g:semshi#mark_selected_nodes = 0'], file='') 237 | vim.current.buffer[:] = ['aaa', 'aaa', 'aaa'] 238 | vim.wait_for_update_thread() 239 | assert vim.host_eval('len(plugin._cur_handler._selected_nodes)') == 0 240 | 241 | vim = start_vim(file='') 242 | vim.current.buffer[:] = ['aaa', 'aaa', 'aaa'] 243 | vim.wait_for_update_thread() 244 | assert vim.host_eval('len(plugin._cur_handler._selected_nodes)') == 2 245 | 246 | vim = start_vim(['--cmd', 'let g:semshi#mark_selected_nodes = 2'], file='') 247 | vim.current.buffer[:] = ['aaa', 'aaa', 'aaa'] 248 | vim.wait_for_update_thread() 249 | assert vim.host_eval('len(plugin._cur_handler._selected_nodes)') == 3 250 | 251 | 252 | def test_option_no_default_builtin_highlight(start_vim): 253 | synstack_cmd = 'map(synstack(line("."), col(".")), "synIDattr(v:val, \'name\')")' 254 | vim = start_vim(file='') 255 | vim.current.buffer[:] = ['len'] 256 | assert vim.eval(synstack_cmd) == [] 257 | 258 | vim = start_vim(['--cmd', 'let g:semshi#no_default_builtin_highlight = 0'], 259 | file='') 260 | vim.current.buffer[:] = ['len'] 261 | assert vim.eval(synstack_cmd) == ['pythonBuiltin'] 262 | 263 | 264 | def test_option_always_update_all_highlights(start_vim): 265 | 266 | def get_ids(): 267 | time.sleep(SLEEP) 268 | return vim.host_eval( 269 | '[n.id for n in plugin._cur_handler._parser._nodes]') 270 | 271 | vim = start_vim(file='') 272 | vim.current.buffer[:] = ['aaa', 'aaa'] 273 | old = get_ids() 274 | vim.current.buffer[:] = ['aaa', 'aab'] 275 | new = get_ids() 276 | assert len(set(old) & set(new)) == 1 277 | 278 | vim = start_vim(['--cmd', 'let g:semshi#always_update_all_highlights = 1'], 279 | file='') 280 | vim.current.buffer[:] = ['aaa', 'aaa'] 281 | old = get_ids() 282 | vim.current.buffer[:] = ['aaa', 'aab'] 283 | new = get_ids() 284 | assert len(set(old) & set(new)) == 0 285 | 286 | 287 | def test_cmd_highlight(vim, tmp_path): 288 | vim.command('edit %s' % (tmp_path / 'foo.py')) 289 | tick = vim.host_eval('plugin._cur_handler._parser.tick') 290 | vim.command('Semshi highlight') 291 | assert vim.host_eval('plugin._cur_handler._parser.tick') > tick 292 | 293 | 294 | def test_syntax_error_sign(start_vim): 295 | jump_to_sign = 'exec "sign jump 314000 buffer=" . buffer_number("%")' 296 | 297 | vim = start_vim(['--cmd', 'let g:semshi#error_sign_delay = 0'], file='') 298 | vim.current.buffer[:] = ['+'] 299 | vim.wait_for_update_thread() 300 | time.sleep(SLEEP) 301 | vim.command(jump_to_sign) 302 | vim.current.buffer[:] = ['a'] 303 | vim.wait_for_update_thread() 304 | time.sleep(SLEEP) 305 | with pytest.raises(pynvim.api.NvimError): 306 | vim.command(jump_to_sign) 307 | 308 | vim = start_vim(['--cmd', 'let g:semshi#error_sign = 0'], file='') 309 | vim.current.buffer[:] = ['+'] 310 | vim.wait_for_update_thread() 311 | time.sleep(SLEEP) 312 | with pytest.raises(pynvim.api.NvimError): 313 | vim.command(jump_to_sign) 314 | 315 | vim = start_vim(['--cmd', 'let g:semshi#error_sign_delay = 1.0'], file='') 316 | vim.current.buffer[:] = ['+'] 317 | vim.wait_for_update_thread() 318 | time.sleep(SLEEP) 319 | with pytest.raises(pynvim.api.NvimError): 320 | vim.command(jump_to_sign) 321 | 322 | 323 | def test_option_tolerate_syntax_errors(start_vim): 324 | vim = start_vim(file='') 325 | vim.current.buffer[:] = ['a+'] 326 | time.sleep(SLEEP) 327 | num_nodes = vim.host_eval('len(plugin._cur_handler._parser._nodes)') 328 | assert num_nodes == 1 329 | 330 | vim = start_vim(['--cmd', 'let g:semshi#tolerate_syntax_errors = 0'], 331 | file='') 332 | vim.current.buffer[:] = ['a+'] 333 | time.sleep(SLEEP) 334 | num_nodes = vim.host_eval('len(plugin._cur_handler._parser._nodes)') 335 | assert num_nodes == 0 336 | 337 | 338 | def test_option_update_delay_factor(start_vim): 339 | vim = start_vim(['--cmd', 'let g:semshi#update_delay_factor = 2'], file='') 340 | time.sleep(SLEEP) 341 | vim.current.buffer[:] = ['foo'] 342 | time.sleep(SLEEP) 343 | num_nodes = vim.host_eval('len(plugin._cur_handler._parser._nodes)') 344 | assert num_nodes == 0 345 | 346 | 347 | def test_option_self_to_attribute(start_vim): 348 | buf = ['class Foo:', ' def foo(self): self.bar, self.bar'] 349 | selected = lambda: vim.host_eval( 350 | '[n.pos for n in plugin._cur_handler._selected_nodes]') 351 | 352 | vim = start_vim(file='') 353 | vim.current.buffer[:] = buf 354 | vim.current.window.cursor = [2, 16] 355 | wait_for(selected, lambda x: x == [[2, 31]]) 356 | 357 | vim = start_vim(['--cmd', 'let g:semshi#self_to_attribute = 0'], file='') 358 | vim.current.buffer[:] = buf 359 | vim.current.window.cursor = [2, 16] 360 | wait_for(selected, lambda x: x == [[2, 9], [2, 26]]) 361 | 362 | 363 | def test_rename(start_vim): 364 | vim = start_vim(file='') 365 | vim.current.buffer[:] = ['aaa, aaa, bbb', 'aaa'] 366 | time.sleep(SLEEP) 367 | vim.command('Semshi rename xxyyzz') 368 | time.sleep(SLEEP) 369 | assert vim.current.buffer[:] == ['xxyyzz, xxyyzz, bbb', 'xxyyzz'] 370 | # The command blocks until an input is received, so we need to call async 371 | # and sleep 372 | time.sleep(SLEEP) 373 | vim.command('Semshi rename', async_=True) 374 | time.sleep(SLEEP) 375 | vim.feedkeys('CC\n') 376 | time.sleep(SLEEP) 377 | assert vim.current.buffer[:] == ['CC, CC, bbb', 'CC'] 378 | 379 | 380 | def test_goto(start_vim): 381 | vim = start_vim(file='') 382 | 383 | def cursor_loc_at(loc): 384 | return lambda: tuple(vim.current.window.cursor) == tuple(loc) 385 | 386 | vim.current.buffer[:] = [ 387 | 'class Foo:', 388 | ' def foo(self): pass', 389 | 'class Bar: pass', 390 | 'class Baz: pass', 391 | ] 392 | time.sleep(SLEEP) 393 | vim.command('Semshi goto function next') 394 | wait_for(cursor_loc_at([2, 1])) 395 | vim.command('Semshi goto class prev') 396 | wait_for(cursor_loc_at([1, 0])) 397 | vim.command('Semshi goto class last') 398 | wait_for(cursor_loc_at([4, 0])) 399 | vim.command('Semshi goto class first') 400 | wait_for(cursor_loc_at([1, 0])) 401 | 402 | 403 | def test_goto_name(start_vim): 404 | vim = start_vim(file='') 405 | 406 | def cursor_loc_at(loc): 407 | return lambda: tuple(vim.current.window.cursor) == tuple(loc) 408 | 409 | vim.current.buffer[:] = ['aaa, aaa, aaa'] 410 | time.sleep(SLEEP) 411 | vim.command('Semshi goto name next') 412 | wait_for(cursor_loc_at([1, 5])) 413 | time.sleep(SLEEP) 414 | vim.command('Semshi goto name next') 415 | wait_for(cursor_loc_at([1, 10])) 416 | time.sleep(SLEEP) 417 | vim.command('Semshi goto name next') 418 | wait_for(cursor_loc_at([1, 0])) 419 | time.sleep(SLEEP) 420 | vim.command('Semshi goto name prev') 421 | wait_for(cursor_loc_at([1, 10])) 422 | time.sleep(SLEEP) 423 | vim.command('Semshi goto name prev') 424 | wait_for(cursor_loc_at([1, 5])) 425 | 426 | 427 | def test_goto_hl_group(start_vim): 428 | vim = start_vim(file='') 429 | vim.current.buffer[:] = [ 430 | 'foo = 1', 431 | 'def x(y): pass', 432 | ] 433 | time.sleep(SLEEP) 434 | vim.command('Semshi goto parameterUnused first') 435 | wait_for(lambda: tuple(vim.current.window.cursor) == (2, 6)) 436 | 437 | 438 | def test_goto_error(start_vim): 439 | vim = start_vim(['--cmd', 'let g:semshi#error_sign_delay = 0'], file='') 440 | vim.current.buffer[:] = ['a', '+'] 441 | vim.wait_for_update_thread() 442 | assert tuple(vim.current.window.cursor) == (1, 0) 443 | vim.command('Semshi goto error') 444 | assert tuple(vim.current.window.cursor) == (2, 0) 445 | 446 | 447 | def test_clear(start_vim): 448 | vim = start_vim(file='') 449 | vim.current.buffer[:] = ['aaa'] 450 | time.sleep(SLEEP) 451 | vim.command('Semshi clear') 452 | assert vim.host_eval('len(plugin._cur_handler._parser._nodes)') == 0 453 | 454 | 455 | def test_enable_disable(start_vim): 456 | 457 | # yapf: disable 458 | def num_nodes(n): 459 | def f(): 460 | return vim.host_eval( 461 | 'len(plugin._cur_handler._parser._nodes)') == n 462 | return f 463 | # yapf: enable 464 | 465 | def no_handler(): 466 | return vim.host_eval('plugin._cur_handler is None') 467 | 468 | vim = start_vim(file='') 469 | vim.current.buffer[:] = ['aaa'] 470 | wait_for(num_nodes(1)) 471 | vim.command('Semshi disable') 472 | wait_for(no_handler) 473 | vim.command('Semshi enable') 474 | wait_for(num_nodes(1)) 475 | vim.command('Semshi disable') 476 | wait_for(no_handler) 477 | vim.command('Semshi enable') 478 | wait_for(num_nodes(1)) 479 | 480 | 481 | def test_enable_by_filetype(start_vim): 482 | vim = start_vim(file='foo.ext') 483 | vim.current.buffer[:] = ['aaa'] 484 | assert vim.host_eval('plugin._cur_handler is None') 485 | assert vim.host_eval('len(plugin._handlers)') == 0 486 | vim.command('set ft=python') 487 | assert vim.eval('b:semshi_attached') 488 | assert vim.host_eval('plugin._cur_handler is not None') 489 | assert vim.host_eval('len(plugin._handlers)') == 1 490 | vim.command('set ft=php') 491 | assert not vim.eval('b:semshi_attached') 492 | assert vim.host_eval('plugin._cur_handler is None') 493 | assert vim.host_eval('len(plugin._handlers)') == 0 494 | 495 | 496 | def test_pause(start_vim): 497 | vim = start_vim(file='') 498 | vim.current.buffer[:] = ['aaa'] 499 | time.sleep(SLEEP) 500 | vim.command('Semshi pause') 501 | assert vim.host_eval('len(plugin._cur_handler._parser._nodes)') == 1 502 | vim.current.buffer[:] = ['aaa, bbb'] 503 | time.sleep(SLEEP) 504 | assert vim.host_eval('len(plugin._cur_handler._parser._nodes)') == 1 505 | 506 | 507 | def test_bug_21(start_vim): 508 | vim = start_vim(file='foo.ext') 509 | with pytest.raises(pynvim.api.NvimError, match='.*not enabled.*'): 510 | vim.command('Semshi goto error') 511 | -------------------------------------------------------------------------------- /rplugin/python3/semshi/handler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import threading 4 | import time 5 | from collections import defaultdict 6 | from typing import Optional 7 | 8 | import pynvim 9 | import pynvim.api 10 | from pynvim.api import Buffer, Nvim 11 | 12 | from . import plugin 13 | from .node import SELECTED, Node, hl_groups 14 | from .parser import Parser, UnparsableError 15 | from .util import debug_time, lines_to_code, logger 16 | 17 | ERROR_SIGN_ID = 314000 18 | ERROR_HL_ID = 313000 19 | 20 | 21 | class BufferHandler: 22 | """Handler for a buffer. 23 | 24 | The handler runs the parser, adds and removes highlights, keeps tracks of 25 | which highlights are visible and which ones need to be added or removed. 26 | """ 27 | 28 | def __init__(self, buf: Buffer, vim: Nvim, options: plugin.Options): 29 | self._buf = buf 30 | self._vim = vim 31 | self._options = options 32 | self._buf_num = buf.number 33 | self._parser = Parser(options.excluded_hl_groups, 34 | options.tolerate_syntax_errors) 35 | self._scheduled = False 36 | self._viewport_changed = False 37 | self._view = (0, 0) 38 | self._update_thread = None 39 | self._error_timer = None 40 | self._indicated_syntax_error = None 41 | # Nodes which are active but pending to be displayed because they are 42 | # in a currently invisible area. 43 | self._pending_nodes = [] 44 | # Nodes which are currently marked as a selected. We keep track of them 45 | # to check if they haven't changed between updates. 46 | self._selected_nodes = [] 47 | 48 | def __repr__(self): 49 | return '' % self._buf_num 50 | 51 | def print(self, s): 52 | """A debugging utility to print something into neovim's stdout.""" 53 | self._vim.async_call(self._vim.api.out_write, str(s) + '\n') 54 | 55 | def viewport(self, start, stop): 56 | """Set viewport to line range from `start` to `stop` and add highlights 57 | that have become visible.""" 58 | range = stop - start 59 | self._view = (start - range, stop + range) 60 | # If the update thread is running, we defer addding visible highlights 61 | # for the new viewport to after the update loop is done. 62 | if self._update_thread is not None and self._update_thread.is_alive(): 63 | self._viewport_changed = True 64 | return 65 | self._add_visible_hls() 66 | 67 | def update(self, force=False, sync=False): 68 | """Update. 69 | 70 | If `sync`, trigger update immediately, otherwise start thread to update 71 | code if thread isn't running already. 72 | """ 73 | if sync: 74 | self._update_step(force=force, sync=True) 75 | return 76 | thread = self._update_thread 77 | # If there is an active update thread... 78 | if thread is not None and thread.is_alive(): 79 | # ...just make sure sure it runs another time. 80 | self._scheduled = True 81 | return 82 | # Otherwise, start a new update thread. 83 | thread = threading.Thread(target=self._update_loop) 84 | self._update_thread = thread 85 | thread.start() 86 | 87 | def clear_highlights(self): 88 | """Clear all highlights.""" 89 | self._update_step(force=True, sync=True, code='') 90 | 91 | @debug_time 92 | def mark_selected(self, cursor): 93 | """Mark all selected nodes. 94 | 95 | Selected nodes are those with the same name and scope as the one at the 96 | cursor position. 97 | """ 98 | if not self._options.mark_selected_nodes: 99 | return 100 | mark_original = bool(self._options.mark_selected_nodes - 1) 101 | nodes = self._parser.same_nodes(cursor, mark_original, 102 | self._options.self_to_attribute) 103 | start, stop = self._view 104 | nodes = [n for n in nodes if start <= n.lineno <= stop] 105 | if nodes == self._selected_nodes: 106 | return 107 | self._selected_nodes = nodes 108 | self._clear_hls(nodes_to_hl(nodes, clear=True, marked=True)) 109 | self._add_hls(nodes_to_hl(nodes, marked=True)) 110 | 111 | def _wait_for(self, func, sync=False): 112 | """Return `func()`. If not `sync`, run `func` in async context and 113 | block until result is available. 114 | 115 | Required for when we need the result of an API call from a thread. 116 | """ 117 | if sync: 118 | return func() 119 | event = threading.Event() 120 | res = None 121 | 122 | def wrapper(): 123 | nonlocal res 124 | res = func() 125 | event.set() 126 | 127 | self._vim.async_call(wrapper) 128 | event.wait() 129 | return res 130 | 131 | def _wrap_async(self, func): 132 | """ 133 | Wraps `func` so that invocation of `func(args, kwargs)` happens 134 | from the main thread. This is a requirement of neovim API when 135 | function call happens from other threads. 136 | Related issue: https://github.com/numirias/semshi/issues/25 137 | """ 138 | 139 | def wrapper(*args, **kwargs): 140 | return self._vim.async_call(func, *args, **kwargs) 141 | 142 | return wrapper 143 | 144 | def _update_loop(self): 145 | try: 146 | while True: 147 | delay_factor = self._options.update_delay_factor 148 | if delay_factor > 0: 149 | time.sleep(delay_factor * len(self._parser.lines)) 150 | self._update_step(self._options.always_update_all_highlights) 151 | if not self._scheduled: 152 | break 153 | self._scheduled = False 154 | if self._viewport_changed: 155 | self._add_visible_hls() 156 | self._viewport_changed = False 157 | except Exception: 158 | import traceback # pylint: disable=import-outside-toplevel 159 | logger.error('Exception: %s', traceback.format_exc()) 160 | raise 161 | 162 | @debug_time 163 | def _update_step(self, force=False, sync=False, code=None): 164 | """Trigger parser, update highlights accordingly, and trigger update of 165 | error sign. 166 | """ 167 | if code is None: 168 | code = self._wait_for(lambda: lines_to_code(self._buf[:]), sync) 169 | try: 170 | add, rem = self._parser.parse(code, force) 171 | except UnparsableError: 172 | pass 173 | else: 174 | # TODO If we force update, can't we just clear all pending? 175 | # Remove nodes to be cleared from pending list 176 | rem_remaining = debug_time('remove from pending')( 177 | lambda: list(self._remove_from_pending(rem)))() 178 | add_visible, add_hidden = self._visible_and_hidden(add) 179 | # Add all new but hidden nodes to pending list 180 | self._pending_nodes += add_hidden 181 | # Update highlights by adding all new visible nodes and removing 182 | # all old nodes which have been drawn earlier 183 | self._update_hls(add_visible, rem_remaining) 184 | self.mark_selected( 185 | self._wait_for(lambda: self._vim.current.window.cursor, sync)) 186 | if self._options.error_sign: 187 | self._schedule_update_error_sign() 188 | 189 | @debug_time 190 | def _add_visible_hls(self): 191 | """Add highlights in the current viewport which have not been applied 192 | yet.""" 193 | visible, hidden = self._visible_and_hidden(self._pending_nodes) 194 | self._add_hls(nodes_to_hl(visible)) 195 | self._pending_nodes = hidden 196 | 197 | def _visible_and_hidden(self, nodes): 198 | """Bisect nodes into visible and hidden ones.""" 199 | start, end = self._view 200 | visible = [] 201 | hidden = [] 202 | for node in nodes: 203 | if start <= node.lineno <= end: 204 | visible.append(node) 205 | else: 206 | hidden.append(node) 207 | return visible, hidden 208 | 209 | # pylint: disable=protected-access 210 | @debug_time(None, lambda s, n: '%d / %d' % (len(n), len(s._pending_nodes))) 211 | def _remove_from_pending(self, nodes): 212 | """Return nodes which couldn't be removed from the pending list (which 213 | means they need to be cleared from the buffer). 214 | """ 215 | for node in nodes: 216 | try: 217 | self._pending_nodes.remove(node) 218 | except ValueError: 219 | # TODO Can we maintain a list of nodes that should be active 220 | # instead of creating it here? 221 | yield node 222 | 223 | def _schedule_update_error_sign(self): 224 | if self._error_timer is not None: 225 | self._error_timer.cancel() 226 | if self._indicated_syntax_error is not None: 227 | self._update_error_indicator() 228 | return 229 | # Delay update to prevent the error sign from flashing while typing. 230 | timer = threading.Timer(self._options.error_sign_delay, 231 | self._update_error_indicator) 232 | self._error_timer = timer 233 | timer.start() 234 | 235 | def _update_error_indicator(self): 236 | cur_error = self._indicated_syntax_error 237 | error = self._parser.syntax_errors[-1] 238 | self._indicated_syntax_error = error 239 | if cur_error is not None and error is not None and \ 240 | (error.lineno, error.offset, error.msg) == \ 241 | (cur_error.lineno, cur_error.offset, cur_error.msg): 242 | return 243 | self._unplace_sign(ERROR_SIGN_ID) 244 | self._wrap_async(self._buf.clear_highlight)(ERROR_HL_ID) 245 | if error is None: 246 | return 247 | self._place_sign(ERROR_SIGN_ID, error.lineno, 'semshiError') 248 | lineno, offset = self._error_pos(error) 249 | self._wrap_async(self._buf.add_highlight)( 250 | 'semshiErrorChar', 251 | lineno - 1, 252 | offset, 253 | offset + 1, 254 | ERROR_HL_ID, 255 | ) 256 | 257 | @property 258 | def syntax_error(self) -> Optional[SyntaxError]: 259 | """Get the current syntax error as string.""" 260 | return self._parser.syntax_errors[-1] 261 | 262 | def _place_sign(self, id, line, name): 263 | command = self._wrap_async(self._vim.command) 264 | command('sign place %d line=%d name=%s buffer=%d' % 265 | (id, line, name, self._buf_num), 266 | async_=True) 267 | 268 | def _unplace_sign(self, id): 269 | command = self._wrap_async(self._vim.command) 270 | command('sign unplace %d buffer=%d' % (id, self._buf_num), async_=True) 271 | 272 | @debug_time(None, lambda _, a, c: '+%d, -%d' % (len(a), len(c))) 273 | def _update_hls(self, add, clear): 274 | self._add_hls(nodes_to_hl(add)) 275 | self._clear_hls(nodes_to_hl(clear, clear=True)) 276 | 277 | @debug_time(None, lambda _, nodes: '%d nodes' % len(nodes)) 278 | def _add_hls(self, node_or_nodes): 279 | buf = self._buf 280 | if not node_or_nodes: 281 | return 282 | if not isinstance(node_or_nodes, list): 283 | buf.add_highlight(*node_or_nodes) 284 | return 285 | self._call_atomic_async([('nvim_buf_add_highlight', (buf, *n)) 286 | for n in node_or_nodes]) 287 | 288 | @debug_time(None, lambda _, nodes: '%d nodes' % len(nodes)) 289 | def _clear_hls(self, node_or_nodes): 290 | buf = self._buf 291 | if not node_or_nodes: 292 | return 293 | if not isinstance(node_or_nodes, list): 294 | self._wrap_async(buf.clear_highlight)(*node_or_nodes) 295 | return 296 | # Don't specify line range to clear explicitly because we can't 297 | # reliably determine the correct range 298 | self._call_atomic_async([('nvim_buf_clear_highlight', (buf, *n)) 299 | for n in node_or_nodes]) 300 | 301 | def _call_atomic_async(self, calls): 302 | # Need to update in small batches to avoid 303 | # https://github.com/neovim/python-client/issues/310 304 | batch_size = 3000 305 | 306 | def _call_atomic(call_chunk, **kwargs): 307 | # when nvim_call_atomic is actually being executed 308 | # (due to an asynchronous call), the buffer might be gone. 309 | # To avoid 'invalid buffer id' errors, we validate the buffer. 310 | if not self._vim.api.buf_is_valid(self._buf): 311 | logger.debug('buffer %d was wiped out, skipping call_atomic', 312 | self._buf) 313 | return None 314 | return self._vim.api.call_atomic(call_chunk, **kwargs) 315 | 316 | call_atomic = self._wrap_async(_call_atomic) 317 | for i in range(0, len(calls), batch_size): 318 | call_atomic(calls[i:i + batch_size], async_=True) 319 | 320 | def rename(self, cursor, new_name=None): 321 | """Rename node at `cursor` to `new_name`. If `new_name` is None, prompt 322 | for new name.""" 323 | cur_node = self._parser.node_at(cursor) 324 | if cur_node is None: 325 | self._vim.out_write('Nothing to rename here.\n') 326 | return 327 | nodes = list( 328 | self._parser.same_nodes( 329 | cur_node, 330 | mark_original=True, 331 | use_target=self._options.self_to_attribute, 332 | )) 333 | num = len(nodes) 334 | if new_name is None: 335 | new_name = self._vim.eval('input("Rename %d nodes to: ")' % num) 336 | # Can't output a carriage return via out_write() 337 | self._vim.command('echo "\r"') 338 | if not new_name or new_name == cur_node.name: 339 | self._vim.out_write('Nothing renamed.\n') 340 | return 341 | lines = self._buf[:] 342 | lines_to_nodes = defaultdict(list) 343 | for node in nodes: 344 | lines_to_nodes[node.lineno].append(node) 345 | for lineno, nodes_in_line in lines_to_nodes.items(): 346 | offset = 0 347 | line = lines[lineno - 1] 348 | for node in sorted(nodes_in_line, key=lambda n: n.col): 349 | line = (line[:node.col + offset] + new_name + 350 | line[node.col + len(node.name) + offset:]) 351 | offset += len(new_name) - len(node.name) 352 | self._buf[lineno - 1] = line 353 | self._vim.out_write('%d nodes renamed.\n' % num) 354 | 355 | def goto(self, what, direction=None): 356 | """Go to next location of type `what` in direction `direction`.""" 357 | if what == 'error': 358 | self._goto_error() 359 | return 360 | # pylint: disable=import-outside-toplevel 361 | from ast import AsyncFunctionDef, ClassDef, FunctionDef 362 | here = tuple(self._vim.current.window.cursor) 363 | if what == 'name': 364 | cur_node = self._parser.node_at(here) 365 | if cur_node is None: 366 | return 367 | locs = sorted([ 368 | n.pos for n in self._parser.same_nodes( 369 | cur_node, use_target=self._options.self_to_attribute) 370 | ]) 371 | elif what == 'class': 372 | locs = self._parser.locations_by_node_types([ 373 | ClassDef, 374 | ]) 375 | elif what == 'function': 376 | locs = self._parser.locations_by_node_types([ 377 | FunctionDef, 378 | AsyncFunctionDef, 379 | ]) 380 | elif what in hl_groups: 381 | locs = self._parser.locations_by_hl_group(hl_groups[what]) 382 | else: 383 | raise ValueError('"%s" is not a recognized element type.' % what) 384 | if not locs: 385 | return 386 | if direction == 'first': 387 | new_loc = locs[0] 388 | elif direction == 'last': 389 | new_loc = locs[-1] 390 | else: 391 | new_loc = next_location(here, locs, (direction == 'prev')) 392 | try: 393 | self._vim.current.window.cursor = new_loc 394 | except pynvim.api.NvimError: 395 | # This can happen when the new cursor position is outside the 396 | # buffer because the code wasn't re-parsed after a buffer change. 397 | pass 398 | 399 | def _goto_error(self): 400 | """Go to syntax error.""" 401 | error = self._indicated_syntax_error 402 | if error is None: 403 | return 404 | self._vim.current.window.cursor = self._error_pos(error) 405 | 406 | def _error_pos(self, error): 407 | """Return a position for the syntax error `error` which is guaranteed 408 | to be a valid position in the buffer.""" 409 | offset = max(1, min(error.offset, \ 410 | len(self._parser.lines[error.lineno - 1]) 411 | ) 412 | ) - 1 413 | return (error.lineno, offset) 414 | 415 | def show_error(self): 416 | error = self._indicated_syntax_error 417 | if error is None: 418 | self._vim.out_write('No syntax error to show.\n') 419 | return 420 | self._vim.out_write('Syntax error: %s (%d, %d)\n' % 421 | (error.msg, error.lineno, error.offset)) 422 | 423 | def shutdown(self): 424 | # Cancel the error timer so vim quits immediately 425 | if self._error_timer is not None: 426 | self._error_timer.cancel() 427 | 428 | 429 | def nodes_to_hl(nodes, clear=False, marked=False): 430 | """Convert list of nodes to highlight tuples which are the arguments to 431 | neovim's add_highlight/clear_highlight APIs.""" 432 | if clear: 433 | if marked: 434 | return (Node.MARK_ID, 0, -1) 435 | return [(n.id, 0, -1) for n in nodes] 436 | if marked: 437 | id = Node.MARK_ID 438 | return [(id, SELECTED, n.lineno - 1, n.col, n.end) for n in nodes] 439 | return [(n.id, n.hl_group, n.lineno - 1, n.col, n.end) for n in nodes] 440 | 441 | 442 | def next_location(here, locs, reverse=False): 443 | """Return the location of `locs` that comes after `here`.""" 444 | locs = locs[:] 445 | if here not in locs: 446 | locs.append(here) 447 | locs = sorted(locs) 448 | return locs[(locs.index(here) + (-1 if reverse else 1)) % len(locs)] 449 | -------------------------------------------------------------------------------- /rplugin/python3/semshi/visitor.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unidiomatic-typecheck 2 | import ast 3 | import contextlib 4 | import sys 5 | from itertools import count 6 | from token import NAME, OP 7 | from tokenize import tokenize 8 | 9 | from .node import ATTRIBUTE, IMPORTED, PARAMETER_UNUSED, SELF, Node 10 | from .util import debug_time 11 | 12 | # PEP-695 type statement (Python 3.12+) 13 | if sys.version_info >= (3, 12): 14 | TYPE_VARS = (ast.TypeVar, ast.ParamSpec, ast.TypeVarTuple) 15 | else: 16 | TYPE_VARS = () 17 | 18 | HAS_PY313 = sys.version_info >= (3, 13) 19 | 20 | # Node types which introduce a new scope and child symboltable 21 | BLOCKS = ( 22 | ast.Module, 23 | ast.Lambda, 24 | ast.FunctionDef, 25 | ast.AsyncFunctionDef, 26 | ast.ClassDef, 27 | ast.GeneratorExp, 28 | ) 29 | if sys.version_info < (3, 12): 30 | # PEP-709: comprehensions no longer have dedicated stack frames; the 31 | # comprehension's local will be included in the parent function's symtable 32 | # (Note: generator expressions are excluded in Python 3.12) 33 | BLOCKS = tuple([*BLOCKS, ast.ListComp, ast.DictComp, ast.SetComp]) 34 | 35 | FUNCTION_BLOCKS = (ast.FunctionDef, ast.Lambda, ast.AsyncFunctionDef) 36 | 37 | # Node types which don't require any action 38 | if sys.version_info < (3, 8): 39 | SKIP = (ast.NameConstant, ast.Str, ast.Num) 40 | else: 41 | from ast import Constant # pylint: disable=ungrouped-imports 42 | SKIP = (Constant, ) 43 | SKIP += (ast.Store, ast.Load, \ 44 | ast.Eq, ast.Lt, ast.Gt, ast.NotEq, ast.LtE, ast.GtE) 45 | 46 | 47 | def tokenize_lines(lines): 48 | return tokenize(((line + '\n').encode('utf-8') for line in lines).__next__) 49 | 50 | 51 | def advance(tokens, s=None, type=NAME): 52 | """Advance token stream `tokens`. 53 | 54 | Advances to next token of type `type` with the string representation `s` or 55 | matching one of the strings in `s` if `s` is an iterable. Without any 56 | arguments, just advances to next NAME token. 57 | """ 58 | if s is None: 59 | cond = lambda token: True 60 | elif isinstance(s, str): 61 | cond = lambda token: token.string == s 62 | else: 63 | cond = lambda token: token.string in s 64 | return next(t for t in tokens if t.type == type and cond(t)) 65 | 66 | 67 | @debug_time 68 | def visitor(lines, symtable_root, ast_root): 69 | visitor = Visitor(lines, symtable_root) 70 | visitor.visit(ast_root) 71 | return visitor.nodes 72 | 73 | 74 | class Visitor: 75 | """The visitor visits the AST recursively to extract relevant name nodes in 76 | their context. 77 | """ 78 | 79 | def __init__(self, lines, root_table): 80 | self._lines = lines 81 | self._table_stack = [root_table] 82 | self._env = [] 83 | # Holds a copy of the current environment to avoid repeated copying 84 | self._cur_env = None 85 | self.nodes = [] 86 | 87 | def visit(self, node): 88 | """Recursively visit the node to build a list of names in their scopes. 89 | 90 | In some contexts, nodes appear in a different order than the scopes are 91 | nested. In that case, attributes of a node might be visitied before 92 | creating a new scope and deleted afterwards so they are not revisited 93 | later. 94 | """ 95 | # Use type() because it's faster than the more idiomatic isinstance() 96 | type_ = type(node) 97 | if type_ is ast.Name: 98 | self._new_name(node) 99 | return 100 | if type_ in TYPE_VARS: # handle type variables (Python 3.12+) 101 | self._visit_typevar(node) 102 | return 103 | if type_ is ast.Attribute: 104 | self._add_attribute(node) 105 | self.visit(node.value) 106 | return 107 | if type_ in SKIP: 108 | return 109 | 110 | if type_ is ast.Try: 111 | self._visit_try(node) 112 | elif type_ is ast.ExceptHandler: 113 | self._visit_except(node) 114 | elif type_ in (ast.Import, ast.ImportFrom): 115 | self._visit_import(node) 116 | elif type_ is ast.arg: 117 | self._visit_arg(node) 118 | elif type_ in FUNCTION_BLOCKS: 119 | self._visit_arg_defaults(node) 120 | elif type_ in (ast.ListComp, ast.SetComp, ast.DictComp, 121 | ast.GeneratorExp): 122 | self._visit_comp(node) 123 | elif type_ in (ast.Global, ast.Nonlocal): 124 | keyword = 'global' if type_ is ast.Global else 'nonlocal' 125 | self._visit_global_nonlocal(node, keyword) 126 | elif type_ is ast.keyword: 127 | pass 128 | elif TYPE_VARS and type_ is ast.TypeAlias: # Python 3.12+ 129 | self._visit_type(node) 130 | return # scope already handled 131 | 132 | if type_ in (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef): 133 | self._visit_class_function_definition(node) 134 | 135 | # Either make a new block scope... 136 | if type_ in BLOCKS: 137 | with self._enter_scope() as current_table: 138 | if type_ in FUNCTION_BLOCKS: 139 | current_table.unused_params = {} 140 | self._iter_node(node) 141 | # Set the hl group of all parameters that didn't appear in the 142 | # function body to "unused parameter". 143 | for param in current_table.unused_params.values(): 144 | if param.hl_group == SELF: 145 | # SELF args should never be shown as unused 146 | continue 147 | param.hl_group = PARAMETER_UNUSED 148 | param.update_tup() 149 | else: 150 | self._iter_node(node) 151 | # ...or just iterate through the node's (remaining) attributes. 152 | else: 153 | self._iter_node(node) 154 | 155 | @contextlib.contextmanager 156 | def _enter_scope(self): 157 | # Enter a local lexical variable scope (env represented by symtables). 158 | current_table = self._table_stack.pop() 159 | # The order of children symtables is not guaranteed and in fact 160 | # differs between CPython 3.13+ and prior versions. Sorting them in 161 | # the order they appear ensures consistency with AST visitation. 162 | children = sorted(current_table.get_children(), 163 | key=lambda st: st.get_lineno()) 164 | self._table_stack += reversed(children) 165 | self._env.append(current_table) 166 | self._cur_env = self._env[:] 167 | yield current_table 168 | self._env.pop() 169 | self._cur_env = self._env[:] 170 | 171 | def _new_name(self, node): 172 | self.nodes.append(Node( 173 | node.id, 174 | node.lineno, 175 | node.col_offset, 176 | self._cur_env, 177 | # Using __dict__.get() is faster than getattr() 178 | node.__dict__.get('_target'), 179 | )) # yapf: disable 180 | 181 | def _visit_arg(self, node): 182 | """Visit function argument.""" 183 | node = Node(node.arg, node.lineno, node.col_offset, self._cur_env) 184 | self.nodes.append(node) 185 | # Register as unused parameter for now. The entry is removed if it's 186 | # found to be used later. 187 | self._env[-1].unused_params[node.name] = node 188 | 189 | def _visit_arg_defaults(self, node): 190 | """Visit argument default values.""" 191 | for arg_ in node.args.defaults + node.args.kw_defaults: 192 | self.visit(arg_) 193 | del node.args.defaults 194 | del node.args.kw_defaults 195 | 196 | def _visit_try(self, node): 197 | """Visit try-except.""" 198 | for child in node.body: 199 | self.visit(child) 200 | del node.body 201 | for child in node.handlers: 202 | self.visit(child) 203 | del node.handlers 204 | for child in node.orelse: 205 | self.visit(child) 206 | del node.orelse 207 | for child in node.finalbody: 208 | self.visit(child) 209 | del node.finalbody 210 | 211 | def _visit_except(self, node): 212 | """Visit except branch.""" 213 | if node.name is None: 214 | # There is no "as ..." branch, so don't do anything. 215 | return 216 | # We can't really predict the line for "except-as", so we must always 217 | # tokenize. 218 | line_idx = node.lineno - 1 219 | tokens = tokenize_lines(self._lines[i] for i in count(line_idx)) 220 | advance(tokens, 'as') 221 | token = advance(tokens) 222 | lineno = token.start[0] + line_idx 223 | cur_line = self._lines[lineno - 1] 224 | self.nodes.append(Node( 225 | node.name, 226 | lineno, 227 | len(cur_line[:token.start[1]].encode('utf-8')), 228 | self._cur_env, 229 | )) # yapf: disable 230 | 231 | def _visit_comp(self, node): 232 | """Visit set/dict/list comprehension or generator expression.""" 233 | generator = node.generators[0] 234 | self.visit(generator.iter) 235 | del generator.iter 236 | 237 | def _visit_class_meta(self, node): 238 | """Visit class bases and keywords.""" 239 | for base in node.bases: 240 | self.visit(base) 241 | del node.bases 242 | for keyword in node.keywords: 243 | self.visit(keyword) 244 | del node.keywords 245 | 246 | def _visit_args(self, node): 247 | """Visit function arguments.""" 248 | # We'd want to visit args.posonlyargs, but it appears an internal bug 249 | # is preventing that. See: https://stackoverflow.com/q/59066024/5765873 250 | for arg in node.args.posonlyargs: 251 | del arg.annotation 252 | self._visit_args_pre38(node) 253 | 254 | def _visit_args_pre38(self, node): 255 | # args: ast.arguments 256 | args = node.args 257 | for arg in args.args + args.kwonlyargs + [args.vararg, args.kwarg]: 258 | if arg is None: 259 | continue 260 | self.visit(arg.annotation) 261 | del arg.annotation 262 | self.visit(node.returns) 263 | del node.returns 264 | 265 | def _visit_import(self, node): 266 | """Visit import statement. 267 | 268 | Unlike other nodes in the AST, names in import statements don't come 269 | with a specified line number and column. Therefore, we need to use the 270 | tokenizer on that part of the code to get the exact position. Since 271 | using the tokenize module is slow, we only use it where absolutely 272 | necessary. 273 | """ 274 | line_idx = node.lineno - 1 275 | # We first try to guess the import line to avoid having to use the 276 | # tokenizer. This will fail in some cases as we just cover the most 277 | # common import syntax. 278 | name = node.names[0].name 279 | asname = node.names[0].asname 280 | target = asname or name 281 | if target != '*' and '.' not in target: 282 | guess = 'import ' + name + (' as ' + asname if asname else '') 283 | if isinstance(node, ast.ImportFrom): 284 | guess = 'from ' + (node.module or node.level * '.') + ' ' + \ 285 | guess 286 | if self._lines[line_idx] == guess: 287 | self.nodes.append(Node( 288 | target, 289 | node.lineno, 290 | len(guess.encode('utf-8')) - len(target.encode('utf-8')), 291 | self._cur_env, 292 | None, 293 | IMPORTED, 294 | )) # yapf: disable 295 | return 296 | # Guessing the line failed, so we need to use the tokenizer 297 | tokens = tokenize_lines(self._lines[i] for i in count(line_idx)) 298 | while True: 299 | # Advance to next "import" keyword 300 | token = advance(tokens, 'import') 301 | cur_line = self._lines[line_idx + token.start[0] - 1] 302 | # Determine exact byte offset. token.start[1] just holds the char 303 | # index which may give a wrong position. 304 | offset = len(cur_line[:token.start[1]].encode('utf-8')) 305 | # ...until we found the matching one. 306 | if offset >= node.col_offset: 307 | break 308 | for alias, more in zip(node.names, count(1 - len(node.names))): 309 | if alias.name == '*': 310 | continue 311 | # If it's an "as" alias import... 312 | if alias.asname is not None: 313 | # ...advance to "as" keyword. 314 | advance(tokens, 'as') 315 | token = advance(tokens) 316 | cur_line = self._lines[line_idx + token.start[0] - 1] 317 | self.nodes.append(Node( 318 | token.string, 319 | token.start[0] + line_idx, 320 | # Exact byte offset of the token 321 | len(cur_line[:token.start[1]].encode('utf-8')), 322 | self._cur_env, 323 | None, 324 | IMPORTED, 325 | )) # yapf: disable 326 | 327 | # If there are more imports in that import statement... 328 | if more: 329 | # ...they must be comma-separated, so advance to next comma. 330 | advance(tokens, ',', OP) 331 | 332 | def _visit_class_function_definition(self, node): 333 | """Visit class or function definition. 334 | 335 | We need to use the tokenizer here for the same reason as in 336 | _visit_import (no line/col for names in class/function definitions). 337 | """ 338 | # node: ast.FunctionDef | ast.ClassDef | ast.AsyncFunctionDef 339 | decorators = node.decorator_list 340 | for decorator in decorators: 341 | self.visit(decorator) 342 | del node.decorator_list 343 | line_idx = node.lineno - 1 344 | # Guess offset of the name (length of the keyword + 1) 345 | start = node.col_offset + (6 if type(node) is ast.ClassDef else 4) 346 | stop = start + len(node.name) 347 | # If the node has no decorators and its name appears directly after the 348 | # definition keyword, we found its position and don't need to tokenize. 349 | if not decorators and self._lines[line_idx][start:stop] == node.name: 350 | lineno = node.lineno 351 | column = start 352 | else: 353 | tokens = tokenize_lines(self._lines[i] for i in count(line_idx)) 354 | advance(tokens, ('class', 'def')) 355 | token = advance(tokens) 356 | lineno = token.start[0] + line_idx 357 | column = token.start[1] 358 | self.nodes.append(Node(node.name, lineno, column, self._cur_env)) 359 | 360 | # Handling type parameters & generic syntax (Python 3.12+) 361 | # When generic type vars are present, a new scope is added 362 | _type_params = node.type_params if TYPE_VARS else None 363 | with (self._enter_scope() if _type_params # ... 364 | else contextlib.nullcontext()): 365 | if _type_params: 366 | for p in _type_params: 367 | self.visit(p) 368 | del node.type_params # Don't visit again later 369 | 370 | # Visit class meta (parent class), argument type hints, etc. 371 | if type(node) is ast.ClassDef: 372 | self._visit_class_meta(node) 373 | else: 374 | self._visit_args(node) 375 | self._mark_self(node) 376 | 377 | def _visit_global_nonlocal(self, node, keyword): 378 | line_idx = node.lineno - 1 379 | line = self._lines[line_idx] 380 | indent = line[:-len(line.lstrip())] 381 | if line == indent + keyword + ' ' + ', '.join(node.names): 382 | offset = len(indent) + len(keyword) + 1 383 | for name in node.names: 384 | self.nodes.append(Node( 385 | name, 386 | node.lineno, 387 | offset, 388 | self._cur_env, 389 | )) # yapf: disable 390 | # Add 2 bytes for the comma and space 391 | offset += len(name.encode('utf-8')) + 2 392 | return 393 | # Couldn't guess line, so we need to tokenize. 394 | tokens = tokenize_lines(self._lines[i] for i in count(line_idx)) 395 | # Advance to global/nonlocal statement 396 | advance(tokens, keyword) 397 | for name, more in zip(node.names, count(1 - len(node.names))): 398 | token = advance(tokens) 399 | cur_line = self._lines[line_idx + token.start[0] - 1] 400 | self.nodes.append(Node( 401 | token.string, 402 | token.start[0] + line_idx, 403 | len(cur_line[:token.start[1]].encode('utf-8')), 404 | self._cur_env, 405 | )) # yapf: disable 406 | # If there are more declared names... 407 | if more: 408 | # ...advance to next comma. 409 | advance(tokens, ',', OP) 410 | 411 | def _visit_type(self, node): 412 | """Visit type statement (PEP-695).""" 413 | # e.g. type MyList[T_var] = list[T_var] 414 | # ^^^^^^ ^^^^^ ^ reference to typevar 415 | # name typevar 416 | # Visit alias name in the outer scope 417 | self.visit(node.name) 418 | 419 | # The type statement has two variable scopes: one for typevar (if any), 420 | # and another one (a child scope) for the rhs 421 | maybe_scope = (self._enter_scope() if node.type_params \ 422 | else contextlib.nullcontext()) 423 | with maybe_scope: 424 | for p in node.type_params: 425 | self.visit(p) 426 | with self._enter_scope(): 427 | self.visit(node.value) 428 | 429 | def _visit_typevar(self, node): 430 | # node: ast.TypeVar | ast.ParamSpec | ast.TypeVarTuple 431 | self.nodes.append( 432 | Node( 433 | node.name, 434 | node.lineno, 435 | node.col_offset, 436 | self._cur_env, 437 | )) 438 | 439 | # When a TypeVar has a bound or a default value, 440 | # e.g. `T: T_Bound = T_Default`, each expression (bound and/or default) 441 | # introduces a new inner lexical scope for the type variable. 442 | bound = node.bound if type(node) is ast.TypeVar else None 443 | default_value = node.default_value if HAS_PY313 else None 444 | 445 | if bound: 446 | with self._enter_scope(): 447 | self.visit(bound) 448 | 449 | if default_value: 450 | with self._enter_scope(): 451 | self.visit(default_value) 452 | 453 | def _mark_self(self, node): 454 | """Mark self/cls argument if the current function has one. 455 | 456 | Determine if an argument is a self argument (the first argument of a 457 | method called "self" or "cls") and add a reference in the function's 458 | symtable. 459 | """ 460 | # The first argument... 461 | try: 462 | # TODO Does this break with posonlyargs? 463 | arg = node.args.args[0] 464 | except IndexError: 465 | return 466 | # ...with a special name... 467 | if arg.arg not in ('self', 'cls'): 468 | return 469 | # ...and a class as parent scope is a self_param. 470 | if not self._env[-1].get_type() == 'class': 471 | return 472 | # Let the table for the current function scope remember the param 473 | self._table_stack[-1].self_param = arg.arg 474 | 475 | def _add_attribute(self, node): 476 | """Add node as an attribute. 477 | 478 | The only relevant attributes are attributes to self or cls in a 479 | method (e.g. "self._name"). 480 | """ 481 | # Node must be an attribute of a name (foo.attr, but not [].attr) 482 | if type(node.value) is not ast.Name: 483 | return 484 | target_name = node.value.id 485 | # Redundant, but may spare us the getattr() call in the next step 486 | if target_name not in ('self', 'cls'): 487 | return 488 | # Only register attributes of self/cls parameter 489 | if target_name != getattr(self._env[-1], 'self_param', None): 490 | return 491 | new_node = Node( 492 | node.attr, 493 | node.value.lineno, 494 | node.value.col_offset + len(target_name) + 1, 495 | self._env[:-1], 496 | None, # target 497 | ATTRIBUTE, 498 | ) 499 | node.value._target = new_node # pylint: disable=protected-access 500 | self.nodes.append(new_node) 501 | 502 | def _iter_node(self, node): 503 | """Iterate through fields of the node.""" 504 | if node is None: 505 | return 506 | for field in node._fields: 507 | value = node.__dict__.get(field, None) 508 | if value is None: 509 | continue 510 | value_type = type(value) 511 | if value_type is list: 512 | for item in value: 513 | if isinstance(item, str): 514 | continue 515 | self.visit(item) 516 | # We would want to use isinstance(value, AST) here. Not sure how 517 | # much more expensive that is, though. 518 | elif value_type not in (str, int, bytes, bool): 519 | self.visit(value) 520 | 521 | 522 | if sys.version_info < (3, 8): 523 | # pylint: disable=protected-access 524 | Visitor._visit_args = Visitor._visit_args_pre38 525 | -------------------------------------------------------------------------------- /test/test_parser.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for semshi.parser""" 2 | 3 | # pylint: disable=protected-access 4 | 5 | import sys 6 | from pathlib import Path 7 | from textwrap import dedent 8 | 9 | import pytest 10 | 11 | from semshi.node import ( 12 | ATTRIBUTE, 13 | BUILTIN, 14 | FREE, 15 | GLOBAL, 16 | IMPORTED, 17 | LOCAL, 18 | PARAMETER, 19 | PARAMETER_UNUSED, 20 | SELF, 21 | UNRESOLVED, 22 | Node, 23 | group, 24 | ) 25 | from semshi.parser import Parser, UnparsableError 26 | 27 | from .conftest import make_parser, make_tree, parse 28 | 29 | # top-level functions are parsed as LOCAL in python<3.7, 30 | # but as GLOBAL in Python 3.8. 31 | MODULE_FUNC = GLOBAL if sys.version_info >= (3, 8) else LOCAL 32 | 33 | # Python 3.12: comprehensions no longer have their own variable scopes 34 | # https://peps.python.org/pep-0709/ 35 | PEP_709 = sys.version_info >= (3, 12) 36 | 37 | 38 | def test_group(): 39 | assert group('foo') == 'semshiFoo' 40 | 41 | 42 | def test_basic_name(): 43 | assert [n.name for n in parse('x = 1')] == ['x'] 44 | 45 | 46 | def test_no_names(): 47 | assert parse('') == [] 48 | assert parse('pass') == [] 49 | 50 | 51 | def test_recursion_error(): 52 | with pytest.raises(UnparsableError): 53 | parse(' + '.join(1000 * ['a'])) 54 | 55 | 56 | def test_syntax_error_fail(): 57 | """Syntax errors which can't be fixed with a single change.""" 58 | parser = Parser() 59 | with pytest.raises(UnparsableError): 60 | parser.parse('(\n(') 61 | with pytest.raises(UnparsableError): 62 | parser.parse(')\n(') 63 | # Intentionally no difference to previous one 64 | with pytest.raises(UnparsableError): 65 | parser.parse(')\n(') 66 | 67 | 68 | def test_syntax_error_fail2(): 69 | """Syntax errors which can't be fixed with a single change.""" 70 | parser = make_parser('a\nb/') 71 | with pytest.raises(UnparsableError): 72 | parser.parse('a(\nb/') 73 | 74 | 75 | def test_fixable_syntax_errors(): 76 | """Test syntax errors where we can tokenize the erroneous line.""" 77 | names = parse(r''' 78 | a a = b in 79 | c 80 | ''') 81 | assert [n.pos for n in names] == [(2, 0), (2, 3), (2, 7), (3, 0)] 82 | 83 | 84 | def test_fixable_syntax_errors2(): 85 | """Test syntax errors where we can tokenize the last modified line.""" 86 | parser = make_parser(r''' 87 | a 88 | b 89 | ''') 90 | parser.parse(dedent(r''' 91 | c( 92 | b 93 | ''')) 94 | assert {n.name for n in parser._nodes} == {'c', 'b'} 95 | 96 | 97 | @pytest.mark.xfail 98 | def test_fixable_syntax_errors3(): 99 | """Improved syntax fixing should be able to handle a bad symbol at the 100 | end of the erroneous line.""" 101 | parser = make_parser('def foo(): x=1-') 102 | print(parser.syntax_errors[-1].offset) 103 | assert [n.hl_group for n in parser._nodes] == [LOCAL, LOCAL] 104 | print(parser._nodes) 105 | raise NotImplementedError() 106 | 107 | 108 | def test_fixable_syntax_errors_indent(): 109 | parser = make_parser('''def foo():\n \t \tx-''') 110 | assert parser._nodes[-1].pos == (2, 4) 111 | 112 | 113 | def test_fixable_syntax_errors_misc(): 114 | fix = Parser._fix_line 115 | assert fix('') == '' 116 | assert fix('(') == '' 117 | assert fix(' (x') == ' +x' 118 | assert fix(' .x') == ' +x' 119 | # The trailing whitespace shouldn't be there, but doesn't do any harm 120 | assert fix(' a .. ') == ' a ' 121 | 122 | 123 | def test_fixable_syntax_errors_attributes(): 124 | fix = Parser._fix_line 125 | assert fix('foo bar . . baz') == \ 126 | 'foo+bar . baz' 127 | assert fix('(foo.bar . baz qux ( . baar') == \ 128 | '+foo.bar . baz++qux . baar' 129 | # Doesn't matter that we don't preserve tabs because we only want offsets 130 | assert fix('def foo.bar( + 1\t. 0 ... .1 spam . ham \t .eggs..') == \ 131 | '++++foo.bar . spam . ham .eggs' 132 | 133 | 134 | def test_syntax_error_cycle(): 135 | parser = make_parser('') 136 | assert parser.syntax_errors[-2] is None 137 | assert parser.syntax_errors[-1] is None 138 | parser.parse('1+') 139 | assert parser.syntax_errors[-2] is None 140 | assert parser.syntax_errors[-1].lineno == 1 141 | parser.parse('1+1') 142 | assert parser.syntax_errors[-2].lineno == 1 143 | assert parser.syntax_errors[-1] is None 144 | with pytest.raises(UnparsableError): 145 | parser.parse('\n+\n+') 146 | assert parser.syntax_errors[-2] is None 147 | assert parser.syntax_errors[-1].lineno == 2 148 | 149 | 150 | def test_detect_symtable_syntax_error(): 151 | """Some syntax errors (such as duplicate parameter names) aren't directly 152 | raised when compile() is called on the code, but cause problems later. 153 | """ 154 | parser = Parser() 155 | with pytest.raises(UnparsableError): 156 | parser.parse('def foo(x, x): pass') 157 | assert parser.syntax_errors[-1].lineno == 1 158 | 159 | 160 | def test_name_len(): 161 | """Name length needs to be byte length for the correct HL offset.""" 162 | names = parse('asd + äöü') 163 | assert names[0].end - names[0].col == 3 164 | assert names[1].end - names[1].col == 6 165 | 166 | 167 | def test_comprehension_scopes(): 168 | names = parse(r''' 169 | #!/usr/bin/env python3 170 | (a for b in c) 171 | [d for e in f] 172 | {g for h in i} 173 | {j:k for l in m} 174 | ''') 175 | root = make_tree(names) 176 | groups = {n.name: n.hl_group for n in names} 177 | print(f"root = {root}") 178 | print(f"groups = {groups}") 179 | 180 | if not PEP_709: 181 | assert root['names'] == ['c', 'f', 'i', 'm'] 182 | assert root['genexpr']['names'] == ['a', 'b'] 183 | assert root['listcomp']['names'] == ['d', 'e'] 184 | assert root['setcomp']['names'] == ['g', 'h'] 185 | assert root['dictcomp']['names'] == ['j', 'k', 'l'] 186 | 187 | # generator variables b, e, h, l are local within the scope 188 | assert [name for name, group in groups.items() if group == LOCAL \ 189 | ] == ['b', 'e', 'h', 'l'] 190 | assert [name for name, group in groups.items() if group == UNRESOLVED 191 | ] == ['c', 'a', 'f', 'd', 'i', 'g', 'm', 'j', 'k'] 192 | 193 | else: 194 | # PEP-709, Python 3.12+: comprehensions do not have scope of their own. 195 | # so all the symbol is contained in the root node (ast.Module) 196 | assert root['names'] == [ 197 | # in the order nodes are visited and evaluated 198 | 'c', # generators have nested scope !!! 199 | 'f', 'd', 'e', 200 | 'i', 'g', 'h', 201 | 'm', 'j', 'k', 'l' 202 | ] # yapf: disable 203 | # no comprehension children nodes 204 | assert list(root.keys()) == ['names', 'genexpr'] 205 | 206 | # generator variables e, h, l have the scope of the top-level module 207 | assert [name for name, group in groups.items() if group == GLOBAL 208 | ] == ['e', 'h', 'l'] # b is defined within the generator scope 209 | assert [name for name, group in groups.items() if group == UNRESOLVED 210 | ] == ['c', 'a', 'f', 'd', 'i', 'g', 'm', 'j', 'k'] 211 | 212 | 213 | def test_function_scopes(): 214 | names = parse(r''' 215 | #!/usr/bin/env python3 216 | def func(a, b, *c, d=e, f=[g for g in h], **i): 217 | pass 218 | def func2(j=k): 219 | pass 220 | func(x, y=p, **z) 221 | ''') 222 | root = make_tree(names) 223 | print(f"root = {root}") 224 | 225 | assert root['names'] == [ 226 | 'e', 'h', 227 | *(['g', 'g'] if PEP_709 else []), 228 | 'func', 'k', 'func2', 'func', 'x', 'p', 'z' 229 | ] # yapf: disable 230 | if not PEP_709: 231 | assert root['listcomp']['names'] == ['g', 'g'] 232 | assert root['func']['names'] == ['a', 'b', 'c', 'd', 'f', 'i'] 233 | assert root['func2']['names'] == ['j'] 234 | 235 | 236 | def test_class_scopes(): 237 | names = parse(r''' 238 | #!/usr/bin/env python3 239 | a = 1 240 | class A(x, y=z): 241 | a = 2 242 | def f(): 243 | a 244 | ''') 245 | root = make_tree(names) 246 | assert root['names'] == ['a', 'A', 'x', 'z'] 247 | 248 | 249 | def test_import_scopes_and_positions(): 250 | names = parse(r''' 251 | #!/usr/bin/env python3 252 | import aa 253 | import BB as cc 254 | from DD import ee 255 | from FF.GG import hh 256 | import ii.jj 257 | import kk, ll 258 | from MM import NN as oo 259 | from PP import * 260 | import qq, RR as tt, UU as vv 261 | from WW import xx, YY as zz 262 | import aaa; import bbb 263 | from CCC import (ddd, 264 | eee) 265 | import FFF.GGG as hhh 266 | from III.JJJ import KKK as lll 267 | import mmm.NNN.OOO, ppp.QQQ 268 | ''') 269 | root = make_tree(names) 270 | assert root['names'] == [ 271 | 'aa', 'cc', 'ee', 'hh', 'ii', 'kk', 'll', 'oo', 'qq', 'tt', 'vv', 'xx', 272 | 'zz', 'aaa', 'bbb', 'ddd', 'eee', 'hhh', 'lll', 'mmm', 'ppp' 273 | ] 274 | assert [(name.name, ) + name.pos for name in names] == [ 275 | ('aa', 3, 7), 276 | ('cc', 4, 13), 277 | ('ee', 5, 15), 278 | ('hh', 6, 18), 279 | ('ii', 7, 7), 280 | ('kk', 8, 7), 281 | ('ll', 8, 11), 282 | ('oo', 9, 21), 283 | ('qq', 11, 7), 284 | ('tt', 11, 17), 285 | ('vv', 11, 27), 286 | ('xx', 12, 15), 287 | ('zz', 12, 25), 288 | ('aaa', 13, 7), 289 | ('bbb', 13, 19), 290 | ('ddd', 14, 17), 291 | ('eee', 15, 0), 292 | ('hhh', 16, 18), 293 | ('lll', 17, 27), 294 | ('mmm', 18, 7), 295 | ('ppp', 18, 20), 296 | ] 297 | 298 | 299 | def test_multibyte_import_positions(): 300 | names = parse(r''' 301 | #!/usr/bin/env python3 302 | import aaa, bbb 303 | import äää, ööö 304 | aaa; import bbb, ccc 305 | äää; import ööö, üüü 306 | import äää; import ööö, üüü; from äää import ööö; import üüü as äää 307 | from x import ( 308 | äää, ööö 309 | ) 310 | from foo \ 311 | import äää 312 | ''') 313 | positions = [(n.col, n.end) for n in names] 314 | assert positions == [ 315 | (7, 10), (12, 15), 316 | (7, 13), (15, 21), 317 | (0, 3), (12, 15), (17, 20), 318 | (0, 6), (15, 21), (23, 29), 319 | (7, 13), (22, 28), (30, 36), (57, 63), (82, 88), 320 | (4, 10), (12, 18), 321 | (11, 17), # note the line continuation 322 | ] # yapf: disable 323 | 324 | 325 | def test_name_mangling(): 326 | """Leading double underscores can lead to a different symbol name.""" 327 | names = parse(r''' 328 | #!/usr/bin/env python3 329 | __foo = 1 330 | class A: 331 | __foo 332 | class B: 333 | __foo 334 | def f(): 335 | __foo 336 | class __C: 337 | pass 338 | class _A: 339 | def f(): 340 | __x 341 | class _A_: 342 | def f(): 343 | __x 344 | class ___A_: 345 | def f(): 346 | __x 347 | ''') 348 | root = make_tree(names) 349 | assert root['names'] == ['__foo', 'A', '_A', '_A_', '___A_'] 350 | assert root['A']['names'] == ['_A__foo', 'B', '_A__C'] 351 | assert root['A']['B']['names'] == ['_B__foo', 'f'] 352 | assert root['A']['B']['f']['names'] == ['_B__foo'] 353 | assert root['_A']['f']['names'] == ['_A__x'] 354 | assert root['_A_']['f']['names'] == ['_A___x'] 355 | assert root['___A_']['f']['names'] == ['_A___x'] 356 | 357 | 358 | def test_self_param(): 359 | """If self/cls appear in a class, they must have a speical group.""" 360 | names = parse(r''' 361 | #!/usr/bin/env python3 362 | self 363 | def x(self): 364 | pass 365 | class Foo: 366 | def x(self): 367 | pass 368 | def y(): 369 | self 370 | def z(self): 371 | self 372 | def a(foo, self): 373 | pass 374 | def b(foo, cls): 375 | pass 376 | def c(cls, foo): 377 | pass 378 | ''') 379 | groups = [n.hl_group for n in names if n.name in ['self', 'cls']] 380 | assert [PARAMETER if g is PARAMETER_UNUSED else g for g in groups] == [ 381 | UNRESOLVED, PARAMETER, SELF, FREE, PARAMETER, PARAMETER, PARAMETER, 382 | PARAMETER, SELF 383 | ] 384 | 385 | 386 | def test_self_with_decorator(): 387 | names = parse(r''' 388 | #!/usr/bin/env python3 389 | class Foo: 390 | @decorator(lambda k: k) 391 | def x(self): 392 | self 393 | ''') 394 | assert names[-1].hl_group == SELF 395 | 396 | 397 | def test_self_target(): 398 | """The target of a self with an attribute should be the attribute node.""" 399 | parser = make_parser(r''' 400 | #!/usr/bin/env python3 401 | self.abc 402 | class Foo: 403 | def x(self): 404 | self.abc 405 | ''') 406 | names = parser._nodes 407 | assert names[0].target is None 408 | last_self = names[-1] 409 | abc = names[-2] 410 | assert last_self.target is abc 411 | assert last_self.target.name == 'abc' 412 | assert list(parser.same_nodes(last_self)) == [abc] 413 | 414 | 415 | def test_unresolved_name(): 416 | names = parse('def foo(): a') 417 | assert names[1].hl_group == UNRESOLVED 418 | 419 | 420 | def test_imported_names(): 421 | names = parse(r''' 422 | #!/usr/bin/env python3 423 | import foo 424 | import abs 425 | foo, abs 426 | ''') 427 | assert [n.hl_group for n in names] == [IMPORTED] * 4 428 | 429 | 430 | def test_nested_comprehension(): 431 | names = parse(r''' 432 | #!/usr/bin/env python3 433 | [a for b in c for d in e for f in g] 434 | [h for i in [[x for y in z] for k in [l for m in n]]] 435 | [o for p, q, r in s] 436 | ''') 437 | root = make_tree(names) 438 | if not PEP_709: 439 | assert root['names'] == ['c', 'n', 's'] 440 | assert root['listcomp']['names'] == [ 441 | 'a', 'b', 'd', 'e', 'f', 'g', 'l', 'm', 'z', 'k', 'h', 'i', 'o', 442 | 'p', 'q', 'r' 443 | ] 444 | else: 445 | # Python 3.12: all the 18 symbols are included in the root scope 446 | assert root['names'] == [ 447 | *['c', 'a', 'b'], *['d', 'e', 'f', 'g'], 448 | *['n', 'l', 'm'], *['z', 'x', 'y'], 'k', 'h', 'i', 449 | *['s', 'o', 'p', 'q', 'r'] 450 | ] # yapf: disable 451 | assert 'listcomp' not in root 452 | 453 | 454 | def test_try_except_order(): 455 | names = parse(r''' 456 | #!/usr/bin/env python3 457 | try: 458 | def A(): 459 | a 460 | except ImportError: 461 | def B(): 462 | b 463 | else: 464 | def C(): 465 | c 466 | finally: 467 | def D(): 468 | d 469 | ''') 470 | root = make_tree(names) 471 | assert root['A']['names'] == ['a'] 472 | assert root['B']['names'] == ['b'] 473 | assert root['C']['names'] == ['c'] 474 | assert root['D']['names'] == ['d'] 475 | 476 | 477 | def test_except_as(): 478 | names = parse('try: pass\nexcept E as a: pass\nexcept F as\\\n b: pass') 479 | assert next(n.pos for n in names if n.name == 'a') == (2, 12) 480 | assert next(n.pos for n in names if n.name == 'b') == (4, 1) 481 | 482 | 483 | def test_global_nonlocal(): 484 | names = parse(r''' 485 | #!/usr/bin/env python3 486 | global ä, ää, \ 487 | b # Line 4 488 | def foo(): # Line 5 489 | c = 1 490 | def bar(): # Line 7 491 | nonlocal c 492 | ''') 493 | print([(n.name, n.pos) for n in names]) 494 | assert [(n.name, n.pos) for n in names] == [ 495 | ('ä', (3, 7)), 496 | ('ää', (3, 11)), 497 | ('b', (4, 0)), 498 | ('foo', (5, 4)), 499 | ('c', (6, 4)), 500 | ('bar', (7, 8)), 501 | ('c', (8, 17)), 502 | ] 503 | 504 | 505 | def test_lambda(): 506 | names = parse(r''' 507 | #!/usr/bin/env python3 508 | lambda a: b 509 | lambda x=y: z 510 | ''') 511 | root = make_tree(names) 512 | assert root['lambda']['names'] == ['a', 'b', 'x', 'z'] 513 | assert root['names'] == ['y'] 514 | 515 | 516 | @pytest.mark.skipif('sys.version_info < (3, 6)') 517 | def test_fstrings(): 518 | assert [n.name for n in parse('f\'{foo}\'')] == ['foo'] 519 | 520 | 521 | @pytest.mark.skipif('sys.version_info < (3, 9, 7)') 522 | def test_fstrings_offsets(): 523 | # There was a Python-internal bug causing expressions with format 524 | # specifiers in f-strings to give wrong offsets when parsing into AST 525 | # (https://bugs.python.org/issue35212, https://bugs.python.org/issue44885). 526 | # The bug was fixed since 3.9.7+ and 3.10+ (numirias/semshi#31). 527 | s = "f'x{aa}{bbb:y}{cccc}'" 528 | names = parse("f'x{aa}{bbb:y}{cccc}'") 529 | offsets = [s.index(x) for x in 'abc'] 530 | assert [n.col for n in names] == offsets 531 | 532 | 533 | def test_type_hints(): 534 | names = parse(r''' 535 | #!/usr/bin/env python3 536 | def f(a:A, b, *c:C, d:D=dd, **e:E) -> z: 537 | pass 538 | async def f2(x:X=y): 539 | pass 540 | ''') 541 | root = make_tree(names) 542 | assert root['names'] == [ 543 | 'dd', 'f', 'A', 'D', 'C', 'E', 'z', 'y', 'f2', 'X' 544 | ] 545 | 546 | 547 | def test_decorator(): 548 | names = parse(r''' 549 | #!/usr/bin/env python3 550 | @d1(a, b=c) 551 | class A: pass 552 | @d2(x, y=z) 553 | def B(): 554 | pass 555 | @d3 556 | async def C(): 557 | pass 558 | ''') 559 | root = make_tree(names) 560 | assert root['names'] == [ 561 | 'd1', 'a', 'c', 'A', 'd2', 'x', 'z', 'B', 'd3', 'C' 562 | ] 563 | 564 | 565 | def test_global_builtin(): 566 | """A builtin name assigned globally should be highlighted as a global, not 567 | a builtin.""" 568 | names = parse(r''' 569 | #!/usr/bin/env python3 570 | len 571 | set = 1 572 | def foo(): set, str 573 | ''') 574 | assert names[0].hl_group == BUILTIN 575 | assert names[-2].hl_group == GLOBAL 576 | assert names[-1].hl_group == BUILTIN 577 | 578 | 579 | def test_global_statement(): 580 | names = parse(r''' 581 | #!/usr/bin/env python3 582 | x = 1 583 | def foo(): 584 | global x 585 | x 586 | ''') 587 | assert names[-1].hl_group == GLOBAL 588 | 589 | 590 | def test_positions(): 591 | names = parse(r''' 592 | #!/usr/bin/env python3 593 | a = 1 # Line 3 594 | def func(x=y): # Line 4 595 | b = 2 596 | ''') 597 | assert [(name.name, ) + name.pos for name in names] == [ 598 | ('a', 3, 0), 599 | ('y', 4, 11), 600 | ('func', 4, 4), 601 | ('x', 4, 9), 602 | ('b', 5, 4), 603 | ] 604 | 605 | 606 | def test_class_and_function_positions(): 607 | # Note: did not use r''' to use literal '\t' 608 | names = parse(''' 609 | #!/usr/bin/env python3 610 | def aaa(): pass # Line 3 611 | async def bbb(): pass 612 | async def ccc(): pass 613 | class ddd(): pass 614 | class \t\f eee(): pass # Line 7 615 | class \\ 616 | \\ 617 | ggg: pass # Line 10 618 | @deco 619 | @deco2 620 | @deco3 621 | class hhh(): # Line 14 622 | def foo(): 623 | pass 624 | ''') 625 | assert [name.pos for name in names] == [ 626 | (3, 4), # aaa 627 | (4, 10), # bbb 628 | (5, 12), # ccc 629 | (6, 6), # ddd 630 | (7, 9), # eee 631 | (10, 2), # ggg 632 | (11, 1), # deco 633 | (12, 1), # deco 2 634 | (13, 1), # deco 3 635 | (14, 6), # hhh 636 | (15, 8), # foo 637 | ] 638 | 639 | 640 | def test_same_nodes(): 641 | parser = make_parser(r''' 642 | #!/usr/bin/env python3 643 | x = 1 644 | class A: 645 | x 646 | def B(): 647 | x 648 | ''') 649 | names = parser._nodes 650 | x, A, A_x, B, B_x = names 651 | same_nodes = set(parser.same_nodes(x)) 652 | assert same_nodes == {x, A_x, B_x} 653 | 654 | 655 | def test_base_scope_global(): 656 | parser = make_parser(r''' 657 | #!/usr/bin/env python3 658 | x = 1 659 | def a(): 660 | x = 2 661 | def b(): 662 | global x 663 | x 664 | ''') 665 | names = parser._nodes 666 | x, a, a_x, b, b_global_x, b_x = names 667 | same_nodes = set(parser.same_nodes(x)) 668 | assert same_nodes == {x, b_global_x, b_x} 669 | 670 | 671 | def test_base_scope_free(): 672 | parser = make_parser(r''' 673 | #!/usr/bin/env python3 674 | def a(): 675 | x = 1 676 | def b(): 677 | x 678 | ''') 679 | names = parser._nodes 680 | a, a_x, b, b_x = names 681 | same_nodes = set(parser.same_nodes(a_x)) 682 | assert same_nodes == {a_x, b_x} 683 | 684 | 685 | def test_base_scope_class(): 686 | parser = make_parser(r''' 687 | #!/usr/bin/env python3 688 | class A: 689 | x = 1 690 | x 691 | ''') 692 | names = parser._nodes 693 | A, x1, x2 = names 694 | same_nodes = set(parser.same_nodes(x1)) 695 | assert same_nodes == {x1, x2} 696 | 697 | 698 | def test_base_scope_class_nested(): 699 | parser = make_parser(r''' 700 | #!/usr/bin/env python3 701 | def z(): 702 | x = 1 703 | class A(): 704 | x = 2 705 | def b(): 706 | return x 707 | ''') 708 | names = parser._nodes 709 | z, z_x, A, A_x, b, b_x = names 710 | same_nodes = set(parser.same_nodes(z_x)) 711 | assert same_nodes == {z_x, b_x} 712 | 713 | 714 | def test_base_scope_nonlocal_free(): 715 | parser = make_parser(r''' 716 | #!/usr/bin/env python3 717 | def foo(): 718 | a = 1 719 | def bar(): 720 | nonlocal a 721 | a = 1 722 | ''') 723 | foo, foo_a, bar, bar_nonlocal_a, bar_a = parser._nodes 724 | assert set(parser.same_nodes(foo_a)) == {foo_a, bar_nonlocal_a, bar_a} 725 | 726 | 727 | def test_attributes(): 728 | parser = make_parser(r''' 729 | #!/usr/bin/env python3 730 | aa.bb 731 | cc.self.dd 732 | self.ee 733 | def a(self): 734 | self.ff 735 | class A: 736 | def b(self): 737 | self.gg 738 | class B: 739 | def c(self): 740 | self.gg 741 | def d(self): 742 | self.gg 743 | def e(self): 744 | self.hh 745 | def f(foo): 746 | self.gg 747 | ''') 748 | names = parser._nodes 749 | names = [n for n in names if n.hl_group == ATTRIBUTE] 750 | b_gg, c_gg, d_gg, e_hh = names 751 | same_nodes = set(parser.same_nodes(c_gg)) 752 | assert same_nodes == {c_gg, d_gg} 753 | 754 | 755 | def test_same_nodes_exclude_current(): 756 | parser = make_parser('a, a, a') 757 | a0, a1, a2 = parser._nodes 758 | assert set(parser.same_nodes(a0, mark_original=False)) == {a1, a2} 759 | 760 | 761 | def test_same_nodes_empty(): 762 | parser = make_parser('0, 1') 763 | assert parser.same_nodes((1, 0)) == [] 764 | 765 | 766 | def test_same_nodes_use_target(): 767 | parser = make_parser(r''' 768 | #!/usr/bin/env python3 769 | class Foo: 770 | def foo(self): 771 | self.x, self.x 772 | ''') 773 | node = parser._nodes[-1] 774 | assert [n.name for n in list(parser.same_nodes(node, use_target=True)) 775 | ] == ['x', 'x'] 776 | assert [n.name for n in list(parser.same_nodes(node, use_target=False)) 777 | ] == ['self', 'self', 'self'] 778 | 779 | 780 | def test_refresh_names(): 781 | """Clear everything if more than one line changes.""" 782 | # yapf: disable 783 | parser = Parser() 784 | add, clear = parser.parse(dedent(r''' 785 | def foo(): 786 | x = y 787 | ''')) 788 | assert len(add) == 3 789 | assert len(clear) == 0 790 | add, clear = parser.parse(dedent(r''' 791 | def foo(): 792 | x = y 793 | ''')) 794 | assert len(add) == 0 795 | assert len(clear) == 0 796 | add, clear = parser.parse(dedent(r''' 797 | def foo(): 798 | z = y 799 | ''')) 800 | assert len(add) == 1 801 | assert len(clear) == 1 802 | add, clear = parser.parse(dedent(r''' 803 | def foo(): 804 | z = y 805 | a, b 806 | ''')) 807 | assert len(add) == 5 808 | assert len(clear) == 3 809 | add, clear = parser.parse(dedent(r''' 810 | def foo(): 811 | z = y 812 | c, d 813 | ''')) 814 | assert len(add) == 2 815 | assert len(clear) == 2 816 | add, clear = parser.parse(dedent(r''' 817 | def foo(): 818 | z = y, k 819 | 1, 1 820 | ''')) 821 | assert len(add) == 4 822 | assert len(clear) == 5 823 | # yapf: enable 824 | 825 | 826 | def test_exclude_types(): 827 | # yapf: disable 828 | parser = Parser(exclude=[LOCAL]) 829 | add, clear = parser.parse(dedent(r''' 830 | a = 1 831 | def f(): 832 | b, c = 1 833 | a + b 834 | ''')) 835 | # Python <= 3.7 parses 'a = 1' as the only GLOBAL, 836 | # but Python >= 3.8 parses three GLOBALS (a, f, a). 837 | # assert [n.name for n in add] == ['a'] 838 | assert all(n.hl_group != LOCAL for n in add) 839 | assert clear == [] 840 | add, clear = parser.parse(dedent(r''' 841 | a = 1 842 | def f(): 843 | b, c = 1 844 | a + c 845 | ''')) 846 | assert add == [] 847 | assert clear == [] 848 | add, clear = parser.parse(dedent(r''' 849 | a = 1 850 | def f(): 851 | b, c = 1 852 | g + c 853 | ''')) 854 | assert [n.name for n in add] == ['g'] 855 | assert [n.name for n in clear] == ['a'] 856 | add, clear = parser.parse(dedent(r''' 857 | a = 1 858 | def f(): 859 | b, c = 1 860 | 0 + c 861 | ''')) 862 | assert add == [] 863 | assert [n.name for n in clear] == ['g'] 864 | # yapf: enable 865 | 866 | 867 | def test_exclude_types_same_nodes(): 868 | parser = Parser(exclude=[UNRESOLVED]) 869 | add, clear = parser.parse('a, a') 870 | assert len(add) == 0 871 | assert [n.pos for n in parser.same_nodes((1, 0))] == [(1, 0), (1, 3)] 872 | 873 | 874 | def test_make_nodes(): 875 | """parser._make_nodes should work without a `lines` argument.""" 876 | parser = Parser() 877 | parser._make_nodes('x') 878 | 879 | 880 | def test_unused_args(): 881 | names = parse(r''' 882 | #!/usr/bin/env python3 883 | def foo(a, b, c, d=1): a, c 884 | lambda x: 1 885 | async def bar(y): pass 886 | ''') 887 | assert [n.hl_group for n in names] == [ 888 | # foo a b c d 889 | MODULE_FUNC, PARAMETER, PARAMETER_UNUSED, PARAMETER, PARAMETER_UNUSED, 890 | # a c 891 | PARAMETER, PARAMETER, 892 | # x bar y 893 | PARAMETER_UNUSED, MODULE_FUNC, PARAMETER_UNUSED 894 | ] # yapf: disable 895 | 896 | 897 | def test_unused_args2(): 898 | """Detect unused args in nested scopes correctly.""" 899 | names = parse(r''' 900 | #!/usr/bin/env python3 901 | def foo(x): lambda: x 902 | ''') 903 | assert [n.hl_group for n in names if n.name == 'x'] == [PARAMETER, FREE] 904 | 905 | names = parse(r''' 906 | #!/usr/bin/env python3 907 | def foo(x): 908 | [[x for a in b] for y in z] 909 | ''') 910 | assert [n.hl_group for n in names if n.name == 'x'] == [ \ 911 | PARAMETER, 912 | PARAMETER if PEP_709 else FREE 913 | ] 914 | 915 | 916 | @pytest.mark.skipif('sys.version_info < (3, 8)') 917 | def test_posonlyargs(): 918 | names = parse('def f(x, /): pass') 919 | assert [n.hl_group for n in names] == [MODULE_FUNC, PARAMETER_UNUSED] 920 | 921 | 922 | # Fails due to what seems to be an internal bug. See: 923 | # https://stackoverflow.com/q/59066024/5765873 924 | @pytest.mark.xfail 925 | @pytest.mark.skipif('sys.version_info < (3, 8)') 926 | def test_posonlyargs_with_annotation(): 927 | names = parse('def f(x: y, /): pass') 928 | assert [n.hl_group for n in names] == [ 929 | MODULE_FUNC, 930 | UNRESOLVED, 931 | PARAMETER_UNUSED, 932 | ] 933 | 934 | 935 | @pytest.mark.skipif('sys.version_info < (3, 8)') 936 | @pytest.mark.parametrize("enable_pep563", (False, True)) 937 | def test_postponed_evaluation_of_annotations_pep563(enable_pep563): 938 | """Tests parsers with __future__ import annotations (PEP 563).""" 939 | # see https://peps.python.org/pep-0563/ 940 | # see https://github.com/numirias/semshi/issues/116 941 | names = parse( 942 | ('from __future__ import annotations' if enable_pep563 else '') + 943 | dedent(r''' 944 | #!/usr/bin/env python3 945 | 946 | # globals 947 | from typing import List, Any, Dict 948 | a: int = 1 # builtins 949 | b: UnknownSymbol = 2 # non-builtins 950 | c: List[Any] = [] # imported 951 | 952 | # nested scope and symtable 953 | def foo(): 954 | local_var: List[Any] = [] # local variables 955 | class Foo: 956 | attr: List[Any] = () # class attributes 957 | def __init__(self, v: Optional[List[Any]], built_in: int) -> Dict: 958 | temp: Any = built_in 959 | ''')) 960 | expected = [ 961 | ('annotations', IMPORTED) if enable_pep563 else (), 962 | ('List', IMPORTED), ('Any', IMPORTED), ('Dict', IMPORTED), 963 | ('a', GLOBAL), ('int', BUILTIN), 964 | ('b', GLOBAL), ('UnknownSymbol', UNRESOLVED), 965 | ('c', GLOBAL), ('List', IMPORTED), ('Any', IMPORTED), 966 | ('foo', GLOBAL), 967 | ('local_var', LOCAL), ('List', IMPORTED), ('Any', IMPORTED), 968 | ('Foo', GLOBAL), 969 | ('attr', LOCAL), ('List', IMPORTED), ('Any', IMPORTED), 970 | ('__init__', LOCAL), 971 | # Note: annotations & returntypes are evaluated first than parameters 972 | ('Optional', UNRESOLVED), ('List', IMPORTED), ('Any', IMPORTED), 973 | ('int', BUILTIN), ('Dict', IMPORTED), 974 | ('self', SELF), ('v', PARAMETER_UNUSED), ('built_in', PARAMETER), 975 | ('temp', LOCAL), ('Any', IMPORTED), ('built_in', PARAMETER), 976 | ] # yapf: disable 977 | expected = [n for n in expected if len(n) > 0] 978 | assert [(n.name, n.hl_group) for n in names] == expected 979 | 980 | 981 | @pytest.mark.skipif('sys.version_info < (3, 8)') 982 | def test_postponed_evaluation_of_annotations_pep563_resolution(request): 983 | """Additional tests for PEP 563. The code is from the PEP-563 document.""" 984 | path = Path(request.fspath.dirname) / 'data/pep-0563-annotations.py' 985 | with open(str(path), encoding="utf-8") as f: 986 | names = parse(f.read()) 987 | 988 | # print('\n' + '\n'.join(repr(n) for n in names)) 989 | 990 | # Tests the eight type annotations on method. 991 | def _find_annotations_for_method(): 992 | for i, _ in enumerate(names): 993 | if names[i].name == 'method': 994 | yield names[i + 1] 995 | 996 | annos = list(_find_annotations_for_method()) 997 | # print('\n' + '\n'.join(repr(n) for n in annos)) 998 | 999 | assert len(annos) == 8 1000 | 1001 | assert annos[0].name == 'C' and annos[0].hl_group == GLOBAL 1002 | assert annos[1].name == 'D' and annos[1].hl_group == UNRESOLVED 1003 | assert annos[2].name == 'field2' and annos[2].hl_group == LOCAL 1004 | assert annos[3].name == 'field' and annos[3].hl_group == UNRESOLVED 1005 | 1006 | assert annos[4].name == 'C' and annos[4].hl_group == GLOBAL 1007 | assert annos[5].name == 'field' and annos[5].hl_group == LOCAL 1008 | assert annos[6].name == 'C' and annos[6].hl_group == GLOBAL 1009 | assert annos[7].name == 'D' and annos[7].hl_group == LOCAL 1010 | 1011 | 1012 | @pytest.mark.skipif('sys.version_info < (3, 10)') 1013 | def test_match_case(): 1014 | """Tests match/case syntax. see wookayin/semshi#19.""" 1015 | parse(''' 1016 | #!/usr/bin/env python3 1017 | import sys 1018 | arg = False 1019 | match arg: 1020 | case True: print('boolean') 1021 | case False: print('boolean') 1022 | case 42: print('integer') 1023 | case 3.14: print('float') 1024 | case "string": print('string') 1025 | case b"123": print('bytearray') 1026 | case sys.version: print('expr') 1027 | ''') 1028 | 1029 | 1030 | @pytest.mark.skipif('sys.version_info < (3, 12)') 1031 | def test_generic_syntax(): 1032 | names = parse(''' 1033 | #!/usr/bin/env python3 1034 | def get_first[T: float](data: list[T]) -> T: 1035 | first: T = data[0] 1036 | return first 1037 | ''') 1038 | 1039 | expected = [ 1040 | ('get_first', MODULE_FUNC), 1041 | *[('T', LOCAL), ('float', BUILTIN)], # TypeVar with bound (T: float) 1042 | *[('list', BUILTIN), ('T', LOCAL)], # list[T] 1043 | ('T', LOCAL), # -> T: 1044 | # for now, arg name is visited *after* params and type annotations 1045 | # because of the way how variable scope is handled 1046 | ('data', PARAMETER), 1047 | *[('first', LOCAL), ('T', FREE), ('data', PARAMETER)], 1048 | ('first', LOCAL), # return ... 1049 | ] 1050 | assert [(n.name, n.hl_group) for n in names] == expected 1051 | 1052 | 1053 | @pytest.mark.skipif('sys.version_info < (3, 12)') 1054 | def test_type_statement_py312(): 1055 | # https://peps.python.org/pep-0695/ 1056 | names = parse(''' 1057 | #!/usr/bin/env python3 1058 | type IntList = list[int] # non-generic case 1059 | type MyList[T] = list[T] 1060 | # ^typevar ^ a resolved reference (treated like a closure) 1061 | 1062 | class A: 1063 | pass 1064 | 1065 | def foo(): 1066 | mylist: MyList[int] = [1, 2, 3] 1067 | # ^^^^ -> type statements used to break environment scope 1068 | assert len(mylist) == 3 1069 | ''') 1070 | expected = [ 1071 | # non-generic type statement 1072 | *[('IntList', GLOBAL), ('list', BUILTIN), ('int', BUILTIN)], 1073 | # generic type statement 1074 | *[('MyList', GLOBAL), ('T', LOCAL), ('list', BUILTIN), ('T', FREE)], 1075 | # class A: 1076 | ('A', GLOBAL), 1077 | # def foo(): 1078 | *[ 1079 | ('foo', GLOBAL), 1080 | # mylist: Mylist[int] 1081 | *[('mylist', LOCAL), ('MyList', GLOBAL), ('int', BUILTIN)], 1082 | # assert len(mylist) == 3 1083 | *[('len', BUILTIN), ('mylist', LOCAL)], 1084 | ], 1085 | ] 1086 | assert [(n.name, n.hl_group) for n in names] == expected 1087 | 1088 | 1089 | @pytest.mark.skipif('sys.version_info < (3, 13)') 1090 | def test_type_statement_py313(): 1091 | """type statement with bound (3.12+) and default (3.13+) parameters.""" 1092 | # https://peps.python.org/pep-0695/ 1093 | names = parse(''' 1094 | #!/usr/bin/env python3 1095 | type Alias1[T, P] = list[P] | set[T] 1096 | type Alias2[T, P: type[T]] = list[P] | set[T] 1097 | type Alias3[T, P = T] = list[P] | set[T] 1098 | type Alias4[T: int, P: int = bool | T] = list[P] | set[T] 1099 | 1100 | def foo(): 1101 | mylist: list[int] = [1, 2, 3] 1102 | assert len(mylist) == 3 1103 | ''') 1104 | RHS_listP_or_setT = [ 1105 | *[('list', BUILTIN), ('P', FREE)], 1106 | *[('set', BUILTIN), ('T', FREE)], 1107 | ] 1108 | expected = [ 1109 | # Alias1 1110 | *[('Alias1', GLOBAL), ('T', LOCAL), ('P', LOCAL), *RHS_listP_or_setT], 1111 | # Alias2: bound (P: type[T]) 1112 | *[('Alias2', GLOBAL), ('T', LOCAL), ('P', LOCAL), ('type', BUILTIN), 1113 | ('T', FREE), *RHS_listP_or_setT], 1114 | # Alias3: default 1115 | *[('Alias3', GLOBAL), ('T', LOCAL), ('P', LOCAL), 1116 | ('T', FREE), *RHS_listP_or_setT], 1117 | # Alias4: bound and default 1118 | *[ 1119 | ('Alias4', GLOBAL), # ... 1120 | *[('T', LOCAL), ('int', BUILTIN)], 1121 | *[('P', LOCAL), ('int', BUILTIN), ('bool', BUILTIN), ('T', FREE)], 1122 | *RHS_listP_or_setT 1123 | ], 1124 | # remaining stuff, def foo(): ... should be unaffected 1125 | *[ 1126 | ('foo', GLOBAL), 1127 | # mylist: Mylist[int] 1128 | *[('mylist', LOCAL), ('list', BUILTIN), ('int', BUILTIN)], 1129 | # assert len(mylist) == 3 1130 | *[('len', BUILTIN), ('mylist', LOCAL)], 1131 | ], 1132 | ] 1133 | assert [(n.name, n.hl_group) for n in names] == expected 1134 | 1135 | 1136 | class TestNode: 1137 | 1138 | def test_node(self): 1139 | # yapf: disable 1140 | class Symbol: 1141 | def __init__(self, name, **kwargs): 1142 | self.name = name 1143 | for k, v in kwargs.items(): 1144 | setattr(self, 'is_' + k, lambda: v) 1145 | def __getattr__(self, item): 1146 | if item.startswith('is_'): 1147 | return lambda: False 1148 | raise AttributeError(item) 1149 | 1150 | class Table: 1151 | def __init__(self, symbols, type=None): 1152 | self.symbols = symbols 1153 | self.type = type or 'module' 1154 | def lookup(self, name): 1155 | return next(sym for sym in self.symbols if sym.name == name) 1156 | def get_type(self): 1157 | return self.type 1158 | # yapf: enable 1159 | 1160 | a = Node('foo', 0, 0, [Table([Symbol('foo', local=True)])]) 1161 | b = Node('bar', 0, 10, [Table([Symbol('bar', local=True)])]) 1162 | assert a.id + 1 == b.id 1163 | 1164 | 1165 | def test_diff(): 1166 | """The id of a saved name should remain the same so that we can remove 1167 | it later by ID.""" 1168 | parser = Parser() 1169 | add0, rem = parser.parse('foo') 1170 | add, rem = parser.parse('foo ') 1171 | add, rem = parser.parse('foo = 1') 1172 | assert add0[0].id == rem[0].id 1173 | 1174 | 1175 | def test_minor_change(): 1176 | 1177 | def minor_change(c1, c2): 1178 | return Parser._minor_change(c1, c2) 1179 | 1180 | assert minor_change(list('abc'), list('axc')) == (True, 1) 1181 | assert minor_change(list('abc'), list('xbx')) == (False, None) 1182 | assert minor_change(list('abc'), list('abcedf')) == (False, None) 1183 | assert minor_change(list('abc'), list('abc')) == (True, None) 1184 | 1185 | 1186 | def test_specific_grammar(request): 1187 | path = Path(request.fspath.dirname) / \ 1188 | 'data/grammar{0}{1}.py'.format(*sys.version_info[:2]) 1189 | with open(str(path), encoding='utf-8') as f: 1190 | parse(f.read()) 1191 | -------------------------------------------------------------------------------- /test/data/grammar35.py: -------------------------------------------------------------------------------- 1 | # Python test set -- part 1, grammar. 2 | # This just tests whether the parser accepts them all. 3 | 4 | from test.support import check_syntax_error 5 | import inspect 6 | import unittest 7 | import sys 8 | # testing import * 9 | from sys import * 10 | 11 | 12 | class TokenTests(unittest.TestCase): 13 | 14 | def test_backslash(self): 15 | # Backslash means line continuation: 16 | x = 1 \ 17 | + 1 18 | self.assertEqual(x, 2, 'backslash for line continuation') 19 | 20 | # Backslash does not means continuation in comments :\ 21 | x = 0 22 | self.assertEqual(x, 0, 'backslash ending comment') 23 | 24 | def test_plain_integers(self): 25 | self.assertEqual(type(000), type(0)) 26 | self.assertEqual(0xff, 255) 27 | self.assertEqual(0o377, 255) 28 | self.assertEqual(2147483647, 0o17777777777) 29 | self.assertEqual(0b1001, 9) 30 | # "0x" is not a valid literal 31 | self.assertRaises(SyntaxError, eval, "0x") 32 | from sys import maxsize 33 | if maxsize == 2147483647: 34 | self.assertEqual(-2147483647-1, -0o20000000000) 35 | # XXX -2147483648 36 | self.assertTrue(0o37777777777 > 0) 37 | self.assertTrue(0xffffffff > 0) 38 | self.assertTrue(0b1111111111111111111111111111111 > 0) 39 | for s in ('2147483648', '0o40000000000', '0x100000000', 40 | '0b10000000000000000000000000000000'): 41 | try: 42 | x = eval(s) 43 | except OverflowError: 44 | self.fail("OverflowError on huge integer literal %r" % s) 45 | elif maxsize == 9223372036854775807: 46 | self.assertEqual(-9223372036854775807-1, -0o1000000000000000000000) 47 | self.assertTrue(0o1777777777777777777777 > 0) 48 | self.assertTrue(0xffffffffffffffff > 0) 49 | self.assertTrue(0b11111111111111111111111111111111111111111111111111111111111111 > 0) 50 | for s in '9223372036854775808', '0o2000000000000000000000', \ 51 | '0x10000000000000000', \ 52 | '0b100000000000000000000000000000000000000000000000000000000000000': 53 | try: 54 | x = eval(s) 55 | except OverflowError: 56 | self.fail("OverflowError on huge integer literal %r" % s) 57 | else: 58 | self.fail('Weird maxsize value %r' % maxsize) 59 | 60 | def test_long_integers(self): 61 | x = 0 62 | x = 0xffffffffffffffff 63 | x = 0Xffffffffffffffff 64 | x = 0o77777777777777777 65 | x = 0O77777777777777777 66 | x = 123456789012345678901234567890 67 | x = 0b100000000000000000000000000000000000000000000000000000000000000000000 68 | x = 0B111111111111111111111111111111111111111111111111111111111111111111111 69 | 70 | def test_floats(self): 71 | x = 3.14 72 | x = 314. 73 | x = 0.314 74 | # XXX x = 000.314 75 | x = .314 76 | x = 3e14 77 | x = 3E14 78 | x = 3e-14 79 | x = 3e+14 80 | x = 3.e14 81 | x = .3e14 82 | x = 3.1e4 83 | 84 | def test_float_exponent_tokenization(self): 85 | # See issue 21642. 86 | self.assertEqual(1 if 1else 0, 1) 87 | self.assertEqual(1 if 0else 0, 0) 88 | self.assertRaises(SyntaxError, eval, "0 if 1Else 0") 89 | 90 | def test_string_literals(self): 91 | x = ''; y = ""; self.assertTrue(len(x) == 0 and x == y) 92 | x = '\''; y = "'"; self.assertTrue(len(x) == 1 and x == y and ord(x) == 39) 93 | x = '"'; y = "\""; self.assertTrue(len(x) == 1 and x == y and ord(x) == 34) 94 | x = "doesn't \"shrink\" does it" 95 | y = 'doesn\'t "shrink" does it' 96 | self.assertTrue(len(x) == 24 and x == y) 97 | x = "does \"shrink\" doesn't it" 98 | y = 'does "shrink" doesn\'t it' 99 | self.assertTrue(len(x) == 24 and x == y) 100 | x = """ 101 | The "quick" 102 | brown fox 103 | jumps over 104 | the 'lazy' dog. 105 | """ 106 | y = '\nThe "quick"\nbrown fox\njumps over\nthe \'lazy\' dog.\n' 107 | self.assertEqual(x, y) 108 | y = ''' 109 | The "quick" 110 | brown fox 111 | jumps over 112 | the 'lazy' dog. 113 | ''' 114 | self.assertEqual(x, y) 115 | y = "\n\ 116 | The \"quick\"\n\ 117 | brown fox\n\ 118 | jumps over\n\ 119 | the 'lazy' dog.\n\ 120 | " 121 | self.assertEqual(x, y) 122 | y = '\n\ 123 | The \"quick\"\n\ 124 | brown fox\n\ 125 | jumps over\n\ 126 | the \'lazy\' dog.\n\ 127 | ' 128 | self.assertEqual(x, y) 129 | 130 | def test_ellipsis(self): 131 | x = ... 132 | self.assertTrue(x is Ellipsis) 133 | self.assertRaises(SyntaxError, eval, ".. .") 134 | 135 | def test_eof_error(self): 136 | samples = ("def foo(", "\ndef foo(", "def foo(\n") 137 | for s in samples: 138 | with self.assertRaises(SyntaxError) as cm: 139 | compile(s, "", "exec") 140 | self.assertIn("unexpected EOF", str(cm.exception)) 141 | 142 | class GrammarTests(unittest.TestCase): 143 | 144 | # single_input: NEWLINE | simple_stmt | compound_stmt NEWLINE 145 | # XXX can't test in a script -- this rule is only used when interactive 146 | 147 | # file_input: (NEWLINE | stmt)* ENDMARKER 148 | # Being tested as this very moment this very module 149 | 150 | # expr_input: testlist NEWLINE 151 | # XXX Hard to test -- used only in calls to input() 152 | 153 | def test_eval_input(self): 154 | # testlist ENDMARKER 155 | x = eval('1, 0 or 1') 156 | 157 | def test_funcdef(self): 158 | ### [decorators] 'def' NAME parameters ['->' test] ':' suite 159 | ### decorator: '@' dotted_name [ '(' [arglist] ')' ] NEWLINE 160 | ### decorators: decorator+ 161 | ### parameters: '(' [typedargslist] ')' 162 | ### typedargslist: ((tfpdef ['=' test] ',')* 163 | ### ('*' [tfpdef] (',' tfpdef ['=' test])* [',' '**' tfpdef] | '**' tfpdef) 164 | ### | tfpdef ['=' test] (',' tfpdef ['=' test])* [',']) 165 | ### tfpdef: NAME [':' test] 166 | ### varargslist: ((vfpdef ['=' test] ',')* 167 | ### ('*' [vfpdef] (',' vfpdef ['=' test])* [',' '**' vfpdef] | '**' vfpdef) 168 | ### | vfpdef ['=' test] (',' vfpdef ['=' test])* [',']) 169 | ### vfpdef: NAME 170 | def f1(): pass 171 | f1() 172 | f1(*()) 173 | f1(*(), **{}) 174 | def f2(one_argument): pass 175 | def f3(two, arguments): pass 176 | self.assertEqual(f2.__code__.co_varnames, ('one_argument',)) 177 | self.assertEqual(f3.__code__.co_varnames, ('two', 'arguments')) 178 | def a1(one_arg,): pass 179 | def a2(two, args,): pass 180 | def v0(*rest): pass 181 | def v1(a, *rest): pass 182 | def v2(a, b, *rest): pass 183 | 184 | f1() 185 | f2(1) 186 | f2(1,) 187 | f3(1, 2) 188 | f3(1, 2,) 189 | v0() 190 | v0(1) 191 | v0(1,) 192 | v0(1,2) 193 | v0(1,2,3,4,5,6,7,8,9,0) 194 | v1(1) 195 | v1(1,) 196 | v1(1,2) 197 | v1(1,2,3) 198 | v1(1,2,3,4,5,6,7,8,9,0) 199 | v2(1,2) 200 | v2(1,2,3) 201 | v2(1,2,3,4) 202 | v2(1,2,3,4,5,6,7,8,9,0) 203 | 204 | def d01(a=1): pass 205 | d01() 206 | d01(1) 207 | d01(*(1,)) 208 | d01(*[] or [2]) 209 | d01(*() or (), *{} and (), **() or {}) 210 | d01(**{'a':2}) 211 | d01(**{'a':2} or {}) 212 | def d11(a, b=1): pass 213 | d11(1) 214 | d11(1, 2) 215 | d11(1, **{'b':2}) 216 | def d21(a, b, c=1): pass 217 | d21(1, 2) 218 | d21(1, 2, 3) 219 | d21(*(1, 2, 3)) 220 | d21(1, *(2, 3)) 221 | d21(1, 2, *(3,)) 222 | d21(1, 2, **{'c':3}) 223 | def d02(a=1, b=2): pass 224 | d02() 225 | d02(1) 226 | d02(1, 2) 227 | d02(*(1, 2)) 228 | d02(1, *(2,)) 229 | d02(1, **{'b':2}) 230 | d02(**{'a': 1, 'b': 2}) 231 | def d12(a, b=1, c=2): pass 232 | d12(1) 233 | d12(1, 2) 234 | d12(1, 2, 3) 235 | def d22(a, b, c=1, d=2): pass 236 | d22(1, 2) 237 | d22(1, 2, 3) 238 | d22(1, 2, 3, 4) 239 | def d01v(a=1, *rest): pass 240 | d01v() 241 | d01v(1) 242 | d01v(1, 2) 243 | d01v(*(1, 2, 3, 4)) 244 | d01v(*(1,)) 245 | d01v(**{'a':2}) 246 | def d11v(a, b=1, *rest): pass 247 | d11v(1) 248 | d11v(1, 2) 249 | d11v(1, 2, 3) 250 | def d21v(a, b, c=1, *rest): pass 251 | d21v(1, 2) 252 | d21v(1, 2, 3) 253 | d21v(1, 2, 3, 4) 254 | d21v(*(1, 2, 3, 4)) 255 | d21v(1, 2, **{'c': 3}) 256 | def d02v(a=1, b=2, *rest): pass 257 | d02v() 258 | d02v(1) 259 | d02v(1, 2) 260 | d02v(1, 2, 3) 261 | d02v(1, *(2, 3, 4)) 262 | d02v(**{'a': 1, 'b': 2}) 263 | def d12v(a, b=1, c=2, *rest): pass 264 | d12v(1) 265 | d12v(1, 2) 266 | d12v(1, 2, 3) 267 | d12v(1, 2, 3, 4) 268 | d12v(*(1, 2, 3, 4)) 269 | d12v(1, 2, *(3, 4, 5)) 270 | d12v(1, *(2,), **{'c': 3}) 271 | def d22v(a, b, c=1, d=2, *rest): pass 272 | d22v(1, 2) 273 | d22v(1, 2, 3) 274 | d22v(1, 2, 3, 4) 275 | d22v(1, 2, 3, 4, 5) 276 | d22v(*(1, 2, 3, 4)) 277 | d22v(1, 2, *(3, 4, 5)) 278 | d22v(1, *(2, 3), **{'d': 4}) 279 | 280 | # keyword argument type tests 281 | try: 282 | str('x', **{b'foo':1 }) 283 | except TypeError: 284 | pass 285 | else: 286 | self.fail('Bytes should not work as keyword argument names') 287 | # keyword only argument tests 288 | def pos0key1(*, key): return key 289 | pos0key1(key=100) 290 | def pos2key2(p1, p2, *, k1, k2=100): return p1,p2,k1,k2 291 | pos2key2(1, 2, k1=100) 292 | pos2key2(1, 2, k1=100, k2=200) 293 | pos2key2(1, 2, k2=100, k1=200) 294 | def pos2key2dict(p1, p2, *, k1=100, k2, **kwarg): return p1,p2,k1,k2,kwarg 295 | pos2key2dict(1,2,k2=100,tokwarg1=100,tokwarg2=200) 296 | pos2key2dict(1,2,tokwarg1=100,tokwarg2=200, k2=100) 297 | 298 | # keyword arguments after *arglist 299 | def f(*args, **kwargs): 300 | return args, kwargs 301 | self.assertEqual(f(1, x=2, *[3, 4], y=5), ((1, 3, 4), 302 | {'x':2, 'y':5})) 303 | self.assertEqual(f(1, *(2,3), 4), ((1, 2, 3, 4), {})) 304 | self.assertRaises(SyntaxError, eval, "f(1, x=2, *(3,4), x=5)") 305 | self.assertEqual(f(**{'eggs':'scrambled', 'spam':'fried'}), 306 | ((), {'eggs':'scrambled', 'spam':'fried'})) 307 | self.assertEqual(f(spam='fried', **{'eggs':'scrambled'}), 308 | ((), {'eggs':'scrambled', 'spam':'fried'})) 309 | 310 | # argument annotation tests 311 | def f(x) -> list: pass 312 | self.assertEqual(f.__annotations__, {'return': list}) 313 | def f(x: int): pass 314 | self.assertEqual(f.__annotations__, {'x': int}) 315 | def f(*x: str): pass 316 | self.assertEqual(f.__annotations__, {'x': str}) 317 | def f(**x: float): pass 318 | self.assertEqual(f.__annotations__, {'x': float}) 319 | def f(x, y: 1+2): pass 320 | self.assertEqual(f.__annotations__, {'y': 3}) 321 | def f(a, b: 1, c: 2, d): pass 322 | self.assertEqual(f.__annotations__, {'b': 1, 'c': 2}) 323 | def f(a, b: 1, c: 2, d, e: 3 = 4, f=5, *g: 6): pass 324 | self.assertEqual(f.__annotations__, 325 | {'b': 1, 'c': 2, 'e': 3, 'g': 6}) 326 | def f(a, b: 1, c: 2, d, e: 3 = 4, f=5, *g: 6, h: 7, i=8, j: 9 = 10, 327 | **k: 11) -> 12: pass 328 | self.assertEqual(f.__annotations__, 329 | {'b': 1, 'c': 2, 'e': 3, 'g': 6, 'h': 7, 'j': 9, 330 | 'k': 11, 'return': 12}) 331 | # Check for issue #20625 -- annotations mangling 332 | class Spam: 333 | def f(self, *, __kw: 1): 334 | pass 335 | class Ham(Spam): pass 336 | self.assertEqual(Spam.f.__annotations__, {'_Spam__kw': 1}) 337 | self.assertEqual(Ham.f.__annotations__, {'_Spam__kw': 1}) 338 | # Check for SF Bug #1697248 - mixing decorators and a return annotation 339 | def null(x): return x 340 | @null 341 | def f(x) -> list: pass 342 | self.assertEqual(f.__annotations__, {'return': list}) 343 | 344 | # test MAKE_CLOSURE with a variety of oparg's 345 | closure = 1 346 | def f(): return closure 347 | def f(x=1): return closure 348 | def f(*, k=1): return closure 349 | def f() -> int: return closure 350 | 351 | # Check ast errors in *args and *kwargs 352 | check_syntax_error(self, "f(*g(1=2))") 353 | check_syntax_error(self, "f(**g(1=2))") 354 | 355 | def test_lambdef(self): 356 | ### lambdef: 'lambda' [varargslist] ':' test 357 | l1 = lambda : 0 358 | self.assertEqual(l1(), 0) 359 | l2 = lambda : a[d] # XXX just testing the expression 360 | l3 = lambda : [2 < x for x in [-1, 3, 0]] 361 | self.assertEqual(l3(), [0, 1, 0]) 362 | l4 = lambda x = lambda y = lambda z=1 : z : y() : x() 363 | self.assertEqual(l4(), 1) 364 | l5 = lambda x, y, z=2: x + y + z 365 | self.assertEqual(l5(1, 2), 5) 366 | self.assertEqual(l5(1, 2, 3), 6) 367 | check_syntax_error(self, "lambda x: x = 2") 368 | check_syntax_error(self, "lambda (None,): None") 369 | l6 = lambda x, y, *, k=20: x+y+k 370 | self.assertEqual(l6(1,2), 1+2+20) 371 | self.assertEqual(l6(1,2,k=10), 1+2+10) 372 | 373 | 374 | ### stmt: simple_stmt | compound_stmt 375 | # Tested below 376 | 377 | def test_simple_stmt(self): 378 | ### simple_stmt: small_stmt (';' small_stmt)* [';'] 379 | x = 1; pass; del x 380 | def foo(): 381 | # verify statements that end with semi-colons 382 | x = 1; pass; del x; 383 | foo() 384 | 385 | ### small_stmt: expr_stmt | pass_stmt | del_stmt | flow_stmt | import_stmt | global_stmt | access_stmt 386 | # Tested below 387 | 388 | def test_expr_stmt(self): 389 | # (exprlist '=')* exprlist 390 | 1 391 | 1, 2, 3 392 | x = 1 393 | x = 1, 2, 3 394 | x = y = z = 1, 2, 3 395 | x, y, z = 1, 2, 3 396 | abc = a, b, c = x, y, z = xyz = 1, 2, (3, 4) 397 | 398 | check_syntax_error(self, "x + 1 = 1") 399 | check_syntax_error(self, "a + 1 = b + 2") 400 | 401 | # Check the heuristic for print & exec covers significant cases 402 | # As well as placing some limits on false positives 403 | def test_former_statements_refer_to_builtins(self): 404 | keywords = "print", "exec" 405 | # Cases where we want the custom error 406 | cases = [ 407 | "{} foo", 408 | "{} {{1:foo}}", 409 | "if 1: {} foo", 410 | "if 1: {} {{1:foo}}", 411 | "if 1:\n {} foo", 412 | "if 1:\n {} {{1:foo}}", 413 | ] 414 | for keyword in keywords: 415 | custom_msg = "call to '{}'".format(keyword) 416 | for case in cases: 417 | source = case.format(keyword) 418 | with self.subTest(source=source): 419 | with self.assertRaisesRegex(SyntaxError, custom_msg): 420 | exec(source) 421 | source = source.replace("foo", "(foo.)") 422 | with self.subTest(source=source): 423 | with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 424 | exec(source) 425 | 426 | def test_del_stmt(self): 427 | # 'del' exprlist 428 | abc = [1,2,3] 429 | x, y, z = abc 430 | xyz = x, y, z 431 | 432 | del abc 433 | del x, y, (z, xyz) 434 | 435 | def test_pass_stmt(self): 436 | # 'pass' 437 | pass 438 | 439 | # flow_stmt: break_stmt | continue_stmt | return_stmt | raise_stmt 440 | # Tested below 441 | 442 | def test_break_stmt(self): 443 | # 'break' 444 | while 1: break 445 | 446 | def test_continue_stmt(self): 447 | # 'continue' 448 | i = 1 449 | while i: i = 0; continue 450 | 451 | msg = "" 452 | while not msg: 453 | msg = "ok" 454 | try: 455 | continue 456 | msg = "continue failed to continue inside try" 457 | except: 458 | msg = "continue inside try called except block" 459 | if msg != "ok": 460 | self.fail(msg) 461 | 462 | msg = "" 463 | while not msg: 464 | msg = "finally block not called" 465 | try: 466 | continue 467 | finally: 468 | msg = "ok" 469 | if msg != "ok": 470 | self.fail(msg) 471 | 472 | def test_break_continue_loop(self): 473 | # This test warrants an explanation. It is a test specifically for SF bugs 474 | # #463359 and #462937. The bug is that a 'break' statement executed or 475 | # exception raised inside a try/except inside a loop, *after* a continue 476 | # statement has been executed in that loop, will cause the wrong number of 477 | # arguments to be popped off the stack and the instruction pointer reset to 478 | # a very small number (usually 0.) Because of this, the following test 479 | # *must* written as a function, and the tracking vars *must* be function 480 | # arguments with default values. Otherwise, the test will loop and loop. 481 | 482 | def test_inner(extra_burning_oil = 1, count=0): 483 | big_hippo = 2 484 | while big_hippo: 485 | count += 1 486 | try: 487 | if extra_burning_oil and big_hippo == 1: 488 | extra_burning_oil -= 1 489 | break 490 | big_hippo -= 1 491 | continue 492 | except: 493 | raise 494 | if count > 2 or big_hippo != 1: 495 | self.fail("continue then break in try/except in loop broken!") 496 | test_inner() 497 | 498 | def test_return(self): 499 | # 'return' [testlist] 500 | def g1(): return 501 | def g2(): return 1 502 | g1() 503 | x = g2() 504 | check_syntax_error(self, "class foo:return 1") 505 | 506 | def test_yield(self): 507 | # Allowed as standalone statement 508 | def g(): yield 1 509 | def g(): yield from () 510 | # Allowed as RHS of assignment 511 | def g(): x = yield 1 512 | def g(): x = yield from () 513 | # Ordinary yield accepts implicit tuples 514 | def g(): yield 1, 1 515 | def g(): x = yield 1, 1 516 | # 'yield from' does not 517 | check_syntax_error(self, "def g(): yield from (), 1") 518 | check_syntax_error(self, "def g(): x = yield from (), 1") 519 | # Requires parentheses as subexpression 520 | def g(): 1, (yield 1) 521 | def g(): 1, (yield from ()) 522 | check_syntax_error(self, "def g(): 1, yield 1") 523 | check_syntax_error(self, "def g(): 1, yield from ()") 524 | # Requires parentheses as call argument 525 | def g(): f((yield 1)) 526 | def g(): f((yield 1), 1) 527 | def g(): f((yield from ())) 528 | def g(): f((yield from ()), 1) 529 | check_syntax_error(self, "def g(): f(yield 1)") 530 | check_syntax_error(self, "def g(): f(yield 1, 1)") 531 | check_syntax_error(self, "def g(): f(yield from ())") 532 | check_syntax_error(self, "def g(): f(yield from (), 1)") 533 | # Not allowed at top level 534 | check_syntax_error(self, "yield") 535 | check_syntax_error(self, "yield from") 536 | # Not allowed at class scope 537 | check_syntax_error(self, "class foo:yield 1") 538 | check_syntax_error(self, "class foo:yield from ()") 539 | # Check annotation refleak on SyntaxError 540 | check_syntax_error(self, "def g(a:(yield)): pass") 541 | 542 | def test_raise(self): 543 | # 'raise' test [',' test] 544 | try: raise RuntimeError('just testing') 545 | except RuntimeError: pass 546 | try: raise KeyboardInterrupt 547 | except KeyboardInterrupt: pass 548 | 549 | def test_import(self): 550 | # 'import' dotted_as_names 551 | import sys 552 | import time, sys 553 | # 'from' dotted_name 'import' ('*' | '(' import_as_names ')' | import_as_names) 554 | from time import time 555 | from time import (time) 556 | # not testable inside a function, but already done at top of the module 557 | # from sys import * 558 | from sys import path, argv 559 | from sys import (path, argv) 560 | from sys import (path, argv,) 561 | 562 | def test_global(self): 563 | # 'global' NAME (',' NAME)* 564 | global a 565 | global a, b 566 | global one, two, three, four, five, six, seven, eight, nine, ten 567 | 568 | def test_nonlocal(self): 569 | # 'nonlocal' NAME (',' NAME)* 570 | x = 0 571 | y = 0 572 | def f(): 573 | nonlocal x 574 | nonlocal x, y 575 | 576 | def test_assert(self): 577 | # assertTruestmt: 'assert' test [',' test] 578 | assert 1 579 | assert 1, 1 580 | assert lambda x:x 581 | assert 1, lambda x:x+1 582 | 583 | try: 584 | assert True 585 | except AssertionError as e: 586 | self.fail("'assert True' should not have raised an AssertionError") 587 | 588 | try: 589 | assert True, 'this should always pass' 590 | except AssertionError as e: 591 | self.fail("'assert True, msg' should not have " 592 | "raised an AssertionError") 593 | 594 | # these tests fail if python is run with -O, so check __debug__ 595 | @unittest.skipUnless(__debug__, "Won't work if __debug__ is False") 596 | def testAssert2(self): 597 | try: 598 | assert 0, "msg" 599 | except AssertionError as e: 600 | self.assertEqual(e.args[0], "msg") 601 | else: 602 | self.fail("AssertionError not raised by assert 0") 603 | 604 | try: 605 | assert False 606 | except AssertionError as e: 607 | self.assertEqual(len(e.args), 0) 608 | else: 609 | self.fail("AssertionError not raised by 'assert False'") 610 | 611 | 612 | ### compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | funcdef | classdef 613 | # Tested below 614 | 615 | def test_if(self): 616 | # 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite] 617 | if 1: pass 618 | if 1: pass 619 | else: pass 620 | if 0: pass 621 | elif 0: pass 622 | if 0: pass 623 | elif 0: pass 624 | elif 0: pass 625 | elif 0: pass 626 | else: pass 627 | 628 | def test_while(self): 629 | # 'while' test ':' suite ['else' ':' suite] 630 | while 0: pass 631 | while 0: pass 632 | else: pass 633 | 634 | # Issue1920: "while 0" is optimized away, 635 | # ensure that the "else" clause is still present. 636 | x = 0 637 | while 0: 638 | x = 1 639 | else: 640 | x = 2 641 | self.assertEqual(x, 2) 642 | 643 | def test_for(self): 644 | # 'for' exprlist 'in' exprlist ':' suite ['else' ':' suite] 645 | for i in 1, 2, 3: pass 646 | for i, j, k in (): pass 647 | else: pass 648 | class Squares: 649 | def __init__(self, max): 650 | self.max = max 651 | self.sofar = [] 652 | def __len__(self): return len(self.sofar) 653 | def __getitem__(self, i): 654 | if not 0 <= i < self.max: raise IndexError 655 | n = len(self.sofar) 656 | while n <= i: 657 | self.sofar.append(n*n) 658 | n = n+1 659 | return self.sofar[i] 660 | n = 0 661 | for x in Squares(10): n = n+x 662 | if n != 285: 663 | self.fail('for over growing sequence') 664 | 665 | result = [] 666 | for x, in [(1,), (2,), (3,)]: 667 | result.append(x) 668 | self.assertEqual(result, [1, 2, 3]) 669 | 670 | def test_try(self): 671 | ### try_stmt: 'try' ':' suite (except_clause ':' suite)+ ['else' ':' suite] 672 | ### | 'try' ':' suite 'finally' ':' suite 673 | ### except_clause: 'except' [expr ['as' expr]] 674 | try: 675 | 1/0 676 | except ZeroDivisionError: 677 | pass 678 | else: 679 | pass 680 | try: 1/0 681 | except EOFError: pass 682 | except TypeError as msg: pass 683 | except RuntimeError as msg: pass 684 | except: pass 685 | else: pass 686 | try: 1/0 687 | except (EOFError, TypeError, ZeroDivisionError): pass 688 | try: 1/0 689 | except (EOFError, TypeError, ZeroDivisionError) as msg: pass 690 | try: pass 691 | finally: pass 692 | 693 | def test_suite(self): 694 | # simple_stmt | NEWLINE INDENT NEWLINE* (stmt NEWLINE*)+ DEDENT 695 | if 1: pass 696 | if 1: 697 | pass 698 | if 1: 699 | # 700 | # 701 | # 702 | pass 703 | pass 704 | # 705 | pass 706 | # 707 | 708 | def test_test(self): 709 | ### and_test ('or' and_test)* 710 | ### and_test: not_test ('and' not_test)* 711 | ### not_test: 'not' not_test | comparison 712 | if not 1: pass 713 | if 1 and 1: pass 714 | if 1 or 1: pass 715 | if not not not 1: pass 716 | if not 1 and 1 and 1: pass 717 | if 1 and 1 or 1 and 1 and 1 or not 1 and 1: pass 718 | 719 | def test_comparison(self): 720 | ### comparison: expr (comp_op expr)* 721 | ### comp_op: '<'|'>'|'=='|'>='|'<='|'!='|'in'|'not' 'in'|'is'|'is' 'not' 722 | if 1: pass 723 | x = (1 == 1) 724 | if 1 == 1: pass 725 | if 1 != 1: pass 726 | if 1 < 1: pass 727 | if 1 > 1: pass 728 | if 1 <= 1: pass 729 | if 1 >= 1: pass 730 | if 1 is 1: pass 731 | if 1 is not 1: pass 732 | if 1 in (): pass 733 | if 1 not in (): pass 734 | if 1 < 1 > 1 == 1 >= 1 <= 1 != 1 in 1 not in 1 is 1 is not 1: pass 735 | 736 | def test_binary_mask_ops(self): 737 | x = 1 & 1 738 | x = 1 ^ 1 739 | x = 1 | 1 740 | 741 | def test_shift_ops(self): 742 | x = 1 << 1 743 | x = 1 >> 1 744 | x = 1 << 1 >> 1 745 | 746 | def test_additive_ops(self): 747 | x = 1 748 | x = 1 + 1 749 | x = 1 - 1 - 1 750 | x = 1 - 1 + 1 - 1 + 1 751 | 752 | def test_multiplicative_ops(self): 753 | x = 1 * 1 754 | x = 1 / 1 755 | x = 1 % 1 756 | x = 1 / 1 * 1 % 1 757 | 758 | def test_unary_ops(self): 759 | x = +1 760 | x = -1 761 | x = ~1 762 | x = ~1 ^ 1 & 1 | 1 & 1 ^ -1 763 | x = -1*1/1 + 1*1 - ---1*1 764 | 765 | def test_selectors(self): 766 | ### trailer: '(' [testlist] ')' | '[' subscript ']' | '.' NAME 767 | ### subscript: expr | [expr] ':' [expr] 768 | 769 | import sys, time 770 | c = sys.path[0] 771 | x = time.time() 772 | x = sys.modules['time'].time() 773 | a = '01234' 774 | c = a[0] 775 | c = a[-1] 776 | s = a[0:5] 777 | s = a[:5] 778 | s = a[0:] 779 | s = a[:] 780 | s = a[-5:] 781 | s = a[:-1] 782 | s = a[-4:-3] 783 | # A rough test of SF bug 1333982. http://python.org/sf/1333982 784 | # The testing here is fairly incomplete. 785 | # Test cases should include: commas with 1 and 2 colons 786 | d = {} 787 | d[1] = 1 788 | d[1,] = 2 789 | d[1,2] = 3 790 | d[1,2,3] = 4 791 | L = list(d) 792 | L.sort(key=lambda x: x if isinstance(x, tuple) else ()) 793 | self.assertEqual(str(L), '[1, (1,), (1, 2), (1, 2, 3)]') 794 | 795 | def test_atoms(self): 796 | ### atom: '(' [testlist] ')' | '[' [testlist] ']' | '{' [dictsetmaker] '}' | NAME | NUMBER | STRING 797 | ### dictsetmaker: (test ':' test (',' test ':' test)* [',']) | (test (',' test)* [',']) 798 | 799 | x = (1) 800 | x = (1 or 2 or 3) 801 | x = (1 or 2 or 3, 2, 3) 802 | 803 | x = [] 804 | x = [1] 805 | x = [1 or 2 or 3] 806 | x = [1 or 2 or 3, 2, 3] 807 | x = [] 808 | 809 | x = {} 810 | x = {'one': 1} 811 | x = {'one': 1,} 812 | x = {'one' or 'two': 1 or 2} 813 | x = {'one': 1, 'two': 2} 814 | x = {'one': 1, 'two': 2,} 815 | x = {'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5, 'six': 6} 816 | 817 | x = {'one'} 818 | x = {'one', 1,} 819 | x = {'one', 'two', 'three'} 820 | x = {2, 3, 4,} 821 | 822 | x = x 823 | x = 'x' 824 | x = 123 825 | 826 | ### exprlist: expr (',' expr)* [','] 827 | ### testlist: test (',' test)* [','] 828 | # These have been exercised enough above 829 | 830 | def test_classdef(self): 831 | # 'class' NAME ['(' [testlist] ')'] ':' suite 832 | class B: pass 833 | class B2(): pass 834 | class C1(B): pass 835 | class C2(B): pass 836 | class D(C1, C2, B): pass 837 | class C: 838 | def meth1(self): pass 839 | def meth2(self, arg): pass 840 | def meth3(self, a1, a2): pass 841 | 842 | # decorator: '@' dotted_name [ '(' [arglist] ')' ] NEWLINE 843 | # decorators: decorator+ 844 | # decorated: decorators (classdef | funcdef) 845 | def class_decorator(x): return x 846 | @class_decorator 847 | class G: pass 848 | 849 | def test_dictcomps(self): 850 | # dictorsetmaker: ( (test ':' test (comp_for | 851 | # (',' test ':' test)* [','])) | 852 | # (test (comp_for | (',' test)* [','])) ) 853 | nums = [1, 2, 3] 854 | self.assertEqual({i:i+1 for i in nums}, {1: 2, 2: 3, 3: 4}) 855 | 856 | def test_listcomps(self): 857 | # list comprehension tests 858 | nums = [1, 2, 3, 4, 5] 859 | strs = ["Apple", "Banana", "Coconut"] 860 | spcs = [" Apple", " Banana ", "Coco nut "] 861 | 862 | self.assertEqual([s.strip() for s in spcs], ['Apple', 'Banana', 'Coco nut']) 863 | self.assertEqual([3 * x for x in nums], [3, 6, 9, 12, 15]) 864 | self.assertEqual([x for x in nums if x > 2], [3, 4, 5]) 865 | self.assertEqual([(i, s) for i in nums for s in strs], 866 | [(1, 'Apple'), (1, 'Banana'), (1, 'Coconut'), 867 | (2, 'Apple'), (2, 'Banana'), (2, 'Coconut'), 868 | (3, 'Apple'), (3, 'Banana'), (3, 'Coconut'), 869 | (4, 'Apple'), (4, 'Banana'), (4, 'Coconut'), 870 | (5, 'Apple'), (5, 'Banana'), (5, 'Coconut')]) 871 | self.assertEqual([(i, s) for i in nums for s in [f for f in strs if "n" in f]], 872 | [(1, 'Banana'), (1, 'Coconut'), (2, 'Banana'), (2, 'Coconut'), 873 | (3, 'Banana'), (3, 'Coconut'), (4, 'Banana'), (4, 'Coconut'), 874 | (5, 'Banana'), (5, 'Coconut')]) 875 | self.assertEqual([(lambda a:[a**i for i in range(a+1)])(j) for j in range(5)], 876 | [[1], [1, 1], [1, 2, 4], [1, 3, 9, 27], [1, 4, 16, 64, 256]]) 877 | 878 | def test_in_func(l): 879 | return [0 < x < 3 for x in l if x > 2] 880 | 881 | self.assertEqual(test_in_func(nums), [False, False, False]) 882 | 883 | def test_nested_front(): 884 | self.assertEqual([[y for y in [x, x + 1]] for x in [1,3,5]], 885 | [[1, 2], [3, 4], [5, 6]]) 886 | 887 | test_nested_front() 888 | 889 | check_syntax_error(self, "[i, s for i in nums for s in strs]") 890 | check_syntax_error(self, "[x if y]") 891 | 892 | suppliers = [ 893 | (1, "Boeing"), 894 | (2, "Ford"), 895 | (3, "Macdonalds") 896 | ] 897 | 898 | parts = [ 899 | (10, "Airliner"), 900 | (20, "Engine"), 901 | (30, "Cheeseburger") 902 | ] 903 | 904 | suppart = [ 905 | (1, 10), (1, 20), (2, 20), (3, 30) 906 | ] 907 | 908 | x = [ 909 | (sname, pname) 910 | for (sno, sname) in suppliers 911 | for (pno, pname) in parts 912 | for (sp_sno, sp_pno) in suppart 913 | if sno == sp_sno and pno == sp_pno 914 | ] 915 | 916 | self.assertEqual(x, [('Boeing', 'Airliner'), ('Boeing', 'Engine'), ('Ford', 'Engine'), 917 | ('Macdonalds', 'Cheeseburger')]) 918 | 919 | def test_genexps(self): 920 | # generator expression tests 921 | g = ([x for x in range(10)] for x in range(1)) 922 | self.assertEqual(next(g), [x for x in range(10)]) 923 | try: 924 | next(g) 925 | self.fail('should produce StopIteration exception') 926 | except StopIteration: 927 | pass 928 | 929 | a = 1 930 | try: 931 | g = (a for d in a) 932 | next(g) 933 | self.fail('should produce TypeError') 934 | except TypeError: 935 | pass 936 | 937 | self.assertEqual(list((x, y) for x in 'abcd' for y in 'abcd'), [(x, y) for x in 'abcd' for y in 'abcd']) 938 | self.assertEqual(list((x, y) for x in 'ab' for y in 'xy'), [(x, y) for x in 'ab' for y in 'xy']) 939 | 940 | a = [x for x in range(10)] 941 | b = (x for x in (y for y in a)) 942 | self.assertEqual(sum(b), sum([x for x in range(10)])) 943 | 944 | self.assertEqual(sum(x**2 for x in range(10)), sum([x**2 for x in range(10)])) 945 | self.assertEqual(sum(x*x for x in range(10) if x%2), sum([x*x for x in range(10) if x%2])) 946 | self.assertEqual(sum(x for x in (y for y in range(10))), sum([x for x in range(10)])) 947 | self.assertEqual(sum(x for x in (y for y in (z for z in range(10)))), sum([x for x in range(10)])) 948 | self.assertEqual(sum(x for x in [y for y in (z for z in range(10))]), sum([x for x in range(10)])) 949 | self.assertEqual(sum(x for x in (y for y in (z for z in range(10) if True)) if True), sum([x for x in range(10)])) 950 | self.assertEqual(sum(x for x in (y for y in (z for z in range(10) if True) if False) if True), 0) 951 | check_syntax_error(self, "foo(x for x in range(10), 100)") 952 | check_syntax_error(self, "foo(100, x for x in range(10))") 953 | 954 | def test_comprehension_specials(self): 955 | # test for outmost iterable precomputation 956 | x = 10; g = (i for i in range(x)); x = 5 957 | self.assertEqual(len(list(g)), 10) 958 | 959 | # This should hold, since we're only precomputing outmost iterable. 960 | x = 10; t = False; g = ((i,j) for i in range(x) if t for j in range(x)) 961 | x = 5; t = True; 962 | self.assertEqual([(i,j) for i in range(10) for j in range(5)], list(g)) 963 | 964 | # Grammar allows multiple adjacent 'if's in listcomps and genexps, 965 | # even though it's silly. Make sure it works (ifelse broke this.) 966 | self.assertEqual([ x for x in range(10) if x % 2 if x % 3 ], [1, 5, 7]) 967 | self.assertEqual(list(x for x in range(10) if x % 2 if x % 3), [1, 5, 7]) 968 | 969 | # verify unpacking single element tuples in listcomp/genexp. 970 | self.assertEqual([x for x, in [(4,), (5,), (6,)]], [4, 5, 6]) 971 | self.assertEqual(list(x for x, in [(7,), (8,), (9,)]), [7, 8, 9]) 972 | 973 | def test_with_statement(self): 974 | class manager(object): 975 | def __enter__(self): 976 | return (1, 2) 977 | def __exit__(self, *args): 978 | pass 979 | 980 | with manager(): 981 | pass 982 | with manager() as x: 983 | pass 984 | with manager() as (x, y): 985 | pass 986 | with manager(), manager(): 987 | pass 988 | with manager() as x, manager() as y: 989 | pass 990 | with manager() as x, manager(): 991 | pass 992 | 993 | def test_if_else_expr(self): 994 | # Test ifelse expressions in various cases 995 | def _checkeval(msg, ret): 996 | "helper to check that evaluation of expressions is done correctly" 997 | print(msg) 998 | return ret 999 | 1000 | # the next line is not allowed anymore 1001 | #self.assertEqual([ x() for x in lambda: True, lambda: False if x() ], [True]) 1002 | self.assertEqual([ x() for x in (lambda: True, lambda: False) if x() ], [True]) 1003 | self.assertEqual([ x(False) for x in (lambda x: False if x else True, lambda x: True if x else False) if x(False) ], [True]) 1004 | self.assertEqual((5 if 1 else _checkeval("check 1", 0)), 5) 1005 | self.assertEqual((_checkeval("check 2", 0) if 0 else 5), 5) 1006 | self.assertEqual((5 and 6 if 0 else 1), 1) 1007 | self.assertEqual(((5 and 6) if 0 else 1), 1) 1008 | self.assertEqual((5 and (6 if 1 else 1)), 6) 1009 | self.assertEqual((0 or _checkeval("check 3", 2) if 0 else 3), 3) 1010 | self.assertEqual((1 or _checkeval("check 4", 2) if 1 else _checkeval("check 5", 3)), 1) 1011 | self.assertEqual((0 or 5 if 1 else _checkeval("check 6", 3)), 5) 1012 | self.assertEqual((not 5 if 1 else 1), False) 1013 | self.assertEqual((not 5 if 0 else 1), 1) 1014 | self.assertEqual((6 + 1 if 1 else 2), 7) 1015 | self.assertEqual((6 - 1 if 1 else 2), 5) 1016 | self.assertEqual((6 * 2 if 1 else 4), 12) 1017 | self.assertEqual((6 / 2 if 1 else 3), 3) 1018 | self.assertEqual((6 < 4 if 0 else 2), 2) 1019 | 1020 | def test_paren_evaluation(self): 1021 | self.assertEqual(16 // (4 // 2), 8) 1022 | self.assertEqual((16 // 4) // 2, 2) 1023 | self.assertEqual(16 // 4 // 2, 2) 1024 | self.assertTrue(False is (2 is 3)) 1025 | self.assertFalse((False is 2) is 3) 1026 | self.assertFalse(False is 2 is 3) 1027 | 1028 | def test_matrix_mul(self): 1029 | # This is not intended to be a comprehensive test, rather just to be few 1030 | # samples of the @ operator in test_grammar.py. 1031 | class M: 1032 | def __matmul__(self, o): 1033 | return 4 1034 | def __imatmul__(self, o): 1035 | self.other = o 1036 | return self 1037 | m = M() 1038 | self.assertEqual(m @ m, 4) 1039 | m @= 42 1040 | self.assertEqual(m.other, 42) 1041 | 1042 | def test_async_await(self): 1043 | async = 1 1044 | await = 2 1045 | self.assertEqual(async, 1) 1046 | 1047 | def async(): 1048 | nonlocal await 1049 | await = 10 1050 | async() 1051 | self.assertEqual(await, 10) 1052 | 1053 | self.assertFalse(bool(async.__code__.co_flags & inspect.CO_COROUTINE)) 1054 | 1055 | async def test(): 1056 | def sum(): 1057 | pass 1058 | if 1: 1059 | await someobj() 1060 | 1061 | self.assertEqual(test.__name__, 'test') 1062 | self.assertTrue(bool(test.__code__.co_flags & inspect.CO_COROUTINE)) 1063 | 1064 | def decorator(func): 1065 | setattr(func, '_marked', True) 1066 | return func 1067 | 1068 | @decorator 1069 | async def test2(): 1070 | return 22 1071 | self.assertTrue(test2._marked) 1072 | self.assertEqual(test2.__name__, 'test2') 1073 | self.assertTrue(bool(test2.__code__.co_flags & inspect.CO_COROUTINE)) 1074 | 1075 | def test_async_for(self): 1076 | class Done(Exception): pass 1077 | 1078 | class AIter: 1079 | def __aiter__(self): 1080 | return self 1081 | async def __anext__(self): 1082 | raise StopAsyncIteration 1083 | 1084 | async def foo(): 1085 | async for i in AIter(): 1086 | pass 1087 | async for i, j in AIter(): 1088 | pass 1089 | async for i in AIter(): 1090 | pass 1091 | else: 1092 | pass 1093 | raise Done 1094 | 1095 | with self.assertRaises(Done): 1096 | foo().send(None) 1097 | 1098 | def test_async_with(self): 1099 | class Done(Exception): pass 1100 | 1101 | class manager: 1102 | async def __aenter__(self): 1103 | return (1, 2) 1104 | async def __aexit__(self, *exc): 1105 | return False 1106 | 1107 | async def foo(): 1108 | async with manager(): 1109 | pass 1110 | async with manager() as x: 1111 | pass 1112 | async with manager() as (x, y): 1113 | pass 1114 | async with manager(), manager(): 1115 | pass 1116 | async with manager() as x, manager() as y: 1117 | pass 1118 | async with manager() as x, manager(): 1119 | pass 1120 | raise Done 1121 | 1122 | with self.assertRaises(Done): 1123 | foo().send(None) 1124 | 1125 | 1126 | if __name__ == '__main__': 1127 | unittest.main() 1128 | --------------------------------------------------------------------------------