├── .github └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── docsmith.py ├── pyproject.toml └── tests └── test_docsmith.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ["3.12", "3.13"] 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | cache: pip 23 | cache-dependency-path: pyproject.toml 24 | - name: Install dependencies 25 | run: | 26 | pip install -e '.[test]' 27 | - name: Run tests 28 | run: | 29 | python -m pytest 30 | deploy: 31 | runs-on: ubuntu-latest 32 | needs: [test] 33 | environment: pypi 34 | permissions: 35 | id-token: write 36 | steps: 37 | - uses: actions/checkout@v4 38 | - name: Set up Python 39 | uses: actions/setup-python@v5 40 | with: 41 | python-version: "3.13" 42 | cache: pip 43 | cache-dependency-path: pyproject.toml 44 | - name: Install dependencies 45 | run: | 46 | pip install setuptools wheel build 47 | - name: Build 48 | run: | 49 | python -m build 50 | - name: Publish 51 | uses: pypa/gh-action-pypi-publish@release/v1 52 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [pull_request] 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.12", "3.13"] 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | cache: pip 21 | cache-dependency-path: pyproject.toml 22 | - name: Install dependencies 23 | run: | 24 | pip install -e '.[test]' 25 | - name: Run tests 26 | run: | 27 | python -m pytest 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm-docsmith 2 | 3 | Generate Python docstrings automatically with LLM and syntax trees. 4 | 5 | ## Installation 6 | 7 | Install this plugin in the same environment as [LLM](https://llm.datasette.io/en/stable/). 8 | 9 | ```bash 10 | llm install llm-docsmith 11 | ``` 12 | 13 | ## Usage 14 | 15 | Pass a Python file as argument to `llm docsmith`: 16 | 17 | ```bash 18 | llm docsmith ./scripts/main.py 19 | ``` 20 | 21 | The file will be edited to include the generated docstrings. 22 | 23 | Options: 24 | 25 | - `-m/--model`: Use a model other than the configured LLM default model 26 | - `-o/--output`: Only show the modified code, without modifying the file 27 | - `-v/--verbose`: Verbose output of prompt and response 28 | - `--git`: Only update docstrings for functions and classes that have been changed since the last commit 29 | - `--git-base`: Git reference to compare against (default: HEAD) 30 | - `--only-missing`: Only add docstrings to functions and classes that don't have them yet 31 | 32 | ## Only Missing Mode 33 | 34 | The `--only-missing` flag tells docsmith to only generate docstrings for functions and classes that don't already have one. This is useful when you want to preserve existing docstrings and only add documentation where it's missing: 35 | 36 | ```bash 37 | # Only add docstrings where they are missing 38 | llm docsmith ./scripts/main.py --only-missing 39 | ``` 40 | 41 | ## Git Integration 42 | 43 | The `--git` flag enables Git integration, which will only update docstrings for functions and classes that have been changed since the last commit. 44 | This is useful for large codebases where you only want to update docstrings for modified code. 45 | 46 | ```bash 47 | # Update docstrings only for changed functions and classes 48 | llm docsmith ./scripts/main.py --git 49 | ``` 50 | 51 | It's also possible to compare against a different git reference to find changed, for instance the main branch: 52 | 53 | ```bash 54 | # Compare against a specific Git reference 55 | llm docsmith ./scripts/main.py --git --git-base main 56 | ``` 57 | -------------------------------------------------------------------------------- /docsmith.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import re 4 | import subprocess 5 | from dataclasses import dataclass, field 6 | from functools import partial 7 | from typing import ( 8 | Literal, 9 | Protocol, 10 | runtime_checkable, 11 | ) 12 | 13 | import click 14 | import libcst as cst 15 | import llm 16 | from pydantic import BaseModel 17 | 18 | SYSTEM_PROMPT = """ 19 | You are a coding assistant whose task is to generate docstrings for existing Python code. 20 | You will receive code without any docstrings. 21 | Generate the appropiate docstrings for each function, class or method. 22 | 23 | Do not return any code. Use the context only to learn about the code. 24 | Write documentation only for the code provided as input code. 25 | 26 | The docstring for a function or method should summarize its behavior, side effects, exceptions raised, 27 | and restrictions on when it can be called (all if applicable). 28 | Only mention exceptions if there is at least one _explicitly_ raised or reraised exception inside the function or method. 29 | The docstring prescribes the function or method's effect as a command, not as a description; e.g. don't write "Returns the pathname ...". 30 | Do not explain implementation details, do not include information about arguments and return here. 31 | If the docstring is multiline, the first line should be a very short summary, followed by a blank line and a more ellaborate description. 32 | Write single-line docstrings if the function is simple. 33 | The docstring for a class should summarize its behavior and list the public methods (one by line) and instance variables. 34 | 35 | In the Argument object, describe each argument. In the return object, describe the returned values of the function, if any. 36 | 37 | You will receive a JSON template. Fill the slots marked with with the appropriate description. Return as JSON. 38 | """ 39 | 40 | PROMPT_TEMPLATE = """ 41 | {CONTEXT} 42 | 43 | Input code: 44 | 45 | ```python 46 | {CODE} 47 | ``` 48 | 49 | Output template: 50 | 51 | ```json 52 | {TEMPLATE} 53 | ``` 54 | """ 55 | 56 | 57 | INDENT = " " 58 | 59 | 60 | class Argument(BaseModel): 61 | name: str 62 | description: str 63 | annotation: str | None = None 64 | default: str | None = None 65 | 66 | 67 | class Return(BaseModel): 68 | description: str 69 | annotation: str | None 70 | 71 | 72 | class Docstring(BaseModel): 73 | node_type: Literal["class", "function"] 74 | name: str 75 | docstring: str 76 | args: list[Argument] | None = None 77 | ret: Return | None = None 78 | 79 | 80 | class Documentation(BaseModel): 81 | entries: list[Docstring] 82 | 83 | 84 | class DocstringGenerator(Protocol): 85 | def __call__( 86 | self, input_code: str, context: str, template: Documentation 87 | ) -> Documentation: ... 88 | 89 | 90 | def create_docstring_node(docstring_text: str, indent: str) -> cst.BaseStatement: 91 | lines = docstring_text.strip().split("\n") 92 | 93 | indented_lines = [] 94 | for line in lines: 95 | indented_lines.append(indent + line if line.strip() else line) 96 | 97 | return cst.SimpleStatementLine( 98 | body=[ 99 | cst.Expr( 100 | value=cst.SimpleString( 101 | value=f'"""\n{"\n".join(indented_lines)}\n{indent}"""' 102 | ) 103 | ) 104 | ] 105 | ) 106 | 107 | 108 | @dataclass 109 | class ChangedEntities: 110 | functions: set[str] = field(default_factory=set) 111 | classes: set[str] = field(default_factory=set) 112 | methods: set[str] = field(default_factory=set) 113 | 114 | 115 | def has_docstring(node: cst.CSTNode) -> bool: 116 | """ 117 | Check if a node has a docstring. 118 | 119 | A docstring is the first statement in a module, function or class body and must be a string literal. 120 | The node can have different types of bodies (IndentedBlock or SimpleStatementSuite) depending on 121 | whether it's a compound statement or a simple one-liner. 122 | """ 123 | # Handle simple one-liner functions/classes that use SimpleStatementSuite 124 | if isinstance(node.body, cst.SimpleStatementSuite): 125 | return False # One-liners can't have docstrings 126 | 127 | # Handle regular functions/classes with IndentedBlock 128 | if isinstance(node.body, cst.IndentedBlock): 129 | body_statements = node.body.body 130 | else: 131 | body_statements = node.body 132 | 133 | if not body_statements: 134 | return False 135 | 136 | first_stmt = body_statements[0] 137 | if not isinstance(first_stmt, cst.SimpleStatementLine): 138 | return False 139 | 140 | if not first_stmt.body: 141 | return False 142 | 143 | first_expr = first_stmt.body[0] 144 | if not isinstance(first_expr, cst.Expr): 145 | return False 146 | 147 | return isinstance(first_expr.value, (cst.SimpleString, cst.ConcatenatedString)) 148 | 149 | 150 | class DocstringTransformer(cst.CSTTransformer): 151 | def __init__( 152 | self, 153 | docstring_generator: DocstringGenerator, 154 | module: cst.Module, 155 | changed_entities: ChangedEntities | None = None, 156 | only_missing: bool = False, 157 | ): 158 | self._class_stack: list[str] = [] 159 | self._doc: Documentation | None = None 160 | self.module: cst.Module = module 161 | self.docstring_gen = docstring_generator 162 | self.indentation_level = 0 163 | self.changed_entities = changed_entities 164 | self.only_missing = only_missing 165 | 166 | @property 167 | def _current_class(self) -> str | None: 168 | """Get the current class name from the top of the stack.""" 169 | return self._class_stack[-1] if self._class_stack else None 170 | 171 | def visit_Module(self, node): 172 | self.module = node 173 | return True 174 | 175 | def visit_FunctionDef(self, node): 176 | self.indentation_level += 1 177 | 178 | def visit_ClassDef(self, node) -> bool | None: 179 | self.indentation_level += 1 180 | self._class_stack.append(node.name.value) 181 | 182 | if ( 183 | self.changed_entities is None 184 | or node.name.value in self.changed_entities.classes 185 | ): 186 | source_lines = cst.Module([node]).code 187 | template = extract_signatures(self.module, node) 188 | context = get_context(self.module, node) 189 | doc = self.docstring_gen(source_lines, context, template) 190 | self._doc = doc 191 | 192 | return super().visit_ClassDef(node) 193 | 194 | def _modify_docstring(self, body, new_docstring): 195 | # If body is an IndentedBlock, extract its body 196 | if isinstance(body, cst.IndentedBlock): 197 | body_statements = list(body.body) 198 | elif not isinstance(body, list): 199 | # Create an IndentedBlock if body is not already one 200 | indent = INDENT * (self.indentation_level + 1) 201 | new_docstring_node = create_docstring_node(new_docstring, indent) 202 | return cst.IndentedBlock(body=[new_docstring_node, body]) 203 | else: 204 | body_statements = list(body) 205 | 206 | indent = INDENT * (self.indentation_level + 1) 207 | # Check if first statement is a docstring 208 | if ( 209 | body_statements 210 | and isinstance(body_statements[0], cst.SimpleStatementLine) 211 | and isinstance(body_statements[0].body[0], cst.Expr) 212 | and isinstance(body_statements[0].body[0].value, cst.SimpleString) 213 | ): 214 | # Replace existing docstring 215 | new_docstring_node = create_docstring_node(new_docstring, indent) 216 | body_statements[0] = new_docstring_node 217 | 218 | # No existing docstring - add new one if provided 219 | elif new_docstring: 220 | new_docstring_node = create_docstring_node(new_docstring, indent) 221 | body_statements.insert(0, new_docstring_node) 222 | 223 | # Reconstruct the body 224 | if isinstance(body, cst.IndentedBlock): 225 | return body.with_changes(body=tuple(body_statements)) 226 | return tuple(body_statements) 227 | 228 | def leave_FunctionDef(self, original_node, updated_node): 229 | self.indentation_level -= 1 230 | 231 | if self.changed_entities is not None: 232 | if ( 233 | self._current_class 234 | and f"{self._current_class}.{updated_node.name.value}" 235 | not in self.changed_entities.methods 236 | ): 237 | return updated_node 238 | if ( 239 | not self._current_class 240 | and updated_node.name.value not in self.changed_entities.functions 241 | ): 242 | return updated_node 243 | 244 | if self.only_missing and has_docstring(updated_node): 245 | return updated_node 246 | 247 | source_lines = cst.Module([updated_node]).code 248 | name = updated_node.name.value 249 | 250 | doc = None 251 | if self._current_class is None: 252 | template = extract_signatures(self.module, updated_node) 253 | context = get_context(self.module, updated_node) 254 | doc = self.docstring_gen(source_lines, context, template) 255 | elif self._doc is not None: 256 | doc = self._doc 257 | else: 258 | return updated_node 259 | 260 | new_docstring = find_docstring_by_name(doc, name) 261 | if new_docstring is None: 262 | return updated_node 263 | 264 | new_body = self._modify_docstring( 265 | updated_node.body, docstring_to_str(new_docstring) 266 | ) 267 | 268 | return updated_node.with_changes(body=new_body) 269 | 270 | def leave_ClassDef(self, original_node, updated_node): 271 | self.indentation_level -= 1 272 | self._class_stack.pop() 273 | 274 | if ( 275 | self.changed_entities is not None 276 | and updated_node.name.value not in self.changed_entities.classes 277 | ): 278 | return updated_node 279 | 280 | if self.only_missing and has_docstring(updated_node): 281 | return updated_node 282 | 283 | if self._doc is None: 284 | return updated_node 285 | 286 | new_docstring = find_docstring_by_name(self._doc, updated_node.name.value) 287 | 288 | if new_docstring is None: 289 | return updated_node 290 | 291 | new_body = self._modify_docstring( 292 | updated_node.body, docstring_to_str(new_docstring) 293 | ) 294 | 295 | return updated_node.with_changes(body=new_body) 296 | 297 | 298 | def find_function_definitions(tree) -> list[ast.FunctionDef | ast.AsyncFunctionDef]: 299 | function_defs = [] 300 | 301 | for node in ast.walk(tree): 302 | if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 303 | function_defs.append(node) 304 | 305 | return function_defs 306 | 307 | 308 | def find_class_definitions(tree) -> list[ast.ClassDef]: 309 | function_defs = [] 310 | 311 | for node in ast.walk(tree): 312 | if isinstance(node, ast.ClassDef): 313 | function_defs.append(node) 314 | 315 | return function_defs 316 | 317 | 318 | def find_top_level_definitions( 319 | tree, 320 | ) -> dict[str, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef]: 321 | definitions = {} 322 | for node in tree.body: 323 | if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): 324 | definitions[node.name] = node 325 | return definitions 326 | 327 | 328 | def collect_entities( 329 | node, 330 | definitions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef], 331 | ) -> list[ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef]: 332 | entities = set() 333 | 334 | for node in ast.walk(node): 335 | match node: 336 | case ast.Call(func=ast.Name(name)): 337 | entities.add(definitions.get(name)) 338 | case ( 339 | ast.AnnAssign(annotation=ast.Name(name)) 340 | | ast.arg(annotation=ast.Name(name)) 341 | ): 342 | entities.add(definitions.get(name)) 343 | case ( 344 | ast.AnnAssign( 345 | annotation=ast.Subscript( 346 | value=ast.Name(subs_name), slice=ast.Name(name) 347 | ) 348 | ) 349 | | ast.arg( 350 | annotation=ast.Subscript( 351 | value=ast.Name(subs_name), slice=ast.Name(name) 352 | ) 353 | ) 354 | ): 355 | entities.add(definitions.get(name)) 356 | entities.add(definitions.get(subs_name)) 357 | 358 | return list(e for e in entities if e is not None) 359 | 360 | 361 | def get_context(module: cst.Module, node: cst.CSTNode) -> str: 362 | source = module.code 363 | 364 | tree = ast.parse(source) 365 | definitions = find_top_level_definitions(tree) 366 | 367 | node_source = module.code_for_node(node) 368 | node_tree = ast.parse(node_source) 369 | referenced_functions = collect_entities(node_tree, definitions) 370 | 371 | out = "\n".join(ast.unparse(func) for func in referenced_functions) 372 | return out 373 | 374 | 375 | def has_return_stmt(node): 376 | return any( 377 | isinstance(n, ast.Return) and n.value is not None for n in ast.walk(node) 378 | ) 379 | 380 | 381 | def extract_signatures(module: cst.Module, node: cst.CSTNode) -> Documentation: 382 | source = module.code_for_node(node) 383 | 384 | tree = ast.parse(source) 385 | function_defs = find_function_definitions(tree) 386 | # TODO argument 387 | function_defs = filter(lambda x: not is_private(x), function_defs) 388 | function_defs = filter(lambda x: not is_dunder(x), function_defs) 389 | 390 | class_defs = find_class_definitions(tree) 391 | class_defs = filter(lambda x: not is_private(x), class_defs) 392 | 393 | function_entries = [extract_signature(node) for node in function_defs] 394 | class_entries = [ 395 | Docstring(node_type="class", name=node.name, docstring="") 396 | for node in class_defs 397 | ] 398 | 399 | return Documentation(entries=[*class_entries, *function_entries]) 400 | 401 | 402 | def is_private(node): 403 | name = node.name 404 | return name.startswith("_") and not is_dunder(node) 405 | 406 | 407 | def is_dunder(node): 408 | name = node.name 409 | return name.startswith("__") and name.endswith("__") 410 | 411 | 412 | def extract_signature(function_node: ast.FunctionDef | ast.AsyncFunctionDef): 413 | function_name = function_node.name 414 | 415 | arguments = [] 416 | for arg in function_node.args.args: 417 | arg_name = arg.arg 418 | 419 | if arg_name in {"self", "cls"}: 420 | continue 421 | 422 | arg_type = ast.unparse(arg.annotation) if arg.annotation else None 423 | 424 | default_value = None 425 | if function_node.args.defaults: 426 | num_defaults = len(function_node.args.defaults) 427 | 428 | # Align defaults with arguments 429 | default_index = len(function_node.args.args) - num_defaults 430 | if function_node.args.args.index(arg) >= default_index: 431 | default_value = ast.unparse( 432 | function_node.args.defaults[ 433 | function_node.args.args.index(arg) - default_index 434 | ] 435 | ) 436 | 437 | arguments.append( 438 | Argument( 439 | name=arg_name, 440 | annotation=arg_type, 441 | default=default_value, 442 | description="", 443 | ) 444 | ) 445 | 446 | # Handle *args 447 | if function_node.args.vararg: 448 | arguments.append( 449 | Argument( 450 | name=f"*{function_node.args.vararg.arg}", 451 | annotation=ast.unparse(function_node.args.vararg.annotation) 452 | if function_node.args.vararg.annotation 453 | else None, 454 | description="", 455 | ) 456 | ) 457 | 458 | # Handle **kwargs 459 | if function_node.args.kwarg: 460 | arguments.append( 461 | Argument( 462 | name=f"**{function_node.args.kwarg.arg}", 463 | annotation=ast.unparse(function_node.args.kwarg.annotation) 464 | if function_node.args.kwarg.annotation 465 | else None, 466 | description="", 467 | ) 468 | ) 469 | 470 | # Extract return type 471 | ret = None 472 | if has_return_stmt(function_node): 473 | return_type = ( 474 | ast.unparse(function_node.returns) if function_node.returns else None 475 | ) 476 | ret = Return(description="", annotation=return_type) 477 | 478 | return Docstring( 479 | node_type="function", 480 | name=function_name, 481 | docstring="", 482 | args=arguments, 483 | ret=ret, 484 | ) 485 | 486 | 487 | def find_docstring_by_name(doc: Documentation, name: str) -> Docstring | None: 488 | entries = [entry for entry in doc.entries if entry.name == name] 489 | return entries[0] if entries else None 490 | 491 | 492 | def wrap_text( 493 | text: str, indent: str = "", initial_indent: str = "", max_width: int = 88 494 | ) -> str: 495 | """Wrap text to max_width, respecting indentation and breaking only between words.""" 496 | # Split by newlines first to preserve them 497 | text = text.replace("\\n", "\n") 498 | paragraphs = text.split("\n") 499 | result = [] 500 | 501 | for paragraph in paragraphs: 502 | words = paragraph.split() 503 | if not words: 504 | # Empty line, preserve it 505 | result.append("") 506 | continue 507 | 508 | lines = [] 509 | current_line = initial_indent 510 | 511 | for word in words: 512 | # Check if adding this word would exceed max_width 513 | if ( 514 | len(current_line) + len(word) + 1 <= max_width 515 | or not current_line.strip() 516 | ): 517 | # Add word with a space if not the first word on the line 518 | if current_line.strip(): 519 | current_line += " " + word 520 | else: 521 | current_line += word 522 | else: 523 | # Start a new line 524 | lines.append(current_line) 525 | current_line = indent + word 526 | 527 | # Add the last line if it's not empty 528 | if current_line: 529 | lines.append(current_line) 530 | 531 | result.append("\n".join(lines)) 532 | 533 | # Join all paragraphs with newlines 534 | return "\n".join(result) 535 | 536 | 537 | def docstring_to_str(docstring: Docstring) -> str: 538 | wrapped_docstring = wrap_text(docstring.docstring.strip()) 539 | string = f"{wrapped_docstring}\n" 540 | 541 | args_strings = [] 542 | for arg in docstring.args or []: 543 | if arg.annotation is not None: 544 | prefix = f" - {arg.name} ({arg.annotation}):" 545 | else: 546 | prefix = f" - {arg.name}:" 547 | 548 | description = arg.description 549 | if arg.default is not None: 550 | description += f" (default {arg.default})" 551 | 552 | # Wrap the argument description with proper indentation 553 | wrapped_arg = wrap_text( 554 | description.strip(), indent=" " * 6, initial_indent=prefix 555 | ) 556 | args_strings.append(wrapped_arg) 557 | 558 | if args_strings: 559 | string += f"""\nParameters: 560 | ----------- 561 | 562 | {"\n".join(args_strings)} 563 | """ 564 | 565 | # Process return value 566 | if docstring.ret is not None and ( 567 | docstring.ret.description or docstring.ret.annotation 568 | ): 569 | if docstring.ret.annotation: 570 | prefix = f" - {docstring.ret.annotation}:" 571 | description = docstring.ret.description 572 | indent = " " * 6 573 | else: 574 | prefix = " " 575 | description = docstring.ret.description 576 | indent = prefix 577 | 578 | # Wrap the return description with proper indentation 579 | wrapped_return = wrap_text(description, indent=indent, initial_indent=prefix) 580 | 581 | string += f"""\nReturns: 582 | -------- 583 | 584 | {wrapped_return} 585 | """ 586 | return string 587 | 588 | 589 | def llm_docstring_generator( 590 | input_code: str, context: str, template: Documentation, model_id: str, verbose: bool 591 | ) -> Documentation: 592 | context = f"Important context:\n\n```python\n{context}\n```" if context else "" 593 | model = llm.get_model(model_id) 594 | if not model.supports_schema: 595 | raise ValueError( 596 | ( 597 | f"The model {model_id} does not support structured outputs." 598 | " Choose a model with structured output support." 599 | ) 600 | ) 601 | prompt = PROMPT_TEMPLATE.strip().format( 602 | CONTEXT=context, 603 | CODE=input_code, 604 | TEMPLATE=template.model_dump_json(), 605 | ) 606 | 607 | if verbose: 608 | click.echo( 609 | click.style(f"System:\n{SYSTEM_PROMPT}", fg="yellow", bold=True), err=True 610 | ) 611 | click.echo(click.style(f"Prompt:\n{prompt}", fg="yellow", bold=True), err=True) 612 | 613 | response = model.prompt( 614 | prompt=prompt, schema=Documentation, system=SYSTEM_PROMPT.strip() 615 | ) 616 | if verbose: 617 | click.echo(click.style(response, fg="green", bold=True), err=True) 618 | 619 | return Documentation.model_validate_json(response.text()) 620 | 621 | 622 | def read_source(fpath: str): 623 | with open(fpath, "r", encoding="utf-8") as f: 624 | source = f.read() 625 | return source 626 | 627 | 628 | def modify_docstring( 629 | source_code, 630 | docstring_generator: DocstringGenerator, 631 | changed_entities: ChangedEntities | None = None, 632 | only_missing: bool = False, 633 | ): 634 | module = cst.parse_module(source_code) 635 | modified_module = module.visit( 636 | DocstringTransformer( 637 | docstring_generator, module, changed_entities, only_missing 638 | ) 639 | ) 640 | return modified_module.code 641 | 642 | 643 | def get_changed_lines(file_path: str, git_base: str = "HEAD"): 644 | abs_file_path = os.path.abspath(file_path) 645 | file_dir = os.path.dirname(abs_file_path) 646 | 647 | result = subprocess.run( 648 | ["git", "-C", file_dir, "diff", "-U0", git_base, "--", file_path], 649 | stdout=subprocess.PIPE, 650 | text=True, 651 | ) 652 | 653 | lines = result.stdout.splitlines() 654 | line_change_regex = re.compile(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@") 655 | 656 | modified_lines = [] 657 | 658 | for line in lines: 659 | match = line_change_regex.match(line) 660 | if match: 661 | start_line = int(match.group(1)) 662 | num_lines = int(match.group(2) or "1") 663 | 664 | # Collect all affected line numbers 665 | for i in range(num_lines): 666 | modified_lines.append(start_line + i) 667 | 668 | return modified_lines 669 | 670 | 671 | class ParentNodeVisitor(ast.NodeVisitor): 672 | """ 673 | Custom AST node visitor that tracks parent-child relationships. 674 | """ 675 | 676 | def __init__(self): 677 | self.parent_map = {} 678 | 679 | def visit(self, node): 680 | for child in ast.iter_child_nodes(node): 681 | self.parent_map[child] = node 682 | super().visit(node) 683 | 684 | 685 | @runtime_checkable 686 | class ASTNodeWithLines(Protocol): 687 | lineno: int 688 | end_lineno: int | None 689 | 690 | 691 | def get_node_line_range(node: ASTNodeWithLines) -> tuple[int, int]: 692 | """ 693 | Get the line range (start_line, end_line) for an AST node. 694 | 695 | Args: 696 | node: The AST node to get the line range for 697 | 698 | Returns: 699 | A tuple containing the start and end line numbers 700 | """ 701 | start_line = node.lineno 702 | end_line = getattr(node, "end_lineno", start_line) 703 | return start_line, end_line 704 | 705 | 706 | def is_node_in_lines(node: ast.AST, changed_lines: list[int]) -> bool: 707 | """ 708 | Check if an AST node has any lines that were changed. 709 | 710 | Args: 711 | node: The AST node to check 712 | changed_lines: List of line numbers that were changed 713 | 714 | Returns: 715 | True if any line in the node was changed, False otherwise 716 | """ 717 | if isinstance(node, ASTNodeWithLines): 718 | start_line, end_line = get_node_line_range(node) 719 | return any(start_line <= line <= end_line for line in changed_lines) 720 | return False 721 | 722 | 723 | def get_parent_class( 724 | node: ast.AST, parent_map: dict[ast.AST, ast.AST] 725 | ) -> ast.ClassDef | None: 726 | """ 727 | Get the parent class of a node if it exists. 728 | 729 | Args: 730 | node: The AST node to check 731 | parent_map: Dictionary mapping nodes to their parents 732 | 733 | Returns: 734 | The parent ClassDef node if the node is a method, None otherwise 735 | """ 736 | parent = parent_map.get(node) 737 | if parent and isinstance(parent, ast.ClassDef): 738 | return parent 739 | return None 740 | 741 | 742 | def get_changed_entities(file_path: str, git_base: str = "HEAD") -> ChangedEntities: 743 | """ 744 | Get a dictionary of changed entities (functions, methods, classes) in a file. 745 | 746 | Args: 747 | file_path: Path to the Python file 748 | git_base: Git reference to compare against (default: HEAD) 749 | 750 | Returns: 751 | ChangedEntities containing sets of changed entity names 752 | """ 753 | changed_lines = get_changed_lines(file_path, git_base) 754 | 755 | if not changed_lines: 756 | return ChangedEntities() 757 | 758 | source = read_source(file_path) 759 | tree = ast.parse(source) 760 | 761 | visitor = ParentNodeVisitor() 762 | visitor.visit(tree) 763 | parent_map = visitor.parent_map 764 | 765 | changed_functions = set() 766 | changed_classes = set() 767 | changed_methods = set() 768 | 769 | classes_with_changed_methods = set() 770 | 771 | for node in ast.walk(tree): 772 | if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 773 | if is_node_in_lines(node, changed_lines): 774 | parent_class = get_parent_class(node, parent_map) 775 | 776 | if parent_class: 777 | method_name = f"{parent_class.name}.{node.name}" 778 | changed_methods.add(method_name) 779 | classes_with_changed_methods.add(parent_class.name) 780 | else: 781 | changed_functions.add(node.name) 782 | 783 | elif isinstance(node, ast.ClassDef): 784 | if is_node_in_lines(node, changed_lines): 785 | changed_classes.add(node.name) 786 | 787 | changed_classes.update(classes_with_changed_methods) 788 | 789 | return ChangedEntities( 790 | functions=changed_functions, 791 | classes=changed_classes, 792 | methods=changed_methods, 793 | ) 794 | 795 | 796 | @llm.hookimpl 797 | def register_commands(cli): 798 | @cli.command() 799 | @click.argument("file_path") 800 | @click.option("model_id", "-m", "--model", help="Model to use") 801 | @click.option( 802 | "-o", 803 | "--output", 804 | help="Only show the modified code, without modifying the file", 805 | is_flag=True, 806 | ) 807 | @click.option( 808 | "-v", "--verbose", help="Verbose output of prompt and response", is_flag=True 809 | ) 810 | @click.option( 811 | "--git", 812 | help="Only update docstrings for functions and classes that have been changed since the last commit", 813 | is_flag=True, 814 | ) 815 | @click.option( 816 | "--git-base", 817 | help="Git reference to compare against (default: HEAD)", 818 | default="HEAD", 819 | ) 820 | @click.option( 821 | "--only-missing", 822 | help="Only add docstrings to entities that don't have them", 823 | is_flag=True, 824 | ) 825 | def docsmith(file_path, model_id, output, verbose, git, git_base, only_missing): 826 | """Generate and write docstrings to a Python file. 827 | 828 | Example usage: 829 | 830 | llm docsmith ./scripts/main.py 831 | llm docsmith ./scripts/main.py --git 832 | llm docsmith ./scripts/main.py --git --git-base HEAD~1 833 | llm docsmith ./scripts/main.py --only-missing 834 | """ 835 | source = read_source(file_path) 836 | docstring_generator = partial( 837 | llm_docstring_generator, model_id=model_id, verbose=verbose 838 | ) 839 | 840 | changed_entities = None 841 | if git: 842 | changed_entities = get_changed_entities(file_path, git_base) 843 | if verbose: 844 | click.echo(f"Changed functions: {changed_entities.functions}") 845 | click.echo(f"Changed classes: {changed_entities.classes}") 846 | click.echo(f"Changed methods: {changed_entities.methods}") 847 | 848 | modified_source = modify_docstring( 849 | source, docstring_generator, changed_entities, only_missing 850 | ) 851 | 852 | if output: 853 | click.echo(modified_source) 854 | return 855 | 856 | with open(file_path, "w", encoding="utf-8") as f: 857 | f.write(modified_source) 858 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "llm-docsmith" 3 | version = "0.3.1" 4 | description = "Generate Python docstrings automatically with LLM and syntax trees." 5 | readme = "README.md" 6 | authors = [{name = "Matheus Pedroni"}] 7 | license-files = ["LICENSE"] 8 | classifiers = [ 9 | "Programming Language :: Python :: 3", 10 | "Operating System :: OS Independent", 11 | ] 12 | dependencies = [ 13 | "libcst>=1.7.0", 14 | "llm>=0.23", 15 | "pydantic>=2.10.6", 16 | ] 17 | requires-python = ">=3.12" 18 | 19 | 20 | [build-system] 21 | requires = ["setuptools"] 22 | build-backend = "setuptools.build_meta" 23 | 24 | [dependency-groups] 25 | dev = [ 26 | "pytest>=8.3.5", 27 | ] 28 | 29 | [project.urls] 30 | Homepage = "https://github.com/mathpn/llm-docsmith" 31 | Changelog = "https://github.com/mathpn/llm-docsmith/releases" 32 | Issues = "https://github.com/mathpn/llm-docsmith/issues" 33 | 34 | [project.entry-points.llm] 35 | docsmith = "docsmith" 36 | 37 | [project.optional-dependencies] 38 | test = ["pytest"] 39 | -------------------------------------------------------------------------------- /tests/test_docsmith.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from unittest.mock import Mock, patch 3 | 4 | import libcst as cst 5 | import pytest 6 | 7 | from docsmith import ( 8 | Argument, 9 | ChangedEntities, 10 | Docstring, 11 | DocstringTransformer, 12 | Documentation, 13 | Return, 14 | docstring_to_str, 15 | extract_signatures, 16 | find_docstring_by_name, 17 | get_changed_entities, 18 | get_changed_lines, 19 | get_context, 20 | has_docstring, 21 | llm_docstring_generator, 22 | modify_docstring, 23 | wrap_text, 24 | ) 25 | 26 | 27 | @pytest.fixture 28 | def sample_python_code(): 29 | return textwrap.dedent( 30 | """ 31 | def greet(name: str, times: int = 1) -> str: 32 | return "Hello " * times + name 33 | 34 | 35 | class Calculator: 36 | def add(self, a: int, b: int) -> int: 37 | return a + b 38 | 39 | def subtract(self, a: int, b: int) -> int: 40 | return a - b 41 | """ 42 | ).strip() 43 | 44 | 45 | @pytest.fixture 46 | def sample_docstring(): 47 | return Docstring( 48 | node_type="function", 49 | name="greet", 50 | docstring="Greets someone multiple times.", 51 | args=[ 52 | Argument(name="name", description="The name to greet", annotation="str"), 53 | Argument( 54 | name="times", 55 | description="Number of times to repeat greeting", 56 | annotation="int", 57 | default="1", 58 | ), 59 | ], 60 | ret=Return(description="The greeting message", annotation="str"), 61 | ) 62 | 63 | 64 | def test_wrap_text_basic(): 65 | text = "This is a long text that should be wrapped at a specific width to ensure readability." 66 | wrapped = wrap_text(text, max_width=20) 67 | lines = wrapped.split("\n") 68 | assert all(len(line) <= 20 for line in lines) 69 | assert wrapped.replace("\n", " ") == text 70 | 71 | 72 | def test_wrap_text_with_indentation(): 73 | text = "This is an indented text." 74 | indent = " " 75 | wrapped = wrap_text(text, indent=indent, max_width=20) 76 | lines = wrapped.split("\n") 77 | assert all([line.startswith(indent) for line in lines[1:] if line]) 78 | assert all(len(line) <= 20 for line in lines) 79 | 80 | 81 | def test_wrap_text_with_newlines(): 82 | text = "First paragraph.\n\nSecond paragraph." 83 | wrapped = wrap_text(text, max_width=20) 84 | assert len(wrapped.split("\n\n")) == 2 85 | 86 | 87 | def test_docstring_to_str_basic(sample_docstring): 88 | result = docstring_to_str(sample_docstring) 89 | assert "Greets someone multiple times." in result 90 | assert "Parameters:" in result 91 | assert "name (str):" in result 92 | assert "times (int):" in result 93 | assert "(default 1)" in result 94 | assert "Returns:" in result 95 | assert "str:" in result 96 | 97 | 98 | def test_docstring_to_str_no_args(): 99 | docstring = Docstring( 100 | node_type="function", 101 | name="simple_func", 102 | docstring="A simple function.", 103 | args=None, 104 | ret=None, 105 | ) 106 | result = docstring_to_str(docstring) 107 | assert "Parameters:" not in result 108 | assert "Returns:" not in result 109 | assert result.strip() == "A simple function." 110 | 111 | 112 | def test_docstring_to_str_with_long_descriptions(): 113 | docstring = Docstring( 114 | node_type="function", 115 | name="func", 116 | docstring="A function with a very long description that should be wrapped properly to maintain readability and formatting.", 117 | args=[ 118 | Argument( 119 | name="param", 120 | description="A parameter with a very long description that should also be wrapped properly to maintain readability.", 121 | annotation="str", 122 | ) 123 | ], 124 | ret=Return( 125 | description="A return value with a very long description that should be wrapped as well.", 126 | annotation="str", 127 | ), 128 | ) 129 | result = docstring_to_str(docstring) 130 | assert all(len(line) <= 88 for line in result.split("\n")) 131 | 132 | 133 | def test_find_docstring_by_name_basic(sample_docstring): 134 | doc = Documentation(entries=[sample_docstring]) 135 | found = find_docstring_by_name(doc, "greet") 136 | assert found == sample_docstring 137 | 138 | 139 | def test_find_docstring_by_name_not_found(sample_docstring): 140 | doc = Documentation(entries=[sample_docstring]) 141 | found = find_docstring_by_name(doc, "nonexistent") 142 | assert found is None 143 | 144 | 145 | def test_find_docstring_by_name_multiple_entries(): 146 | docstrings = [ 147 | Docstring(node_type="function", name="func1", docstring="First function"), 148 | Docstring(node_type="function", name="func2", docstring="Second function"), 149 | Docstring( 150 | node_type="function", name="func1", docstring="Another first function" 151 | ), 152 | ] 153 | doc = Documentation(entries=docstrings) 154 | found = find_docstring_by_name(doc, "func1") 155 | assert found == docstrings[0] 156 | 157 | 158 | def test_extract_signatures_basic(sample_python_code): 159 | import libcst as cst 160 | 161 | module = cst.parse_module(sample_python_code) 162 | doc = extract_signatures(module, module) 163 | 164 | assert len(doc.entries) == 4 165 | 166 | # Check class 167 | class_entry = next(e for e in doc.entries if e.name == "Calculator") 168 | assert class_entry.node_type == "class" 169 | assert class_entry.docstring == "" 170 | 171 | # Check standalone function 172 | greet_entry = next(e for e in doc.entries if e.name == "greet") 173 | assert greet_entry.node_type == "function" 174 | assert greet_entry.args is not None 175 | assert len(greet_entry.args) == 2 176 | assert greet_entry.args[0].name == "name" 177 | assert greet_entry.args[0].annotation == "str" 178 | assert greet_entry.args[1].name == "times" 179 | assert greet_entry.args[1].annotation == "int" 180 | assert greet_entry.args[1].default == "1" 181 | assert greet_entry.ret is not None 182 | assert greet_entry.ret.annotation == "str" 183 | 184 | 185 | def test_extract_signatures_incomplete_type_hints(): 186 | import libcst as cst 187 | 188 | module = cst.parse_module( 189 | textwrap.dedent( 190 | """ 191 | def foo(a: int, b, c: str="foo", d=3, e=None) -> bool: 192 | return False 193 | """ 194 | ) 195 | ) 196 | doc = extract_signatures(module, module) 197 | foo_entry = next(e for e in doc.entries if e.name == "foo") 198 | assert foo_entry.args is not None 199 | assert foo_entry.args[0].name == "a" 200 | assert foo_entry.args[0].annotation == "int" 201 | assert foo_entry.args[1].name == "b" 202 | assert foo_entry.args[1].annotation is None 203 | assert foo_entry.args[2].name == "c" 204 | assert foo_entry.args[2].annotation == "str" 205 | assert foo_entry.args[2].default == "'foo'" 206 | assert foo_entry.args[3].name == "d" 207 | assert foo_entry.args[3].annotation is None 208 | assert foo_entry.args[3].default == "3" 209 | assert foo_entry.args[4].name == "e" 210 | assert foo_entry.args[4].annotation is None 211 | assert foo_entry.args[4].default == "None" 212 | 213 | 214 | def test_extract_signatures_with_complex_types(): 215 | code = textwrap.dedent( 216 | """ 217 | from typing import List, Dict, Optional 218 | 219 | def complex_func( 220 | items: List[int], 221 | mapping: Dict[str, Optional[int]], 222 | *args: str, 223 | **kwargs: int 224 | ) -> Optional[List[Dict[str, int]]]: 225 | return None 226 | """ 227 | ).strip() 228 | import libcst as cst 229 | 230 | module = cst.parse_module(code) 231 | doc = extract_signatures(module, module) 232 | print(doc) 233 | 234 | func = doc.entries[0] 235 | assert func.name == "complex_func" 236 | assert func.args is not None 237 | assert len(func.args) == 4 238 | assert func.args[0].annotation == "List[int]" 239 | assert func.args[1].annotation == "Dict[str, Optional[int]]" 240 | assert func.args[2].name == "*args" 241 | assert func.args[2].annotation == "str" 242 | assert func.args[3].name == "**kwargs" 243 | assert func.args[3].annotation == "int" 244 | assert func.ret is not None 245 | assert func.ret.annotation == "Optional[List[Dict[str, int]]]" 246 | 247 | 248 | def test_extract_signatures_private_methods(): 249 | code = textwrap.dedent( 250 | """ 251 | class Test: 252 | def _private(self): 253 | pass 254 | 255 | def __dunder__(self): 256 | pass 257 | 258 | def public(self): 259 | pass 260 | """ 261 | ).strip() 262 | import libcst as cst 263 | 264 | module = cst.parse_module(code) 265 | doc = extract_signatures(module, module) 266 | 267 | # Should only include public methods 268 | assert len(doc.entries) == 2 # Test class + public method 269 | assert all(not entry.name.startswith("_") for entry in doc.entries) 270 | 271 | 272 | def test_has_docstring_regular_function(): 273 | """Test a regular function with a docstring.""" 274 | code = ''' 275 | def foo(): 276 | """This is a docstring.""" 277 | pass 278 | ''' 279 | node = cst.parse_module(code).body[0] 280 | assert has_docstring(node) is True 281 | 282 | 283 | def test_has_docstring_no_docstring(): 284 | """Test a function without a docstring.""" 285 | code = """ 286 | def foo(): 287 | pass 288 | """ 289 | node = cst.parse_module(code).body[0] 290 | assert has_docstring(node) is False 291 | 292 | 293 | def test_has_docstring_empty_function(): 294 | """Test an empty function.""" 295 | code = """ 296 | def foo(): 297 | pass 298 | """ 299 | node = cst.parse_module(code).body[0] 300 | assert has_docstring(node) is False 301 | 302 | 303 | def test_has_docstring_oneliner(): 304 | """Test a one-liner function.""" 305 | code = "def foo(): return 42" 306 | node = cst.parse_module(code).body[0] 307 | assert has_docstring(node) is False 308 | 309 | 310 | def test_has_docstring_class(): 311 | """Test a class with a docstring.""" 312 | code = ''' 313 | class Foo: 314 | """This is a class docstring.""" 315 | pass 316 | ''' 317 | node = cst.parse_module(code).body[0] 318 | assert has_docstring(node) is True 319 | 320 | 321 | def test_has_docstring_class_no_docstring(): 322 | """Test a class without a docstring.""" 323 | code = """ 324 | class Foo: 325 | pass 326 | """ 327 | node = cst.parse_module(code).body[0] 328 | assert has_docstring(node) is False 329 | 330 | 331 | def test_has_docstring_concatenated_string(): 332 | """Test a function with a concatenated string docstring.""" 333 | code = ''' 334 | def foo(): 335 | """This is a """ """concatenated docstring.""" 336 | pass 337 | ''' 338 | node = cst.parse_module(code).body[0] 339 | assert has_docstring(node) is True 340 | 341 | 342 | def test_has_docstring_first_statement_not_string(): 343 | """Test a function where first statement is not a string.""" 344 | code = ''' 345 | def foo(): 346 | x = 42 347 | """This is not a docstring.""" 348 | pass 349 | ''' 350 | node = cst.parse_module(code).body[0] 351 | assert has_docstring(node) is False 352 | 353 | 354 | def test_has_docstring_empty_class(): 355 | """Test an empty class.""" 356 | code = """ 357 | class Foo: 358 | pass 359 | """ 360 | node = cst.parse_module(code).body[0] 361 | assert has_docstring(node) is False 362 | 363 | 364 | def test_has_docstring_oneliner_class(): 365 | """Test a one-liner class.""" 366 | code = "class Foo: pass" 367 | node = cst.parse_module(code).body[0] 368 | assert has_docstring(node) is False 369 | 370 | 371 | def test_docstring_transformer_without_changed_entities(): 372 | source = textwrap.dedent(""" 373 | def greet(): 374 | return "hello" 375 | 376 | class MyClass: 377 | def my_method(self): 378 | pass 379 | """) 380 | 381 | def mock_generator(input_code, context, template): 382 | return Documentation( 383 | entries=[ 384 | Docstring( 385 | node_type="function", name="greet", docstring="A test function." 386 | ), 387 | Docstring(node_type="class", name="MyClass", docstring="A test class."), 388 | Docstring( 389 | node_type="function", name="my_method", docstring="A test method." 390 | ), 391 | ] 392 | ) 393 | 394 | module = cst.parse_module(source) 395 | transformer = DocstringTransformer(mock_generator, module) 396 | modified = module.visit(transformer) 397 | 398 | assert "A test function." in modified.code 399 | assert "A test class." in modified.code 400 | assert "A test method." in modified.code 401 | 402 | 403 | def test_docstring_transformer_with_changed_entities(): 404 | source = textwrap.dedent(""" 405 | def changed_function(): 406 | return True 407 | 408 | def unchanged_function(): 409 | return False 410 | 411 | class ChangedClass: 412 | def changed_method(self): 413 | pass 414 | 415 | def unchanged_method(self): 416 | pass 417 | 418 | class UnchangedClass: 419 | def another_method(self): 420 | pass 421 | """) 422 | 423 | def mock_generator(input_code, context, template): 424 | return Documentation( 425 | entries=[ 426 | Docstring( 427 | node_type="function", 428 | name="changed_function", 429 | docstring="A changed function.", 430 | ), 431 | Docstring( 432 | node_type="function", 433 | name="unchanged_function", 434 | docstring="An unchanged function.", 435 | ), 436 | Docstring( 437 | node_type="class", name="ChangedClass", docstring="A changed class." 438 | ), 439 | Docstring( 440 | node_type="function", 441 | name="changed_method", 442 | docstring="A changed method.", 443 | ), 444 | Docstring( 445 | node_type="function", 446 | name="unchanged_method", 447 | docstring="An unchanged method.", 448 | ), 449 | ] 450 | ) 451 | 452 | changed_entities = ChangedEntities( 453 | functions={"changed_function"}, 454 | classes={"ChangedClass"}, 455 | methods={"ChangedClass.changed_method"}, 456 | ) 457 | 458 | module = cst.parse_module(source) 459 | transformer = DocstringTransformer(mock_generator, module, changed_entities) 460 | modified = module.code_for_node(module.visit(transformer)) 461 | 462 | assert "A changed function." in modified 463 | assert "A changed class." in modified 464 | assert "A changed method." in modified 465 | assert "An unchanged function." not in modified 466 | assert "An unchanged method." not in modified 467 | 468 | 469 | def test_docstring_transformer_nested_classes(): 470 | source = textwrap.dedent(""" 471 | class OuterClass: 472 | class InnerClass: 473 | def inner_method(self): 474 | pass 475 | 476 | def outer_method(self): 477 | pass 478 | """) 479 | 480 | def mock_generator(input_code, context, template): 481 | return Documentation( 482 | entries=[ 483 | Docstring( 484 | node_type="class", name="OuterClass", docstring="Outer class." 485 | ), 486 | Docstring( 487 | node_type="class", name="InnerClass", docstring="Inner class." 488 | ), 489 | Docstring( 490 | node_type="function", name="inner_method", docstring="Inner method." 491 | ), 492 | Docstring( 493 | node_type="function", name="outer_method", docstring="Outer method." 494 | ), 495 | ] 496 | ) 497 | 498 | module = cst.parse_module(source) 499 | transformer = DocstringTransformer(mock_generator, module) 500 | modified = module.visit(transformer) 501 | 502 | assert "Outer class." in modified.code 503 | assert "Inner class." in modified.code 504 | assert "Inner method." in modified.code 505 | assert "Outer method." in modified.code 506 | 507 | changed_entities = ChangedEntities( 508 | classes={"OuterClass", "InnerClass"}, methods={"OuterClass.outer_method"} 509 | ) 510 | 511 | transformer = DocstringTransformer(mock_generator, module, changed_entities) 512 | modified = module.visit(transformer) 513 | 514 | assert "Outer class." in modified.code 515 | assert "Inner class." in modified.code 516 | assert "Outer method." in modified.code 517 | assert "Inner method." not in modified.code 518 | 519 | 520 | def test_only_missing_option(): 521 | """Test that only_missing option only adds docstrings to entities without them.""" 522 | source = textwrap.dedent(''' 523 | class WithDocstring: 524 | """This is an existing docstring.""" 525 | def method_with_doc(self): 526 | """This method already has a doc.""" 527 | pass 528 | 529 | def method_without_doc(self): 530 | pass 531 | 532 | class WithoutDocstring: 533 | def method_without_doc(self): 534 | pass 535 | 536 | def func_with_doc(): 537 | """This function has a doc.""" 538 | return True 539 | 540 | def func_without_doc(): 541 | return False 542 | ''') 543 | 544 | def mock_docstring_gen(input_code, context, template): 545 | return Documentation( 546 | entries=[ 547 | Docstring( 548 | node_type="class", 549 | name="WithDocstring", 550 | docstring="This should not replace existing docstring", 551 | ), 552 | Docstring( 553 | node_type="function", 554 | name="method_with_doc", 555 | docstring="This should not replace existing method docstring", 556 | ), 557 | Docstring( 558 | node_type="function", 559 | name="method_without_doc", 560 | docstring="New method docstring", 561 | ), 562 | Docstring( 563 | node_type="class", 564 | name="WithoutDocstring", 565 | docstring="New class docstring", 566 | ), 567 | Docstring( 568 | node_type="function", 569 | name="func_with_doc", 570 | docstring="This should not replace existing function docstring", 571 | ), 572 | Docstring( 573 | node_type="function", 574 | name="func_without_doc", 575 | docstring="New function docstring", 576 | ), 577 | ] 578 | ) 579 | 580 | modified_source = modify_docstring(source, mock_docstring_gen, only_missing=True) 581 | 582 | assert "This is an existing docstring." in modified_source 583 | assert "This method already has a doc." in modified_source 584 | assert "This function has a doc." in modified_source 585 | 586 | assert "New method docstring" in modified_source 587 | assert "New class docstring" in modified_source 588 | assert "New function docstring" in modified_source 589 | 590 | assert "This should not replace existing" not in modified_source 591 | 592 | 593 | def test_only_missing_with_git_changes(): 594 | """Test that only_missing works correctly with git changes tracking.""" 595 | source = textwrap.dedent(''' 596 | class ChangedClass: 597 | """Existing docstring.""" 598 | def changed_method_with_doc(self): 599 | """Existing method doc.""" 600 | pass 601 | 602 | def changed_method_without_doc(self): 603 | pass 604 | 605 | class UnchangedClass: 606 | """Existing docstring.""" 607 | def unchanged_method(self): 608 | """Existing doc.""" 609 | pass 610 | ''') 611 | 612 | def mock_docstring_gen(input_code, context, template): 613 | return Documentation( 614 | entries=[ 615 | Docstring( 616 | node_type="class", 617 | name="ChangedClass", 618 | docstring="This should not replace existing class docstring", 619 | ), 620 | Docstring( 621 | node_type="function", 622 | name="changed_method_with_doc", 623 | docstring="This should not replace existing method docstring", 624 | ), 625 | Docstring( 626 | node_type="function", 627 | name="changed_method_without_doc", 628 | docstring="New method docstring", 629 | ), 630 | Docstring( 631 | node_type="class", 632 | name="UnchangedClass", 633 | docstring="This should be ignored - class unchanged", 634 | ), 635 | Docstring( 636 | node_type="function", 637 | name="unchanged_method", 638 | docstring="This should be ignored - method unchanged", 639 | ), 640 | ] 641 | ) 642 | 643 | changed_entities = ChangedEntities( 644 | classes={"ChangedClass"}, 645 | methods={ 646 | "ChangedClass.changed_method_with_doc", 647 | "ChangedClass.changed_method_without_doc", 648 | }, 649 | ) 650 | 651 | modified_source = modify_docstring( 652 | source, mock_docstring_gen, changed_entities=changed_entities, only_missing=True 653 | ) 654 | 655 | assert "Existing docstring." in modified_source 656 | assert "Existing doc." in modified_source 657 | assert "Existing method doc." in modified_source 658 | 659 | assert "New method docstring" in modified_source 660 | 661 | assert "This should not replace existing" not in modified_source 662 | assert "This should be ignored" not in modified_source 663 | 664 | 665 | @patch("subprocess.run") 666 | def test_get_changed_lines_basic(mock_run): 667 | mock_run.return_value = Mock( 668 | stdout=textwrap.dedent( 669 | """ 670 | @@ -1,3 +1,4 @@ 671 | +def new_function(): 672 | def old_function(): 673 | pass 674 | - return None 675 | + return True 676 | """ 677 | ).strip() 678 | ) 679 | 680 | lines = get_changed_lines("test.py") 681 | assert lines == [1, 2, 3, 4] 682 | mock_run.assert_called_once() 683 | 684 | 685 | @patch("subprocess.run") 686 | def test_get_changed_lines_with_git_base(mock_run): 687 | mock_run.return_value = Mock(stdout="@@ -1 +1 @@\n-old\n+new\n") 688 | get_changed_lines("test.py", git_base="main") 689 | assert mock_run.call_args[0][0][5] == "main" 690 | 691 | 692 | @patch("subprocess.run") 693 | def test_get_changed_lines_no_changes(mock_run): 694 | mock_run.return_value = Mock(stdout="") 695 | lines = get_changed_lines("test.py") 696 | assert lines == [] 697 | 698 | 699 | @patch("subprocess.run") 700 | def test_get_changed_entities_basic(mock_run, sample_python_code): 701 | mock_run.return_value = Mock( 702 | stdout=textwrap.dedent( 703 | """ 704 | @@ -7,2 +7 @@ class Calculator: 705 | - add = a + b 706 | - return add 707 | + return a + b 708 | """ 709 | ).strip() 710 | ) 711 | 712 | with patch("docsmith.read_source", return_value=sample_python_code): 713 | entities = get_changed_entities("test.py") 714 | 715 | assert "Calculator" in entities.classes 716 | assert "Calculator.add" in entities.methods 717 | assert "Calculator.subtract" not in entities.methods 718 | assert not entities.functions 719 | 720 | 721 | @patch("subprocess.run") 722 | def test_get_changed_entities_multiple_changes(mock_run): 723 | code = textwrap.dedent( 724 | """ 725 | def func1(): 726 | print("changed") 727 | return True 728 | 729 | 730 | class TestClass: 731 | def method1(self): 732 | print("changed") 733 | 734 | def method2(self): 735 | pass 736 | 737 | 738 | def func2(): 739 | pass 740 | """ 741 | ).strip() 742 | 743 | mock_run.return_value = Mock( 744 | stdout=textwrap.dedent( 745 | """ 746 | @@ -2 +2,2 @@ def func1(): 747 | - pass 748 | + print("changed") 749 | + return True 750 | @@ -7 +8 @@ class TestClass: 751 | - pass 752 | + print("changed") 753 | """ 754 | ).strip() 755 | ) 756 | 757 | with patch("docsmith.read_source", return_value=code): 758 | entities = get_changed_entities("test.py") 759 | 760 | assert "func1" in entities.functions 761 | assert "TestClass" in entities.classes 762 | assert "TestClass.method1" in entities.methods 763 | assert "func2" not in entities.functions 764 | assert "TestClass.method2" not in entities.methods 765 | 766 | 767 | @patch("llm.get_model") 768 | def test_llm_docstring_generator(mock_get_model): 769 | mock_model = Mock() 770 | mock_model.prompt.return_value = Mock( 771 | text=lambda: '{"entries": [{"node_type": "function", "name": "test", "docstring": "Test function"}]}' 772 | ) 773 | mock_get_model.return_value = mock_model 774 | 775 | result = llm_docstring_generator( 776 | "def test(): pass", 777 | "", 778 | Documentation(entries=[]), 779 | model_id="test-model", 780 | verbose=False, 781 | ) 782 | 783 | assert len(result.entries) == 1 784 | assert result.entries[0].name == "test" 785 | assert result.entries[0].docstring == "Test function" 786 | 787 | 788 | def test_modify_docstring_basic(): 789 | code = "def test(): pass" 790 | mock_generator = Mock( 791 | return_value=Documentation( 792 | entries=[ 793 | Docstring( 794 | node_type="function", 795 | name="test", 796 | docstring="Test function", 797 | args=None, 798 | ret=None, 799 | ) 800 | ] 801 | ) 802 | ) 803 | 804 | result = modify_docstring(code, mock_generator) 805 | print(result) 806 | assert "Test function" in result 807 | assert "def test():" in result 808 | 809 | 810 | def test_modify_docstring_with_existing_docstring(): 811 | code = textwrap.dedent( 812 | ''' 813 | def test(): 814 | """Old docstring.""" 815 | pass 816 | ''' 817 | ).strip() 818 | mock_generator = Mock( 819 | return_value=Documentation( 820 | entries=[ 821 | Docstring( 822 | node_type="function", 823 | name="test", 824 | docstring="New docstring", 825 | args=None, 826 | ret=None, 827 | ) 828 | ] 829 | ) 830 | ) 831 | 832 | result = modify_docstring(code, mock_generator) 833 | assert "New docstring" in result 834 | assert "Old docstring." not in result 835 | 836 | 837 | def test_modify_docstring_with_changed_entities(): 838 | code = textwrap.dedent( 839 | """ 840 | def func1(): 841 | pass 842 | 843 | def func2(): 844 | pass 845 | """ 846 | ).strip() 847 | mock_generator = Mock( 848 | return_value=Documentation( 849 | entries=[ 850 | Docstring( 851 | node_type="function", 852 | name="func1", 853 | docstring="Updated func1", 854 | args=None, 855 | ret=None, 856 | ), 857 | Docstring( 858 | node_type="function", 859 | name="func2", 860 | docstring="Updated func2", 861 | args=None, 862 | ret=None, 863 | ), 864 | ] 865 | ) 866 | ) 867 | 868 | changed_entities = ChangedEntities(functions={"func1"}) 869 | result = modify_docstring(code, mock_generator, changed_entities) 870 | 871 | assert "Updated func1" in result 872 | assert "Updated func2" not in result 873 | 874 | 875 | @pytest.mark.parametrize( 876 | "code,expected_context", 877 | [ 878 | ( 879 | textwrap.dedent( 880 | """ 881 | def helper(x: int) -> str: 882 | return str(x) 883 | 884 | def main(y: int) -> str: 885 | return helper(y) 886 | """ 887 | ).strip(), 888 | "def helper(x: int) -> str:\n return str(x)", 889 | ), 890 | ( 891 | textwrap.dedent( 892 | """ 893 | def unused() -> None: 894 | pass 895 | 896 | def main() -> None: 897 | pass 898 | """ 899 | ).strip(), 900 | "", 901 | ), 902 | ], 903 | ) 904 | def test_get_context_function_references(code, expected_context): 905 | import libcst as cst 906 | 907 | module = cst.parse_module(code) 908 | main_func = next( 909 | node 910 | for node in module.body 911 | if isinstance(node, cst.FunctionDef) and node.name.value == "main" 912 | ) 913 | context = get_context(module, main_func) 914 | assert context.strip() == expected_context.strip() 915 | --------------------------------------------------------------------------------