├── .gitignore ├── LICENSE ├── README.md ├── motleycoder ├── __init__.py ├── codemap │ ├── __init__.py │ ├── builtins_by_lang.json │ ├── file_group.py │ ├── graph.py │ ├── map_args.py │ ├── parse.py │ ├── rank.py │ ├── render.py │ ├── repomap.py │ └── tag.py ├── linter.py ├── prompts.py ├── queries │ ├── README.md │ ├── tree-sitter-c-tags.scm │ ├── tree-sitter-c_sharp-tags.scm │ ├── tree-sitter-cpp-tags.scm │ ├── tree-sitter-elisp-tags.scm │ ├── tree-sitter-elixir-tags.scm │ ├── tree-sitter-elm-tags.scm │ ├── tree-sitter-go-tags.scm │ ├── tree-sitter-java-tags.scm │ ├── tree-sitter-javascript-tags.scm │ ├── tree-sitter-ocaml-tags.scm │ ├── tree-sitter-php-tags.scm │ ├── tree-sitter-python-tags.scm │ ├── tree-sitter-ql-tags.scm │ ├── tree-sitter-r-tags.scm │ ├── tree-sitter-ruby-tags.scm │ ├── tree-sitter-rust-tags.scm │ └── tree-sitter-typescript-tags.scm ├── repo.py ├── tools │ ├── __init__.py │ ├── add_files_tool.py │ ├── file_edit_tool.py │ ├── get_full_text_tool.py │ ├── get_modifiable_files_tool.py │ ├── inspect_entity_tool.py │ └── return_to_user_tool.py └── user_interface.py ├── motleycoder_demo.ipynb ├── poetry.lock └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # Editors 2 | .vscode/ 3 | .idea/ 4 | 5 | # Vagrant 6 | .vagrant/ 7 | 8 | # Mac/OSX 9 | .DS_Store 10 | 11 | # Windows 12 | Thumbs.db 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # IPython 69 | profile_default/ 70 | ipython_config.py 71 | 72 | # Environments 73 | .env 74 | .venv 75 | env/ 76 | venv/ 77 | ENV/ 78 | env.bak/ 79 | venv.bak/ 80 | .python-version 81 | 82 | # mypy 83 | .mypy_cache/ 84 | .dmypy.json 85 | dmypy.json 86 | *.pyc 87 | 88 | scripts/ 89 | 90 | # docs-related stuff 91 | docs/source/_autosummary 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ShoggothAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## motleycoder: a code editor for AI agents 2 | 3 | This is a collection of tools and utilities that help agents write code. 4 | 5 | MotleyCoder uses a combination of static code analysis and retrieval techniques for building a map of the codebase and navigating it. 6 | 7 | MotleyCoder consists of the following main elements: 8 | - `RepoMap`: provides the agent with an initial overview of the parts of the codebase relevant to the current task. This uses tree-sitter to build formal syntax trees of the code, and then builds a graph of relationships between entities in the code. 9 | - `InspectEntityTool`: a tool given to the agent so it can inspect and navigate the codebase, read the code of specific entities or files, and list directories. This uses the graph built by RepoMap to enrich the information about each entity with a summary of the entities it references. This way, an agent using this tool can gradually walk the dependency graph of the code. 10 | - `FileEditTool`: a tool that allows editing code in a way an LLM can comprehend. This comes with a built-in linter, so only syntactically valid edits are accepted, else the tool returns a description of the linting errors so the LLM can try again. 11 | 12 | Please check out the [demo notebook](https://github.com/ShoggothAI/motleycoder/blob/main/motleycoder_demo.ipynb) to see how it all works. 13 | 14 | MotleyCoder was originally designed for use with our [motleycrew](https://github.com/ShoggothAI/motleycrew) library, but its flexible nature allows using it in other contexts. 15 | 16 | Credits to [aider](https://github.com/paul-gauthier/aider) for the original idea of the RepoMap and of using tree-sitter for code parsing. 17 | 18 | ## Installation 19 | 20 | ``` 21 | pip install motleycoder 22 | ``` 23 | -------------------------------------------------------------------------------- /motleycoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShoggothAI/motleycoder/92ad30914c9df4806bc00cfb3ecc3d70af85dde6/motleycoder/__init__.py -------------------------------------------------------------------------------- /motleycoder/codemap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShoggothAI/motleycoder/92ad30914c9df4806bc00cfb3ecc3d70af85dde6/motleycoder/codemap/__init__.py -------------------------------------------------------------------------------- /motleycoder/codemap/builtins_by_lang.json: -------------------------------------------------------------------------------- 1 | { 2 | "python": [ 3 | "abs", 4 | "all", 5 | "any", 6 | "ascii", 7 | "bin", 8 | "bool", 9 | "bytearray", 10 | "bytes", 11 | "callable", 12 | "chr", 13 | "classmethod", 14 | "compile", 15 | "complex", 16 | "delattr", 17 | "dict", 18 | "dir", 19 | "divmod", 20 | "enumerate", 21 | "eval", 22 | "exec", 23 | "filter", 24 | "float", 25 | "format", 26 | "frozenset", 27 | "getattr", 28 | "globals", 29 | "hasattr", 30 | "hash", 31 | "help", 32 | "hex", 33 | "id", 34 | "input", 35 | "int", 36 | "isinstance", 37 | "issubclass", 38 | "iter", 39 | "len", 40 | "list", 41 | "locals", 42 | "map", 43 | "max", 44 | "memoryview", 45 | "min", 46 | "next", 47 | "object", 48 | "oct", 49 | "open", 50 | "ord", 51 | "pow", 52 | "print", 53 | "property", 54 | "range", 55 | "repr", 56 | "reversed", 57 | "round", 58 | "set", 59 | "setattr", 60 | "slice", 61 | "sorted", 62 | "staticmethod", 63 | "str", 64 | "sum", 65 | "super", 66 | "tuple", 67 | "type", 68 | "vars", 69 | "zip", 70 | "False", 71 | "None", 72 | "True", 73 | "and", 74 | "as", 75 | "assert", 76 | "async", 77 | "await", 78 | "break", 79 | "class", 80 | "continue", 81 | "def", 82 | "del", 83 | "elif", 84 | "else", 85 | "except", 86 | "finally", 87 | "for", 88 | "from", 89 | "global", 90 | "if", 91 | "import", 92 | "in", 93 | "is", 94 | "lambda", 95 | "nonlocal", 96 | "not", 97 | "or", 98 | "pass", 99 | "raise", 100 | "return", 101 | "try", 102 | "while", 103 | "with", 104 | "yield", 105 | "__abs__", 106 | "__add__", 107 | "__and__", 108 | "__bool__", 109 | "__bytes__", 110 | "__call__", 111 | "__ceil__", 112 | "__cmp__", 113 | "__coerce__", 114 | "__complex__", 115 | "__contains__", 116 | "__del__", 117 | "__delattr__", 118 | "__delete__", 119 | "__delitem__", 120 | "__delslice__", 121 | "__dir__", 122 | "__div__", 123 | "__divmod__", 124 | "__enter__", 125 | "__eq__", 126 | "__exit__", 127 | "__float__", 128 | "__floor__", 129 | "__floordiv__", 130 | "__ge__", 131 | "__get__", 132 | "__getattr__", 133 | "__getattribute__", 134 | "__getitem__", 135 | "__getslice__", 136 | "__gt__", 137 | "__hash__", 138 | "__iadd__", 139 | "__iand__", 140 | "__ifloordiv__", 141 | "__ilshift__", 142 | "__imod__", 143 | "__import__", 144 | "__imul__", 145 | "__index__", 146 | "__init__", 147 | "__instancecheck__", 148 | "__int__", 149 | "__invert__", 150 | "__ior__", 151 | "__ipow__", 152 | "__irshift__", 153 | "__isub__", 154 | "__iter__", 155 | "__itruediv__", 156 | "__ixor__", 157 | "__le__", 158 | "__len__", 159 | "__long__", 160 | "__lshift__", 161 | "__lt__", 162 | "__mod__", 163 | "__mul__", 164 | "__ne__", 165 | "__neg__", 166 | "__new__", 167 | "__nonzero__", 168 | "__oct__", 169 | "__or__", 170 | "__pos__", 171 | "__pow__", 172 | "__radd__", 173 | "__rand__", 174 | "__rdiv__", 175 | "__rdivmod__", 176 | "__repr__", 177 | "__reversed__", 178 | "__rfloordiv__", 179 | "__rlshift__", 180 | "__rmod__", 181 | "__rmul__", 182 | "__ror__", 183 | "__round__", 184 | "__rpow__", 185 | "__rrshift__", 186 | "__rshift__", 187 | "__rsub__", 188 | "__rtruediv__", 189 | "__rxor__", 190 | "__set__", 191 | "__setattr__", 192 | "__setitem__", 193 | "__setslice__", 194 | "__str__", 195 | "__sub__", 196 | "__truediv__", 197 | "__xor__", 198 | "__subclasscheck__", 199 | "__subclasses__", 200 | "__format__", 201 | "__sizeof__", 202 | "__dir__", 203 | "__class__", 204 | "__doc__", 205 | "__enter__", 206 | "__exit__", 207 | "__annotations__", 208 | "__kwdefaults__", 209 | "__code__", 210 | "__defaults__", 211 | "__globals__", 212 | "__closure__", 213 | "__doc__", 214 | "__name__", 215 | "__qualname__", 216 | "__module__", 217 | "__defaults__", 218 | "__kwdefaults__", 219 | "__code__", 220 | "__globals__", 221 | "__closure__", 222 | "__get__", 223 | "__set__", 224 | "__delete__", 225 | "__slots__", 226 | "__weakref__", 227 | "__dict__", 228 | "__bases__", 229 | "__class__", 230 | "__mro__", 231 | "__subclasses__", 232 | "__init_subclass__", 233 | "__prepare__", 234 | "__instancecheck__", 235 | "__subclasscheck__", 236 | "__class_getitem__" 237 | ] 238 | } 239 | -------------------------------------------------------------------------------- /motleycoder/codemap/file_group.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path 4 | import re 5 | from collections import defaultdict 6 | from difflib import SequenceMatcher 7 | from pathlib import Path 8 | from typing import Callable, List, Set, Optional 9 | 10 | from diskcache import Cache 11 | 12 | from motleycoder.repo import GitRepo 13 | 14 | 15 | def python_file_filter(fname: str, with_tests: bool = False) -> bool: 16 | return fname.endswith(".py") and (with_tests or not "test_" in fname) 17 | 18 | 19 | class FileGroup: 20 | """ 21 | A FileGroup is a collection of files that we are parsing and monitoring for changes. 22 | This might be a git repo or a directory. If new files appear in it, 23 | we will see that as well using the get_all_filenames method. 24 | """ 25 | 26 | CACHE_VERSION = 4 27 | TAGS_CACHE_DIR = f".aider.tags.cache.v{CACHE_VERSION}" 28 | 29 | def __init__(self, repo: GitRepo | None = None, root: str | None = None, filename_filter=None): 30 | # TODO: support other kinds of locations 31 | self.repo = repo 32 | if self.repo is None: 33 | if os.path.isdir(root): 34 | self.root = root 35 | else: 36 | raise ValueError("Must supply either a GitRepo or a valid root directory") 37 | else: 38 | self.root = self.repo.root 39 | 40 | if filename_filter is None: 41 | self.filename_filter = python_file_filter 42 | else: 43 | self.filename_filter = filename_filter 44 | 45 | self.load_tags_cache() 46 | self.warned_files = set() 47 | 48 | self.files_for_modification = set() 49 | self.edited_files = set() 50 | 51 | def abs_root_path(self, path): 52 | "Gives an abs path, which safely returns a full (not 8.3) windows path" 53 | res = Path(self.root) / path 54 | res = Path(res).resolve() 55 | return str(res) 56 | 57 | def get_all_filenames(self, with_tests: bool = False): 58 | """ 59 | Get all the filenames in the group, including new files. 60 | :return: List of unique absolute file paths 61 | """ 62 | if self.repo: 63 | files = self.repo.get_tracked_files() 64 | files = [self.abs_root_path(fname) for fname in files] 65 | files = [str(fname) for fname in files if os.path.isfile(str(fname))] 66 | 67 | else: 68 | files = [str(f) for f in Path(self.root).rglob("*") if f.is_file()] 69 | 70 | files = [ 71 | str(f).replace("\\", "/") 72 | for f in files 73 | if self.filename_filter(f, with_tests=with_tests) 74 | ] 75 | 76 | return sorted(set(files)) 77 | 78 | def validate_fnames(self, fnames: List[str], with_tests: bool = False) -> List[str]: 79 | cleaned_fnames = [] 80 | for fname in fnames: 81 | if not self.filename_filter(str(fname), with_tests=with_tests): 82 | continue 83 | if Path(fname).is_file(): 84 | cleaned_fnames.append(str(fname)) 85 | else: 86 | if fname not in self.warned_files: 87 | if Path(fname).exists(): 88 | logging.error(f"Repo-map can't include {fname}, it is not a normal file") 89 | else: 90 | logging.error( 91 | f"Repo-map can't include {fname}, it doesn't exist (anymore?)" 92 | ) 93 | 94 | self.warned_files.add(fname) 95 | 96 | return cleaned_fnames 97 | 98 | def load_tags_cache(self): 99 | path = Path(self.root) / self.TAGS_CACHE_DIR 100 | if not path.exists(): 101 | logging.warning(f"Tags cache not found, creating: {path}") 102 | self.TAGS_CACHE = Cache(str(path)) 103 | 104 | def add_for_modification(self, rel_fname): 105 | self.files_for_modification.add(self.abs_root_path(rel_fname)) 106 | 107 | def get_rel_fname(self, fname): 108 | return os.path.relpath(fname, self.root).replace("\\", "/") 109 | 110 | def save_tags_cache(self): 111 | pass 112 | 113 | def get_mtime(self, fname): 114 | try: 115 | return os.path.getmtime(fname) 116 | except FileNotFoundError: 117 | logging.error(f"File not found error: {fname}") 118 | 119 | def cached_function_call(self, fname: str, function: Callable, key: str | None = None): 120 | """ 121 | Cache the result of a function call, refresh the cache if the file has changed. 122 | :param fname: the file to monitor for changes 123 | :param function: the function to apply to the file 124 | :param key: the key to use in the cache, if None, the function name is used 125 | :return: the function's result 126 | """ 127 | # Check if the file is in the cache and if the modification time has not changed 128 | # TODO: this should be a decorator? 129 | file_mtime = self.get_mtime(fname) 130 | if file_mtime is None: 131 | return [] 132 | 133 | cache_key = fname + "::" + (key or function.__name__) 134 | if cache_key in self.TAGS_CACHE and self.TAGS_CACHE[cache_key]["mtime"] == file_mtime: 135 | return self.TAGS_CACHE[cache_key]["data"] 136 | 137 | # miss! 138 | data = function(fname) 139 | 140 | # Update the cache 141 | self.TAGS_CACHE[cache_key] = {"mtime": file_mtime, "data": data} 142 | self.save_tags_cache() 143 | return data 144 | 145 | def get_file_mentions(self, content): 146 | words = set(word for word in content.split()) 147 | 148 | # drop sentence punctuation from the end 149 | words = set(word.rstrip(",.!;:") for word in words) 150 | 151 | # strip away all kinds of quotes 152 | quotes = "".join(['"', "'", "`"]) 153 | words = set(word.strip(quotes) for word in words) 154 | 155 | all_files = self.get_all_filenames() 156 | other_files = set(all_files) - set(self.files_for_modification) 157 | addable_rel_fnames = [self.get_rel_fname(f) for f in other_files] 158 | 159 | mentioned_rel_fnames = set() 160 | fname_to_rel_fnames = {} 161 | for rel_fname in addable_rel_fnames: 162 | if rel_fname in words: 163 | mentioned_rel_fnames.add(str(rel_fname)) 164 | 165 | fname = os.path.basename(rel_fname) 166 | 167 | # Don't add basenames that could be plain words like "run" or "make" 168 | if "/" in fname or "." in fname or "_" in fname or "-" in fname: 169 | if fname not in fname_to_rel_fnames: 170 | fname_to_rel_fnames[fname] = [] 171 | fname_to_rel_fnames[fname].append(rel_fname) 172 | 173 | for fname, rel_fnames in fname_to_rel_fnames.items(): 174 | if len(rel_fnames) == 1 and fname in words: 175 | mentioned_rel_fnames.add(rel_fnames[0]) 176 | 177 | return self.clean_mentioned_filenames(mentioned_rel_fnames) 178 | 179 | def clean_mentioned_filenames(self, mentioned_filenames: Set[str]) -> Set[str]: 180 | all_files = self.get_all_filenames() 181 | clean_mentioned_filenames = [] 182 | for mentioned_name in mentioned_filenames: 183 | for name in all_files: 184 | if mentioned_name in name: 185 | clean_mentioned_filenames.append(name) 186 | break 187 | return set(clean_mentioned_filenames) 188 | 189 | def get_rel_fnames_in_directory( 190 | self, 191 | abs_dir: str, 192 | level: Optional[int] = 1, 193 | with_tests: bool = False, 194 | ) -> List[str] | None: 195 | abs_dir = abs_dir.replace("\\", "/").rstrip("/") 196 | all_abs_files = self.get_all_filenames(with_tests=with_tests) 197 | # List all of the above files that are in abs_dir, but not in subdirectories of abs_dir 198 | matches = [ 199 | f 200 | for f in all_abs_files 201 | if f.startswith(abs_dir) and (not level or f.count("/") == abs_dir.count("/") + level) 202 | ] 203 | rel_matches = [str(self.get_rel_fname(f)) for f in matches] 204 | return rel_matches 205 | 206 | def edit_file(self, file_path: str, search: str, replace: str): 207 | abs_path = self.abs_root_path(file_path) 208 | abs_path = Path(abs_path) 209 | 210 | if not abs_path.exists() and not search.strip(): 211 | abs_path.touch() 212 | 213 | file_content = abs_path.read_text() 214 | 215 | new_content = replace_part(file_content, search, replace) 216 | 217 | if new_content and new_content != file_content: 218 | abs_path.write_text(new_content) 219 | return True, None 220 | else: 221 | close_match = find_similar_lines(search, file_content) 222 | return False, close_match 223 | 224 | 225 | def prepare_content_and_lines(content): 226 | if content and not content.endswith("\n"): 227 | content += "\n" 228 | lines = content.splitlines(keepends=True) 229 | 230 | lines_without_numbers = [re.sub(r"^\d+\s*│", "", line) for line in lines] 231 | return "".join(lines_without_numbers), lines_without_numbers 232 | 233 | 234 | def perfect_replace_part(orig_content, search_content, replace_content): 235 | if search_content in orig_content: 236 | return orig_content.replace(search_content, replace_content) 237 | 238 | 239 | def match_but_for_leading_whitespace(whole_lines, part_lines): 240 | num = len(whole_lines) 241 | 242 | # does the non-whitespace all agree? 243 | if not all(whole_lines[i].lstrip() == part_lines[i].lstrip() for i in range(num)): 244 | return 245 | 246 | # are they all offset the same? 247 | add = set( 248 | whole_lines[i][: len(whole_lines[i]) - len(part_lines[i])] 249 | for i in range(num) 250 | if whole_lines[i].strip() 251 | ) 252 | 253 | if len(add) != 1: 254 | return 255 | 256 | return add.pop() 257 | 258 | 259 | def replace_part_with_missing_leading_whitespace(orig_lines, search_lines, replace_lines): 260 | # GPT often messes up leading whitespace. 261 | # It usually does it uniformly across the ORIG and UPD blocks. 262 | # Either omitting all leading whitespace, or including only some of it. 263 | 264 | # Outdent everything in part_lines and replace_lines by the max fixed amount possible 265 | leading = [len(p) - len(p.lstrip()) for p in search_lines if p.strip()] + [ 266 | len(p) - len(p.lstrip()) for p in replace_lines if p.strip() 267 | ] 268 | 269 | if leading and min(leading): 270 | num_leading = min(leading) 271 | search_lines = [p[num_leading:] if p.strip() else p for p in search_lines] 272 | replace_lines = [p[num_leading:] if p.strip() else p for p in replace_lines] 273 | 274 | # can we find an exact match not including the leading whitespace 275 | num_search_lines = len(search_lines) 276 | 277 | for i in range(len(orig_lines) - num_search_lines + 1): 278 | add_leading = match_but_for_leading_whitespace( 279 | orig_lines[i : i + num_search_lines], search_lines 280 | ) 281 | 282 | if add_leading is None: 283 | continue 284 | 285 | replace_lines = [add_leading + rline if rline.strip() else rline for rline in replace_lines] 286 | orig_lines = orig_lines[:i] + replace_lines + orig_lines[i + num_search_lines :] 287 | return "".join(orig_lines) 288 | 289 | 290 | def replace_with_dotdotdots(orig, search, replace): 291 | """ 292 | See if the edit block has ... lines. 293 | If not, return none. 294 | 295 | If yes, try and do a perfect edit with the ... chunks. 296 | If there's a mismatch or otherwise imperfect edit, raise ValueError. 297 | 298 | If perfect edit succeeds, return the updated whole. 299 | """ 300 | 301 | dots_re = re.compile(r"(^\s*\.\.\.\n)", re.MULTILINE | re.DOTALL) 302 | 303 | search_pieces = re.split(dots_re, search) 304 | replace_pieces = re.split(dots_re, replace) 305 | 306 | if len(search_pieces) != len(replace_pieces): 307 | raise ValueError("Unpaired ... in SEARCH/REPLACE block") 308 | 309 | if len(search_pieces) == 1: 310 | # no dots in this edit block, just return None 311 | return 312 | 313 | # Compare odd strings in search_pieces and replace_pieces 314 | all_dots_match = all( 315 | search_pieces[i] == replace_pieces[i] for i in range(1, len(search_pieces), 2) 316 | ) 317 | 318 | if not all_dots_match: 319 | raise ValueError("Unmatched ... in SEARCH/REPLACE block") 320 | 321 | search_pieces = [search_pieces[i] for i in range(0, len(search_pieces), 2)] 322 | replace_pieces = [replace_pieces[i] for i in range(0, len(replace_pieces), 2)] 323 | 324 | pairs = zip(search_pieces, replace_pieces) 325 | for search, replace in pairs: 326 | if not search and not replace: 327 | continue 328 | 329 | if not search and replace: 330 | if not orig.endswith("\n"): 331 | orig += "\n" 332 | orig += replace 333 | continue 334 | 335 | if orig.count(search) == 0: 336 | raise ValueError 337 | if orig.count(search) > 1: 338 | raise ValueError 339 | 340 | orig = orig.replace(search, replace, 1) 341 | 342 | return orig 343 | 344 | 345 | def replace_part(text, search, replace): 346 | if not text: 347 | text = "" 348 | 349 | if not search or search[-1] != "\n": 350 | search += "\n" 351 | if not replace or replace[-1] != "\n": 352 | replace += "\n" 353 | 354 | if not search.strip(): 355 | return text + replace 356 | 357 | orig_content, orig_lines = prepare_content_and_lines(text) 358 | search_content, search_lines = prepare_content_and_lines(search) 359 | replace_content, replace_lines = prepare_content_and_lines(replace) 360 | 361 | result = perfect_replace_part(orig_content, search_content, replace_content) 362 | if result: 363 | return result 364 | 365 | result = replace_part_with_missing_leading_whitespace(orig_lines, search_lines, replace_lines) 366 | if result: 367 | return result 368 | 369 | try: 370 | return replace_with_dotdotdots(orig_content, search_content, replace_content) 371 | except ValueError: 372 | return None 373 | 374 | 375 | def find_similar_lines(search_lines, content_lines, threshold=0.6): 376 | search_lines = search_lines.splitlines() 377 | content_lines = content_lines.splitlines() 378 | 379 | best_ratio = 0 380 | best_match = None 381 | 382 | for i in range(len(content_lines) - len(search_lines) + 1): 383 | chunk = content_lines[i : i + len(search_lines)] 384 | ratio = SequenceMatcher(None, search_lines, chunk).ratio() 385 | if ratio > best_ratio: 386 | best_ratio = ratio 387 | best_match = chunk 388 | best_match_i = i 389 | 390 | if best_ratio < threshold: 391 | return "" 392 | 393 | if best_match[0] == search_lines[0] and best_match[-1] == search_lines[-1]: 394 | return "\n".join(best_match) 395 | 396 | N = 5 397 | best_match_end = min(len(content_lines), best_match_i + len(search_lines) + N) 398 | best_match_i = max(0, best_match_i - N) 399 | 400 | best = content_lines[best_match_i:best_match_end] 401 | return "\n".join(best) 402 | 403 | 404 | def get_ident_filename_matches(idents, all_rel_fnames: List[str], max_ident_len=2): 405 | all_fnames = defaultdict(set) 406 | for fname in all_rel_fnames: 407 | base = Path(fname).with_suffix("").name.lower() 408 | if len(base) >= max_ident_len: 409 | all_fnames[base].add(fname) 410 | 411 | matches = set() 412 | for ident in idents: 413 | if len(ident) < max_ident_len: 414 | continue 415 | matches.update(all_fnames[ident.lower()]) 416 | 417 | return matches 418 | 419 | 420 | def get_ident_mentions(text): 421 | # Split the string on any character that is not alphanumeric 422 | # \W+ matches one or more non-word characters (equivalent to [^a-zA-Z0-9_]+) 423 | words = set(re.split(r"\W+", text)) 424 | return words 425 | 426 | 427 | if __name__ == "__main__": 428 | text = """ 429 | 517│ def _convert(self, val: Any) -> Any: 430 | 518│ '''Convert `val` to an appropriate type for the element's VR.''' 431 | 519│ # If the value is a byte string and has a VR that can only be encoded 432 | 520│ # using the default character repertoire, we convert it to a string 433 | 521│ # here to allow for byte string input in these cases 434 | 522│ if _is_bytes(val) and self.VR in ( 435 | 523│ 'AE', 'AS', 'CS', 'DA', 'DS', 'DT', 'IS', 'TM', 'UI', 'UR'): 436 | 524│ val = val.decode() 437 | 525│ 438 | 526│ if self.VR == 'IS': 439 | 527│ return pydicom.valuerep.IS(val) 440 | 528│ elif self.VR == 'DA' and config.datetime_conversion: 441 | 529│ return pydicom.valuerep.DA(val) 442 | 530│ elif self.VR == 'DS': 443 | 531│ return pydicom.valuerep.DS(val) 444 | 532│ elif self.VR == 'DT' and config.datetime_conversion: 445 | 533│ return pydicom.valuerep.DT(val) 446 | 534│ elif self.VR == 'TM' and config.datetime_conversion: 447 | 535│ return pydicom.valuerep.TM(val) 448 | 536│ elif self.VR == "UI": 449 | 537│ return UID(val) if val is not None else None 450 | 538│ elif self.VR == "PN": 451 | 539│ return PersonName(val) 452 | 540│ elif self.VR == "AT" and (val == 0 or val): 453 | 541│ return val if isinstance(val, BaseTag) else Tag(val) 454 | 542│ # Later may need this for PersonName as for UI, 455 | 543│ # but needs more thought 456 | 544│ # elif self.VR == "PN": 457 | 545│ # return PersonName(val) 458 | 546│ else: # is either a string or a type 2 optionally blank string 459 | 547│ return val # this means a "numeric" value could be empty string "" 460 | 548│ # except TypeError: 461 | 549│ # print "Could not convert value '%s' to VR '%s' in tag %s" \ 462 | 550│ # % (repr(val), self.VR, self.tag) 463 | 551│ # except ValueError: 464 | 552│ # print "Could not convert value '%s' to VR '%s' in tag %s" \ 465 | 553│ # % (repr(val), self.VR, self.tag) 466 | """ 467 | search = """else: # is either a string or a type 2 optionally blank string""" 468 | replace = """elif self.VR == "OL": 469 | return bytes(val) 470 | else: # is either a string or a type 2 optionally blank string 471 | """ 472 | print(replace_part(text, search, replace)) 473 | -------------------------------------------------------------------------------- /motleycoder/codemap/graph.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | from typing import List, Optional, Dict 5 | 6 | import networkx as nx 7 | 8 | from motleycrew.common import logger 9 | from .render import RenderCode 10 | from .tag import Tag 11 | 12 | BUILTINS_BY_LANG_FILE = "builtins_by_lang.json" 13 | 14 | 15 | class TagGraph(nx.MultiDiGraph): 16 | def __init__(self): 17 | super().__init__() 18 | self.code_renderer = RenderCode() 19 | 20 | @property 21 | def filenames(self): 22 | return set([tag.fname for tag in self.nodes]) 23 | 24 | def successors_with_attribute(self, node, attr_name, attr_value): 25 | """ 26 | Get all neighbors of a node in a MultiDiGraph where the relationship has a specific attribute. 27 | 28 | Parameters: 29 | G (networkx.MultiDiGraph): The graph 30 | node (node): The node for which to find neighbors 31 | attr_name (str): The attribute name to filter edges 32 | attr_value: The attribute value to filter edges 33 | 34 | Returns: 35 | list: A list of neighbors where the relationship has the specified attribute 36 | """ 37 | neighbors = [] 38 | for successor in self.successors(node): 39 | # Check all edges from node to successor 40 | edges = self[node][successor] 41 | for key in edges: 42 | if edges[key].get(attr_name) == attr_value: 43 | neighbors.append(successor) 44 | break # Found a valid edge, no need to check further 45 | return neighbors 46 | 47 | def get_parents(self, tag: Tag) -> List[Tag] | str: 48 | """ 49 | Get the parent tags of a tag in the same file, eg the class def for a method name 50 | :param tag: 51 | :return: list of parent tags 52 | """ 53 | if not len(tag.parent_names): 54 | return [] 55 | 56 | parents = [t for t in self.predecessors(tag) if t.kind == "def"] 57 | if not parents: 58 | logger.warning(f"No parent found for {tag} with nonempty parent names!") 59 | return ".".join(tag.parent_names) + "." + tag.name + ":" 60 | parent = parents[0] 61 | 62 | if len(parent.parent_names): 63 | predecessors = self.get_parents(parent) 64 | else: 65 | predecessors = [] 66 | 67 | return predecessors + [parent] 68 | 69 | def get_tag_representation( 70 | self, tag: Tag, parent_details: bool = False, max_lines=200, force_include_full_text=False 71 | ) -> str: 72 | if tag is None: 73 | return None 74 | if tag not in self.nodes: 75 | raise ValueError(f"The tag {tag} is not in the tag graph") 76 | 77 | builtins_by_lang = load_builtins_by_lang() 78 | 79 | tag_repr = [tag.rel_fname + ":"] 80 | if not parent_details: 81 | if len(tag.parent_names): 82 | tag_repr.append(".".join(tag.parent_names) + "." + tag.name + ":") 83 | else: 84 | parents = self.get_parents(tag) 85 | if parents: 86 | if isinstance(parents, str): 87 | tag_repr.append(parents) 88 | else: 89 | # if there are parents, this will include the filename 90 | tag_repr = [self.code_renderer.to_tree(parents)] 91 | 92 | tag_repr.append(RenderCode.text_with_line_numbers(tag)) 93 | tag_repr = "\n".join(tag_repr) 94 | 95 | n_lines = len(tag_repr.split("\n")) 96 | 97 | if force_include_full_text or n_lines <= max_lines: 98 | # if the full text hast at most 200 lines, put it all in the summary 99 | children = [] 100 | for e, c, data in self.out_edges(tag, data=True): 101 | if ( # If the child is included in the parent's full text anyway, skip it 102 | c.fname == tag.fname 103 | and c.byte_range[0] >= tag.byte_range[0] 104 | and c.byte_range[1] <= tag.byte_range[1] 105 | ) or not data.get("include_in_summary"): 106 | continue 107 | if c.name in builtins_by_lang.get(c.language, []): 108 | continue # skip built-ins 109 | children.append(c) 110 | 111 | out = [tag_repr] 112 | if children: 113 | chlidren_summary = self.code_renderer.to_tree(children) 114 | if n_lines + len(chlidren_summary.split("\n")) < max_lines: 115 | out.extend( 116 | [ 117 | "Referenced entities summary:", 118 | chlidren_summary, 119 | ] 120 | ) 121 | return "\n".join(out) 122 | else: 123 | # if the full text is too long, send a summary of it and its children 124 | children = list( 125 | self.successors_with_attribute(tag, attr_name="include_in_summary", attr_value=True) 126 | ) 127 | tag_repr = self.code_renderer.to_tree( 128 | [tag] + [c for c in children if c.name not in builtins_by_lang.get(c.language, [])] 129 | ) 130 | return tag_repr 131 | 132 | def search_line_in_tags(self, tags: List[Tag], line: int) -> Tag | None: 133 | """ 134 | Search for a line in a list of tags, assuming the tags belong to the same file 135 | :param tags: The tags to search 136 | :param line: The line number to search for 137 | :return: The tag that contains the line, or None if not found 138 | """ 139 | if not tags: 140 | return None 141 | 142 | filename = tags[0].rel_fname 143 | for tag in tags: 144 | assert tag.rel_fname == filename, "Tags must belong to the same file" 145 | 146 | if tag.line <= line <= tag.end_line: 147 | return tag 148 | return None 149 | 150 | def get_file_representation(self, file_name: str, file_content: str, max_lines=500) -> str: 151 | """ 152 | Get a representation of a file, with a maximum number of lines 153 | :param file_name: The file name 154 | :param max_lines: The maximum number of lines to include 155 | :return: A string representation of the file 156 | """ 157 | tags = [t for t in self.nodes if t.fname == file_name] 158 | if not tags: 159 | if not file_content: 160 | raise ValueError(f"No tags found for file {file_name} and no content provided") 161 | 162 | file_lines = file_content.split("\n") 163 | file_repr = "\n".join( 164 | [ 165 | self.code_renderer.render_line(line, i + 1) 166 | for i, line in enumerate(file_content.split("\n")[:max_lines]) 167 | ] 168 | ) 169 | if len(file_lines) > max_lines + 1: 170 | return file_repr + f"\n... and {len(file_lines) - max_lines} more lines" 171 | 172 | return file_repr 173 | 174 | root_tags = [t for t in tags if not t.parent_names] 175 | file_lines = file_content.split("\n") 176 | line_nums_to_display = [] 177 | 178 | i = 0 179 | while i < len(file_lines): 180 | tag = self.search_line_in_tags(root_tags, i) 181 | if tag is not None: 182 | i = tag.end_line + 1 183 | else: 184 | line_nums_to_display.append(i) 185 | i += 1 186 | 187 | return self.code_renderer.to_tree( 188 | tags, additional_lines={tags[0].rel_fname: line_nums_to_display} 189 | ) 190 | 191 | def get_tag_from_filename_lineno( 192 | self, fname: str, line_no: int, try_next_line=True 193 | ) -> Tag | None: 194 | files = [f for f in self.filenames if fname in f] 195 | if not files: 196 | raise ValueError(f"File {fname} not found in the file group") 197 | this_file_nodes = [node for node in self.nodes if node.fname in files] 198 | if not this_file_nodes: 199 | raise ValueError(f"File {fname} not found in the tag graph") 200 | for node in this_file_nodes: 201 | if node.line == line_no - 1: 202 | return node 203 | # If we got this far, we didn't find the tag 204 | # Let's look in the next line, sometimes that works 205 | if try_next_line: 206 | return self.get_tag_from_filename_lineno(fname, line_no + 1, try_next_line=False) 207 | 208 | return None 209 | 210 | def get_tags_from_entity_name( 211 | self, entity_name: Optional[str] = None, file_name: Optional[str] = None 212 | ) -> List[Tag]: 213 | 214 | if entity_name is None: 215 | assert file_name is not None, "Must supply at least one of entity_name, file_name" 216 | return [t for t in self.nodes if file_name in t.fname] 217 | 218 | min_entity_name = entity_name.split(".")[-1] 219 | 220 | # Composite, like `file.py:method_name` 221 | if file_name is not None: 222 | preselection: List[Tag] = [t for t in self.nodes if file_name in t.fname] 223 | 224 | test = [t for t in preselection if t.name == min_entity_name and t.kind == "def"] 225 | if not test: 226 | logger.warning( 227 | f"Definition of entity {entity_name} not found in file {file_name}, searching globally" 228 | ) 229 | preselection: List[Tag] = list(self.nodes) 230 | else: 231 | preselection: List[Tag] = list(self.nodes) 232 | 233 | orig_tags: List[Tag] = [ 234 | t for t in preselection if t.name == min_entity_name and t.kind == "def" 235 | ] 236 | 237 | # do fancier name resolution 238 | re_tags = [t for t in orig_tags if match_entity_name(entity_name, t)] 239 | 240 | if len(re_tags) > 1: 241 | logger.warning(f"Multiple definitions found for {entity_name}: {re_tags}") 242 | return re_tags 243 | 244 | 245 | def load_builtins_by_lang() -> Dict[str, List[str]]: 246 | here = os.path.dirname(__file__) 247 | builtins_by_lang_path = os.path.realpath(os.path.join(here, BUILTINS_BY_LANG_FILE)) 248 | if not os.path.exists(builtins_by_lang_path): 249 | raise Exception(f"Builtins by lang file not found at {builtins_by_lang_path}") 250 | 251 | with open(builtins_by_lang_path, "r") as file: 252 | builtins_by_lang = json.load(file) 253 | 254 | return builtins_by_lang 255 | 256 | 257 | def match_entity_name(entity_name: str, tag: Tag) -> bool: 258 | entity_name = entity_name.split(".") 259 | if entity_name[-1] != tag.name: 260 | return False 261 | 262 | # Simple reference, with no dots, and names are the same 263 | # or the tag has no parent names, and the dots are just package names 264 | if len(entity_name) == 1 or len(tag.parent_names) == 0: 265 | return True 266 | 267 | # Check if the parent names match if they exist 268 | if tag.parent_names == tuple((entity_name[:-1])[-len(tag.parent_names) :]): 269 | return True 270 | 271 | # TODO: do fancier resolution here, potentially returning match scores to rank matches 272 | 273 | # If entity name includes package name, check that 274 | fn_parts = tag.fname.split("/") 275 | fn_parts[-1] = fn_parts[-1].replace(".py", "") 276 | 277 | potential_parents = fn_parts + list(tag.parent_names) 278 | clipped_parents = potential_parents[-len(entity_name) - 1 :] 279 | 280 | if tuple(clipped_parents) == tuple(entity_name[:-1]): 281 | return True 282 | 283 | return False 284 | 285 | 286 | def build_tag_graph(tags: List[Tag], code_map: Dict[str, str]) -> TagGraph: 287 | """ 288 | Build a graph of tags, with edges from references to definitions 289 | And with edges from parent definitions to child definitions in the same file 290 | :param tags: 291 | :return: 292 | """ 293 | # Build a map from entity names to their definitions 294 | # There may be multiple definitions for a single name in different scopes, 295 | # for now we don't bother resolving them 296 | G = TagGraph() 297 | G.code_renderer.code_map = code_map 298 | 299 | def_map = defaultdict(set) 300 | 301 | for tag in tags: 302 | if tag.kind == "def": 303 | def_map[tag.name].add(tag) 304 | elif tag.kind == "file": 305 | # Just add all the parsed files to the graph 306 | G.add_node(tag, kind=tag.kind) 307 | 308 | # Add all tags as nodes 309 | # Add edges from references to definitions 310 | for tag in tags: 311 | G.add_node(tag, kind=tag.kind) 312 | if tag.kind == "def": 313 | # Look for any references to other entities inside that definition 314 | for ref_tag in tags: 315 | if ref_tag.kind == "ref" and ref_tag.fname == tag.fname: 316 | if ( 317 | ref_tag.byte_range[0] >= tag.byte_range[0] 318 | and ref_tag.byte_range[1] <= tag.byte_range[1] 319 | ): 320 | G.add_edge(tag, ref_tag) 321 | 322 | elif tag.kind == "ref": 323 | G.add_node(tag, kind=tag.kind) 324 | for def_tag in def_map[tag.name]: 325 | # point to any definitions that might have been meant 326 | # would probably need a language server for unique resolution, 327 | # don't bother with that here 328 | G.add_edge(tag, def_tag) 329 | tag.n_defs += 1 330 | 331 | # Build up definition hierarchy 332 | # A parent definition for a tag must: 333 | # - be in the same file 334 | # - Have matching tail of parent names 335 | if len(tag.parent_names): 336 | parent_name = tag.parent_names[-1] 337 | candidates = [ 338 | c 339 | for c in def_map[parent_name] 340 | if c.fname == tag.fname and c.parent_names == tag.parent_names[:-1] 341 | ] 342 | for c in candidates: 343 | G.add_edge(c, tag) 344 | 345 | return G 346 | 347 | 348 | def only_defs(tag_graph: TagGraph) -> TagGraph: 349 | """ 350 | Return a graph with only the def nodes and the edges between them 351 | If a def node has a reference node as a child, add an edge to the reference node's definition 352 | :param tag_graph: A graph generated by build_tag_graph 353 | :return: A graph with only def nodes and edges between them 354 | """ 355 | 356 | G = TagGraph() 357 | G.code_renderer.code_map = tag_graph.code_renderer.code_map 358 | 359 | for tag in tag_graph.nodes: 360 | if tag.kind == "def": 361 | G.add_node(tag) 362 | for u, v, data in tag_graph.edges(data=True): 363 | if u.kind == "def" and v.kind == "def": 364 | data["include_in_summary"] = True 365 | G.add_edge(u, v, **data) 366 | # Also add edges betweend defs and their two-hop descendant defs 367 | # TODO: should we look for more than two-hop def descendants? Can these ever happen? 368 | for u, v, data in tag_graph.edges(data=True): 369 | if u.kind == "def" and v.kind != "def": 370 | for _, v_desc in tag_graph.out_edges(v): 371 | if v_desc.kind == "def" and v_desc != u: 372 | data["include_in_summary"] = ( 373 | v.n_defs <= 2 374 | ) # Skip entries with more than 2 definition candidates 375 | G.add_edge(u, v_desc, **data) 376 | return G 377 | -------------------------------------------------------------------------------- /motleycoder/codemap/map_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Set 3 | 4 | 5 | @dataclass 6 | class RepoMapArgs: 7 | chat_fnames: Set[str] = field(default_factory=set) 8 | other_fnames: Set[str] = field(default_factory=set) 9 | mentioned_fnames: Set[str] = field(default_factory=set) 10 | mentioned_idents: Set[str] = field(default_factory=set) 11 | mentioned_entities: Set[str] = field(default_factory=set) 12 | search_terms: Set[str] = field(default_factory=set) 13 | add_prefix: bool = True 14 | -------------------------------------------------------------------------------- /motleycoder/codemap/parse.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from grep_ast import filename_to_lang 5 | from pygments.lexers import guess_lexer_for_filename 6 | from pygments.token import Token 7 | from pygments.util import ClassNotFound 8 | from tree_sitter import Tree, Query, Node 9 | from tree_sitter_languages import get_language 10 | from tree_sitter_languages import get_parser # noqa: E402 11 | 12 | from motleycrew.common import logger 13 | from .tag import Tag 14 | 15 | 16 | def get_query(lang: str) -> Query | None: 17 | language = get_language(lang) 18 | # Load the tags queries 19 | here = os.path.dirname(__file__) 20 | scm_fname = os.path.realpath(os.path.join(here, "../queries", f"tree-sitter-{lang}-tags.scm")) 21 | if not os.path.exists(scm_fname): 22 | return None 23 | 24 | with open(scm_fname, "r") as file: 25 | query_scm = file.read() 26 | 27 | # Run the tags queries 28 | query = language.query(query_scm) 29 | return query 30 | 31 | 32 | def ast_to_tags( 33 | full_file_code: str, 34 | tree: Tree, 35 | query: Query, 36 | rel_fname: str, 37 | fname: str, 38 | language: str | None = None, 39 | ) -> List[Tag]: 40 | # TODO: extract docstrings and comments to do RAG on 41 | captures = list(query.captures(tree.root_node)) 42 | defs = [] 43 | refs = [] 44 | names = [] 45 | 46 | for node, tag in captures: 47 | if tag.startswith("name"): 48 | names.append(node) 49 | elif tag.startswith("reference"): 50 | refs.append((node, "ref")) 51 | elif tag.startswith("definition"): 52 | defs.append((node, "def")) 53 | else: 54 | continue 55 | 56 | out = [] 57 | for node, kind in defs + refs: 58 | name_node = node2namenode(node, names) 59 | if name_node is None: 60 | # logging.warning(f"Could not find name node for {node}") 61 | # TODO: should we populate these anyway, eg by parsing the text? 62 | continue 63 | 64 | parent_defs = get_def_parents(node, [d[0] for d in defs]) 65 | parent_names = tuple([namenode2name(node2namenode(d, names)) for d in parent_defs]) 66 | 67 | out.append( 68 | Tag( 69 | rel_fname=rel_fname.replace("\\", "/"), 70 | fname=fname.replace("\\", "/"), 71 | name=namenode2name(name_node), 72 | parent_names=parent_names, 73 | kind=kind, 74 | docstring=node2docstring(node, language) if kind == "def" else "", 75 | line=name_node.start_point[0], 76 | end_line=node.end_point[0], 77 | text=node.text.decode("utf-8"), 78 | byte_range=node.byte_range, 79 | language=language, 80 | ) 81 | ) 82 | 83 | return out 84 | 85 | 86 | def node2docstring(node: Node, language: str) -> str: 87 | if language == "python": 88 | docstring = extract_python_docstring(node.text.decode("utf-8")) 89 | # TODO: check for more kinds of docstring-like comments 90 | if docstring is None: 91 | cmt = [(i, n) for i, n in enumerate(node.children) if n.type == "comment"] 92 | if len(cmt): 93 | docstring = cmt[0][1].text.decode("utf-8") 94 | for j, (i, n) in enumerate(cmt[1:]): 95 | # only look for adjacent comments 96 | if i - j - 1 == cmt[0][0]: 97 | docstring += "\n" + n.text.decode("utf-8") 98 | return docstring if docstring is not None else "" 99 | else: 100 | logger.warning(f"Docstrings not yet implemented for {language}") 101 | return "" 102 | 103 | 104 | def extract_python_docstring(code: str) -> str | None: 105 | import ast 106 | 107 | # Parse the code into an AST 108 | try: 109 | tree = ast.parse(code) 110 | except Exception as e: 111 | logger.error(f"Error parsing python code: {e}") 112 | return None 113 | 114 | # Initialize the docstring variable 115 | docstring = None 116 | 117 | # Traverse the AST to find the first function, method, or class definition 118 | for node in ast.walk(tree): 119 | if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): 120 | # Extract the docstring 121 | docstring = ast.get_docstring(node) 122 | break 123 | 124 | return docstring 125 | 126 | 127 | def node2namenode(node: Node, name_nodes: List[Node]) -> Node | None: 128 | tmp = [n for n in name_nodes if n in node.children] 129 | 130 | if len(tmp) > 0: 131 | return tmp[0] 132 | 133 | # method calls 134 | tmp = [n for n in node.children if n.type == "attribute"] 135 | if len(tmp) == 0: 136 | logger.warning(f"Could not find name node for {node}") 137 | return None 138 | # method name 139 | tmp = [n for n in name_nodes if n in tmp[0].children] 140 | 141 | if len(tmp) == 0: 142 | logger.warning(f"Could not find name node for {node}") 143 | return None 144 | 145 | return tmp[0] 146 | 147 | 148 | def namenode2name(node: Node | None) -> str: 149 | return node.text.decode("utf-8") if node else "" 150 | 151 | 152 | def get_def_parents(node: Node, defs: List[Node]) -> List[Node]: 153 | dp = [] 154 | while node.parent is not None: 155 | if node.parent in defs: 156 | dp.append(node.parent) 157 | node = node.parent 158 | return tuple(reversed(dp)) 159 | 160 | 161 | def refs_from_lexer(rel_fname, fname, code, language: str | None = None): 162 | try: 163 | lexer = guess_lexer_for_filename(fname, code) 164 | except ClassNotFound: 165 | return [] 166 | 167 | tokens = list(lexer.get_tokens(code)) 168 | tokens = [token[1] for token in tokens if token[0] in Token.Name] 169 | 170 | out = [ 171 | Tag( 172 | rel_fname=rel_fname, 173 | fname=fname, 174 | name=token, 175 | kind="ref", 176 | line=-1, 177 | end_line=-1, 178 | text="", 179 | byte_range=(0, 0), 180 | language=language, 181 | docstring="", 182 | ) 183 | for token in tokens 184 | ] 185 | return out 186 | 187 | 188 | def get_tags_raw(fname, rel_fname, code) -> list[Tag]: 189 | lang = filename_to_lang(fname) 190 | if not lang: 191 | return [] 192 | 193 | parser = get_parser(lang) 194 | 195 | if not code: 196 | return [] 197 | 198 | ast = parser.parse(bytes(code, "utf-8")) 199 | query = get_query(lang) 200 | if not query: 201 | return [] 202 | 203 | pre_tags = ast_to_tags(code, ast, query, rel_fname, fname, lang) 204 | 205 | saw = set([tag.kind for tag in pre_tags]) 206 | if "ref" in saw or "def" not in saw: 207 | return pre_tags 208 | 209 | # We saw defs, without any refs 210 | # Some tags files only provide defs (cpp, for example) 211 | # Use pygments to backfill refs 212 | refs = refs_from_lexer(rel_fname, fname, code, lang) 213 | out = pre_tags + refs 214 | return out 215 | 216 | 217 | def read_text(filename: str, encoding: str = "utf-8") -> str | None: 218 | try: 219 | with open(str(filename), "r", encoding=encoding) as f: 220 | return f.read() 221 | except FileNotFoundError: 222 | logger.error(f"{filename}: file not found error") 223 | return 224 | except IsADirectoryError: 225 | logger.error(f"{filename}: is a directory") 226 | return 227 | except UnicodeError as e: 228 | logger.error(f"{filename}: {e}") 229 | logger.error("Use encoding parameter to set the unicode encoding.") 230 | return 231 | -------------------------------------------------------------------------------- /motleycoder/codemap/rank.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import defaultdict, Counter 3 | from typing import List, Dict, Collection 4 | 5 | import networkx as nx 6 | import numpy as np 7 | 8 | from .map_args import RepoMapArgs 9 | from .tag import Tag 10 | 11 | 12 | def rank_tags_new( 13 | tag_graph: nx.MultiDiGraph, 14 | args: RepoMapArgs, 15 | diffusion_mult=0.2, 16 | ) -> List[Tag | tuple]: 17 | G = copy.deepcopy(tag_graph) 18 | for tag in G.nodes: 19 | G.nodes[tag]["weight"] = 0.0 20 | 21 | mentioned_entities_clean = set([name.split(".")[-1] for name in args.mentioned_entities]) 22 | 23 | # process mentioned_idents 24 | for tag in tag_graph.nodes: 25 | if tag.kind == "def": 26 | if tag.fname in args.chat_fnames and tag.name in mentioned_entities_clean: 27 | G.nodes[tag]["weight"] += 3.0 28 | 29 | elif tag.name in args.mentioned_idents: 30 | G.nodes[tag]["weight"] += 1.0 31 | 32 | # process mentioned_fnames 33 | mentioned_weights = weights_from_fnames(tag_graph, args.mentioned_fnames) 34 | for tag, weight in mentioned_weights.items(): 35 | G.nodes[tag]["weight"] += 0.2 * weight 36 | 37 | # process chat_fnames 38 | chat_fname_weights = weights_from_fnames(tag_graph, args.chat_fnames) 39 | for tag, weight in chat_fname_weights.items(): 40 | G.nodes[tag]["weight"] += 0.5 * weight 41 | 42 | # process search_terms: 43 | tag_matches = defaultdict(set) 44 | for tag in tag_graph.nodes: 45 | for term in args.search_terms: 46 | if tag.kind == "def" and term in tag.text: 47 | tag_matches[term].add(tag) 48 | 49 | typical_search_count = np.median([len(tags) for tags in tag_matches.values()]) 50 | for term, tags in tag_matches.items(): 51 | for tag in tags: 52 | G.nodes[tag]["weight"] += typical_search_count / len(tags) 53 | 54 | # diffuse these weights through the graph 55 | G1 = copy.deepcopy(G) 56 | for t in G.nodes: 57 | for _, t2 in G.out_edges(t): 58 | G.nodes[t2]["weight"] += G1.nodes[t]["weight"] * diffusion_mult 59 | 60 | # Order the tags by weight 61 | node_list = [(G.nodes[tag]["weight"], tag) for tag in G.nodes] 62 | tags = sorted( 63 | node_list, 64 | key=lambda x: x[0], 65 | reverse=True, 66 | ) 67 | 68 | return [t[1] for t in tags] 69 | 70 | 71 | def rank_tags( 72 | tags: List[Tag], 73 | args: RepoMapArgs, 74 | other_rel_fnames: Collection[str], 75 | ) -> List[tuple]: 76 | """ 77 | The original aider ranking algorithm 78 | """ 79 | defines = defaultdict(set) 80 | references = defaultdict(list) 81 | definitions = defaultdict(set) 82 | 83 | cleaned_fnames = set([(tag.fname, tag.rel_fname) for tag in tags]) 84 | 85 | for tag in tags: 86 | if tag.kind == "def": 87 | defines[tag.name].add(tag.rel_fname) 88 | definitions[(tag.rel_fname, tag.name)].add(tag) 89 | 90 | if tag.kind == "ref": 91 | references[tag.name].append(tag.rel_fname) 92 | 93 | # now construct the graph 94 | 95 | chat_rel_fnames = set() 96 | personalization = dict() 97 | # Default personalization for unspecified files is 1/num_nodes 98 | # https://networkx.org/documentation/stable/_modules/networkx/algorithms/link_analysis/pagerank_alg.html#pagerank 99 | personalize = 10 / (len(cleaned_fnames) + 1) 100 | 101 | for fname, rel_fname in cleaned_fnames: 102 | if fname in args.chat_fnames: 103 | personalization[rel_fname] = personalize 104 | chat_rel_fnames.add(rel_fname) 105 | 106 | if fname in args.mentioned_fnames: 107 | personalization[rel_fname] = personalize 108 | 109 | if not references: 110 | references = dict((k, list(v)) for k, v in defines.items()) 111 | 112 | idents = set(defines.keys()).intersection(set(references.keys())) 113 | 114 | G = nx.MultiDiGraph() 115 | 116 | for ident in idents: 117 | definers = defines[ident] 118 | if ident in args.mentioned_idents: 119 | mul = 10 120 | else: 121 | mul = 1 122 | for referencer, num_refs in Counter(references[ident]).items(): 123 | for definer in definers: 124 | # if referencer == definer: 125 | # continue 126 | G.add_edge(referencer, definer, weight=mul * num_refs, ident=ident) 127 | 128 | if personalization: 129 | pers_args = dict(personalization=personalization, dangling=personalization) 130 | else: 131 | pers_args = dict() 132 | 133 | try: 134 | ranked = nx.pagerank(G, weight="weight", **pers_args) 135 | except ZeroDivisionError: 136 | return [] 137 | 138 | # distribute the rank from each source node, across all of its out edges 139 | ranked_definitions = defaultdict(float) 140 | for src in G.nodes: 141 | src_rank = ranked[src] 142 | total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True)) 143 | # dump(src, src_rank, total_weight) 144 | for _src, dst, data in G.out_edges(src, data=True): 145 | data["rank"] = src_rank * data["weight"] / total_weight 146 | ident = data["ident"] 147 | ranked_definitions[(dst, ident)] += data["rank"] 148 | 149 | ranked_tags = [] 150 | ranked_definitions = sorted(ranked_definitions.items(), reverse=True, key=lambda x: x[1]) 151 | 152 | # dump(ranked_definitions) 153 | 154 | # First collect the definitions in rank order 155 | # Do NOT include the chat-added files - is that because they'll be added in their entirety? 156 | for (fname, ident), rank in ranked_definitions: 157 | # print(f"{rank:.03f} {fname} {ident}") 158 | if fname in chat_rel_fnames: 159 | continue 160 | ranked_tags += list(definitions.get((fname, ident), [])) 161 | 162 | rel_other_fnames_without_tags = set(other_rel_fnames) 163 | 164 | fnames_already_included = set(rt.rel_fname for rt in ranked_tags) 165 | 166 | # Then go through the __files__ ranked earlier, and add them in rank order 167 | # These are just files with references, without definitions, presumably 168 | top_rank = sorted([(rank, node) for (node, rank) in ranked.items()], reverse=True) 169 | for rank, fname in top_rank: 170 | if fname in rel_other_fnames_without_tags: 171 | rel_other_fnames_without_tags.remove(fname) 172 | if fname not in fnames_already_included: 173 | ranked_tags.append((fname,)) 174 | 175 | # At the very tail of the list, append the files that have no tags at all 176 | for fname in rel_other_fnames_without_tags: 177 | ranked_tags.append((fname,)) 178 | 179 | return ranked_tags 180 | 181 | 182 | def weights_from_fnames( 183 | tag_graph: nx.MultiDiGraph, mentioned_fnames: Collection[str] 184 | ) -> Dict[Tag, float]: 185 | tag_weights = defaultdict(float) 186 | fname_counts = defaultdict(int) 187 | for tag in tag_graph.nodes: 188 | if tag.kind == "def" and tag.fname in mentioned_fnames: 189 | fname_counts[tag.fname] += 1 190 | 191 | # Normalize the weights to take into account what's typical in the codebase 192 | typical_count = np.median(np.array(list(fname_counts.values()))) 193 | for tag in tag_graph.nodes: 194 | if tag.fname in fname_counts and tag.kind == "def": 195 | tag_weights[tag] += typical_count / fname_counts[tag.fname] 196 | 197 | return tag_weights 198 | -------------------------------------------------------------------------------- /motleycoder/codemap/render.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict 2 | 3 | from grep_ast import TreeContext 4 | 5 | from .tag import Tag 6 | 7 | 8 | class RenderCode: 9 | def __init__(self): 10 | self.tree_cache = {} 11 | self.code_map = {} 12 | 13 | def to_tree( 14 | self, 15 | tags: List[Tag | tuple], 16 | render_file_name: bool = True, 17 | additional_lines: Optional[Dict[str, List[int]]] = None, 18 | ) -> str: 19 | if not tags: 20 | return "" 21 | 22 | additional_lines = additional_lines or {} 23 | 24 | assert ( 25 | render_file_name or len(set(tag.fname for tag in tags)) <= 1 26 | ), "can't render without filenames if there are multiple files" 27 | 28 | tags = sorted(tags, key=lambda x: tuple(x)) 29 | 30 | cur_fname = None 31 | cur_abs_fname = None 32 | lois = None 33 | output = "" 34 | 35 | # add a bogus tag at the end so we trip the this_fname != cur_fname... 36 | dummy_tag = (None,) 37 | for tag in tags + [dummy_tag]: 38 | this_rel_fname = tag[0] 39 | 40 | # ... here ... to output the final real entry in the list 41 | if this_rel_fname != cur_fname: 42 | if lois is not None: 43 | output += "\n" 44 | if render_file_name: 45 | output += cur_fname + ":\n" 46 | output += self.render_tree( 47 | cur_fname, 48 | lois + additional_lines.get(cur_fname, []), 49 | code=self.code_map[cur_abs_fname], 50 | ) 51 | lois = None 52 | elif cur_fname: 53 | if render_file_name: 54 | output += "\n" + cur_fname + "\n" 55 | if type(tag) is Tag: 56 | lois = [] 57 | cur_abs_fname = tag.fname 58 | cur_fname = this_rel_fname 59 | 60 | if lois is not None: 61 | lois.append(tag.line) 62 | 63 | # truncate long lines, in case we get minified js or something else crazy 64 | output = "\n".join([line[:100] for line in output.splitlines()]) + "\n" 65 | 66 | return output 67 | 68 | def render_tree(self, rel_fname, lois, line_number: bool = True, code="") -> str: 69 | key = (rel_fname, tuple(sorted(lois))) 70 | 71 | if key in self.tree_cache: 72 | return self.tree_cache[key] 73 | 74 | if code and not code.endswith("\n"): 75 | code += "\n" 76 | 77 | context = TreeContext( 78 | rel_fname, 79 | code, 80 | color=False, 81 | line_number=line_number, 82 | child_context=False, 83 | last_line=False, 84 | margin=0, 85 | mark_lois=False, 86 | loi_pad=0, 87 | # header_max=30, 88 | show_top_of_file_parent_scope=False, 89 | ) 90 | 91 | context.add_lines_of_interest(lois) 92 | context.add_context() 93 | res = context.format() 94 | self.tree_cache[key] = res 95 | return res 96 | 97 | @staticmethod 98 | def text_with_line_numbers(t: Tag) -> str: 99 | out = [] 100 | for i, line in enumerate(t.text.split("\n")): 101 | re_line = RenderCode.render_line(line, i + 1 + t.line) 102 | out.append(re_line) 103 | return "\n".join(out) 104 | 105 | @staticmethod 106 | def render_line(line: str, number: int) -> str: 107 | return f"{number:3}│{line}" 108 | -------------------------------------------------------------------------------- /motleycoder/codemap/repomap.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | from typing import List, Set, Optional 5 | 6 | import litellm 7 | from langchain_core.pydantic_v1 import BaseModel, Field 8 | 9 | from motleycrew.common import logger 10 | from .file_group import ( 11 | FileGroup, 12 | get_ident_mentions, 13 | get_ident_filename_matches, 14 | ) 15 | from .graph import TagGraph, build_tag_graph, only_defs # noqa: F402 16 | from .map_args import RepoMapArgs 17 | from .parse import get_tags_raw, read_text # noqa: F402 18 | from .rank import rank_tags_new, rank_tags # noqa: F402 19 | from .render import RenderCode 20 | 21 | # tree_sitter is throwing a FutureWarning 22 | warnings.simplefilter("ignore", category=FutureWarning) 23 | 24 | 25 | class RepoMap: 26 | def __init__( 27 | self, 28 | map_tokens: int = 1024, 29 | root: Optional[str] = None, 30 | llm_name: Optional[str] = None, 31 | repo_content_prefix: Optional[str] = None, 32 | verbose: bool = False, 33 | file_group: FileGroup = None, 34 | use_old_ranking: bool = False, 35 | cache_graphs: bool = False, 36 | ): 37 | self.verbose = verbose 38 | self.use_old_ranking = use_old_ranking 39 | 40 | if not root: 41 | root = os.getcwd() 42 | self.root = root 43 | 44 | self.max_map_tokens = map_tokens 45 | 46 | self.llm_name = llm_name 47 | self.repo_content_prefix = repo_content_prefix 48 | self.file_group = file_group 49 | self.code_renderer = RenderCode() 50 | self.tag_graphs = {} if cache_graphs else None 51 | 52 | def tokenizer(self, text): 53 | return litellm.encode(model=self.llm_name, text=text) 54 | 55 | def token_count(self, messages): 56 | if not self.tokenizer: 57 | return 58 | 59 | if type(messages) is str: 60 | llm_input = messages 61 | else: 62 | llm_input = json.dumps(messages) 63 | 64 | return len(self.tokenizer(llm_input)) 65 | 66 | def get_repo_map( 67 | self, 68 | args: RepoMapArgs, 69 | ): 70 | if self.max_map_tokens <= 0: 71 | return 72 | if not args.other_fnames: 73 | return 74 | 75 | try: 76 | ranked_tags = self.get_ranked_tags(args=args) 77 | files_listing = self.find_best_tag_tree(ranked_tags=ranked_tags) 78 | except RecursionError: 79 | logger.error("Disabling repo map, git repo too large?") 80 | self.max_map_tokens = 0 81 | return 82 | 83 | if not files_listing: 84 | return 85 | 86 | num_tokens = self.token_count(files_listing) 87 | if self.verbose: 88 | logger.info(f"Repo-map: {num_tokens/1024:.1f} k-tokens") 89 | 90 | if self.repo_content_prefix and args.add_prefix: 91 | repo_content = self.repo_content_prefix 92 | else: 93 | repo_content = "" 94 | 95 | repo_content += files_listing 96 | return repo_content 97 | 98 | def get_tag_graph( 99 | self, abs_fnames: List[str] | None = None, with_tests: bool = False 100 | ) -> TagGraph: 101 | if not abs_fnames: 102 | abs_fnames = self.file_group.get_all_filenames(with_tests=with_tests) 103 | clean_fnames = self.file_group.validate_fnames(abs_fnames, with_tests=with_tests) 104 | 105 | if self.tag_graphs is not None: 106 | for files, graph in self.tag_graphs.items(): 107 | if not set(clean_fnames).difference(set(files)): 108 | return graph 109 | 110 | # If no caching or cached graph not found, construct it 111 | all_tags = [] 112 | code_map = {} 113 | for fname in clean_fnames: 114 | code, tags = self.tags_from_filename(fname) 115 | all_tags += tags 116 | code_map[fname] = code 117 | 118 | raw_graph = build_tag_graph(all_tags, code_map) 119 | graph = only_defs(raw_graph) 120 | 121 | if self.tag_graphs is not None: 122 | self.tag_graphs[tuple(clean_fnames)] = graph 123 | return graph 124 | 125 | def tags_from_filename(self, fname): 126 | def get_tags_raw_function(fname): 127 | code = read_text(fname) 128 | rel_fname = self.file_group.get_rel_fname(fname) 129 | data = get_tags_raw(fname, rel_fname, code) 130 | assert isinstance(data, list) 131 | return code, data 132 | 133 | return self.file_group.cached_function_call(fname, get_tags_raw_function) 134 | 135 | def get_ranked_tags( 136 | self, 137 | args: RepoMapArgs, 138 | ): 139 | 140 | # Check file names for validity 141 | fnames = sorted(set(args.chat_fnames).union(set(args.other_fnames))) 142 | cleaned = self.file_group.validate_fnames(fnames) 143 | 144 | # All the source code parsing happens here 145 | tag_graph = self.get_tag_graph(cleaned) 146 | self.code_renderer.code_map = tag_graph.code_renderer.code_map 147 | 148 | tags = list(tag_graph.nodes) 149 | 150 | if self.use_old_ranking: 151 | other_rel_fnames = [self.file_group.get_rel_fname(fname) for fname in args.other_fnames] 152 | ranked_tags = rank_tags(tags, args=args, other_rel_fnames=other_rel_fnames) 153 | else: 154 | ranked_tags = rank_tags_new( 155 | tag_graph, 156 | args=args, 157 | ) 158 | 159 | return ranked_tags 160 | 161 | def find_best_tag_tree( 162 | self, 163 | ranked_tags: list, 164 | ): 165 | """Does a binary search over the number of tags to include in the map, 166 | to find the largest map that fits within the token limit. 167 | """ 168 | num_tags = len(ranked_tags) 169 | lower_bound = 0 170 | upper_bound = num_tags 171 | best_tree = None 172 | best_tree_tokens = 0 173 | 174 | # Guess a small starting number to help with giant repos 175 | middle = min(self.max_map_tokens // 25, num_tags) 176 | 177 | while lower_bound <= upper_bound: 178 | used_tags = [tag for tag in ranked_tags[:middle]] 179 | tree = self.code_renderer.to_tree(used_tags) 180 | num_tokens = self.token_count(tree) 181 | 182 | if self.max_map_tokens > num_tokens > best_tree_tokens: 183 | best_tree = tree 184 | best_tree_tokens = num_tokens 185 | 186 | if num_tokens < self.max_map_tokens: 187 | lower_bound = middle + 1 188 | else: 189 | upper_bound = middle - 1 190 | 191 | middle = (lower_bound + upper_bound) // 2 192 | 193 | return best_tree 194 | 195 | def repo_map_from_message( 196 | self, 197 | message: str, 198 | mentioned_entities: Set[str] | None = None, 199 | add_prefix: bool = False, 200 | llm=None, 201 | ) -> str: 202 | all_files = self.file_group.get_all_filenames() 203 | added_files = self.file_group.files_for_modification 204 | other_files = set(all_files) - set(added_files) 205 | 206 | if llm is not None: 207 | search_terms = search_terms_from_message(message, llm) 208 | else: 209 | search_terms = set() 210 | 211 | mentioned_fnames = self.file_group.get_file_mentions(message) 212 | mentioned_idents = get_ident_mentions(message) 213 | 214 | all_rel_fnames = [self.file_group.get_rel_fname(f) for f in all_files] 215 | mentioned_fnames.update(get_ident_filename_matches(mentioned_idents, all_rel_fnames)) 216 | 217 | args = RepoMapArgs( 218 | chat_fnames=added_files or {}, 219 | other_fnames=other_files or {}, 220 | mentioned_fnames=mentioned_fnames or {}, 221 | mentioned_idents=mentioned_idents or {}, 222 | mentioned_entities=mentioned_entities or {}, 223 | search_terms=search_terms, 224 | add_prefix=add_prefix, 225 | ) 226 | 227 | repo_content = self.get_repo_map(args) 228 | 229 | # fall back to global repo map if files in chat are disjoint from rest of repo 230 | if not repo_content: 231 | args.chat_fnames = set() 232 | args.other_fnames = set(all_files) 233 | 234 | repo_content = self.get_repo_map(args) 235 | 236 | # fall back to completely unhinted repo 237 | if not repo_content: 238 | args = RepoMapArgs(search_terms=search_terms, add_prefix=add_prefix) 239 | 240 | repo_content = self.get_repo_map(args) 241 | 242 | return repo_content 243 | 244 | 245 | def search_terms_from_message(message: str, llm) -> Set[str]: 246 | search_prompt = f"""You are an expert bug fixer. You are given a bug report. 247 | Return a JSON list of at most 10 strings extracted from the bug report, that should be used 248 | in a full-text search of the codebase to find the part of the code that needs to be modified. 249 | Select at most 10 strings that are most likely to be unique to the part of the code that needs to be modified. 250 | ONLY extract strings that you could expect to find verbatim in the code, especially function names, 251 | class names, and error messages. 252 | For method calls, such as `foo.bar()`, extract `.bar(` 253 | 254 | For error messages, extract the bits of the error message that are likely to be found VERBATIM in the code, 255 | for example "File not found: " rather than "File not found: /amger/gae/doc.tcx"; 256 | return "A string is required" rather than "A string is required, not 'MyWeirdClassName'". 257 | 258 | Here is the problem description: 259 | {message}""" 260 | 261 | class ListOfStrings(BaseModel): 262 | strings: List[str] = Field( 263 | description="List of full-text search strings to find the part of the code that needs to be modified." 264 | ) 265 | 266 | out = llm.with_structured_output(ListOfStrings).invoke(search_prompt) 267 | re_out = [x.split(".")[-1] for x in out.strings] 268 | re_out = sum([x.split(",") for x in re_out], []) 269 | return set(re_out) 270 | -------------------------------------------------------------------------------- /motleycoder/codemap/tag.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Tag: 6 | rel_fname: str 7 | line: int 8 | end_line: int 9 | name: str 10 | kind: str 11 | docstring: str 12 | fname: str 13 | text: str 14 | byte_range: tuple[int, int] 15 | parent_names: tuple[str, ...] = () 16 | language: str | None = None 17 | n_defs: int = 0 18 | 19 | @property 20 | def full_name(self): 21 | if self.kind == "ref": 22 | return self.name 23 | else: 24 | return tuple(list(self.parent_names) + [self.name]) 25 | 26 | def to_tuple(self): 27 | return ( 28 | self.rel_fname, 29 | self.line, 30 | self.name, 31 | self.kind, 32 | self.docstring, 33 | self.fname, 34 | self.text, 35 | self.byte_range, 36 | self.parent_names, 37 | ) 38 | 39 | def __getitem__(self, item): 40 | return self.to_tuple()[item] 41 | 42 | def __len__(self): 43 | return len(self.to_tuple()) 44 | 45 | def __hash__(self): 46 | return hash(self.to_tuple()) 47 | -------------------------------------------------------------------------------- /motleycoder/linter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | import sys 5 | import traceback 6 | import warnings 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | 10 | from grep_ast import TreeContext, filename_to_lang 11 | from tree_sitter_languages import get_parser # noqa: E402 12 | 13 | # tree_sitter is throwing a FutureWarning 14 | warnings.simplefilter("ignore", category=FutureWarning) 15 | 16 | 17 | class Linter: 18 | def __init__(self, encoding="utf-8", root=None): 19 | self.encoding = encoding 20 | self.root = root 21 | 22 | self.languages = dict( 23 | python=self.py_lint, 24 | ) 25 | self.all_lint_cmd = None 26 | 27 | def set_linter(self, lang, cmd): 28 | if lang: 29 | self.languages[lang] = cmd 30 | return 31 | 32 | self.all_lint_cmd = cmd 33 | 34 | def get_rel_fname(self, fname): 35 | if self.root: 36 | return os.path.relpath(fname, self.root) 37 | else: 38 | return fname 39 | 40 | def run_cmd(self, cmd, rel_fname, code): 41 | cmd += " " + rel_fname 42 | cmd = cmd.split() 43 | 44 | process = subprocess.Popen( 45 | cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT 46 | ) 47 | stdout, _ = process.communicate() 48 | errors = stdout.decode() 49 | if process.returncode == 0: 50 | return # zero exit status 51 | 52 | cmd = " ".join(cmd) 53 | res = f"## Running: {cmd}\n\n" 54 | res += errors 55 | 56 | return self.errors_to_lint_result(rel_fname, res) 57 | 58 | def errors_to_lint_result(self, rel_fname, errors): 59 | if not errors: 60 | return 61 | 62 | linenums = [] 63 | filenames_linenums = find_filenames_and_linenums(errors, [rel_fname]) 64 | if filenames_linenums: 65 | filename, linenums = next(iter(filenames_linenums.items())) 66 | linenums = [num - 1 for num in linenums] 67 | 68 | return LintResult(text=errors, lines=linenums) 69 | 70 | def lint(self, fname, cmd=None): 71 | rel_fname = self.get_rel_fname(fname) 72 | code = Path(fname).read_text(self.encoding) 73 | 74 | if cmd: 75 | cmd = cmd.strip() 76 | if not cmd: 77 | lang = filename_to_lang(fname) 78 | if not lang: 79 | return 80 | if self.all_lint_cmd: 81 | cmd = self.all_lint_cmd 82 | else: 83 | cmd = self.languages.get(lang) 84 | 85 | if callable(cmd): 86 | linkres = cmd(fname, rel_fname, code) 87 | elif cmd: 88 | linkres = self.run_cmd(cmd, rel_fname, code) 89 | else: 90 | linkres = basic_lint(rel_fname, code) 91 | 92 | if not linkres: 93 | return 94 | 95 | res = "# Edits applied. Now, fix any errors below, if possible.\n\n" 96 | res += linkres.text 97 | res += "\n" 98 | res += tree_context(rel_fname, code, linkres.lines) 99 | 100 | return res 101 | 102 | def py_lint(self, fname, rel_fname, code): 103 | basic_res = basic_lint(rel_fname, code) 104 | compile_res = lint_python_compile(fname, code) 105 | flake_res = self.flake8_lint(rel_fname) 106 | 107 | text = "" 108 | lines = set() 109 | for res in [basic_res, compile_res, flake_res]: 110 | if not res: 111 | continue 112 | if text: 113 | text += "\n" 114 | text += res.text 115 | lines.update(res.lines) 116 | 117 | if text or lines: 118 | return LintResult(text, lines) 119 | 120 | def flake8_lint(self, rel_fname): 121 | fatal = "E9,F821,F823,F831,F406,F407,F701,F702,F704,F706" 122 | flake8_cmd = [ 123 | sys.executable, 124 | "-m", 125 | "flake8", 126 | f"--select={fatal}", 127 | "--show-source", 128 | "--isolated", 129 | rel_fname, 130 | ] 131 | 132 | text = f"## Running: {' '.join(flake8_cmd)}\n\n" 133 | 134 | try: 135 | result = subprocess.run( 136 | flake8_cmd, 137 | cwd=self.root, 138 | capture_output=True, 139 | text=True, 140 | check=False, 141 | ) 142 | errors = result.stdout + result.stderr 143 | except Exception as e: 144 | errors = f"Error running flake8: {str(e)}" 145 | 146 | if not errors: 147 | return 148 | 149 | text += errors 150 | return self.errors_to_lint_result(rel_fname, text) 151 | 152 | 153 | @dataclass 154 | class LintResult: 155 | text: str 156 | lines: list 157 | 158 | 159 | def lint_python_compile(fname, code): 160 | try: 161 | compile(code, fname, "exec") # USE TRACEBACK BELOW HERE 162 | return 163 | except Exception as err: 164 | end_lineno = getattr(err, "end_lineno", err.lineno) 165 | line_numbers = list(range(err.lineno - 1, end_lineno)) 166 | 167 | tb_lines = traceback.format_exception(type(err), err, err.__traceback__) 168 | last_file_i = 0 169 | 170 | target = "# USE TRACEBACK" 171 | target += " BELOW HERE" 172 | for i in range(len(tb_lines)): 173 | if target in tb_lines[i]: 174 | last_file_i = i 175 | break 176 | 177 | tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :] 178 | 179 | res = "".join(tb_lines) 180 | return LintResult(text=res, lines=line_numbers) 181 | 182 | 183 | def basic_lint(fname, code): 184 | """ 185 | Use tree-sitter to look for syntax errors, display them with tree context. 186 | """ 187 | 188 | lang = filename_to_lang(fname) 189 | if not lang: 190 | return 191 | 192 | parser = get_parser(lang) 193 | tree = parser.parse(bytes(code, "utf-8")) 194 | 195 | errors = traverse_tree(tree.root_node) 196 | if not errors: 197 | return 198 | 199 | return LintResult(text="", lines=errors) 200 | 201 | 202 | def tree_context(fname, code, line_nums): 203 | context = TreeContext( 204 | fname, 205 | code, 206 | color=False, 207 | line_number=True, 208 | child_context=False, 209 | last_line=False, 210 | margin=0, 211 | mark_lois=True, 212 | loi_pad=3, 213 | # header_max=30, 214 | show_top_of_file_parent_scope=False, 215 | ) 216 | line_nums = set(line_nums) 217 | context.add_lines_of_interest(line_nums) 218 | context.add_context() 219 | s = "s" if len(line_nums) > 1 else "" 220 | output = f"## See relevant line{s} below marked with █.\n\n" 221 | output += fname + ":\n" 222 | output += context.format() 223 | 224 | return output 225 | 226 | 227 | # Traverse the tree to find errors 228 | def traverse_tree(node): 229 | errors = [] 230 | if node.type == "ERROR" or node.is_missing: 231 | line_no = node.start_point[0] 232 | errors.append(line_no) 233 | 234 | for child in node.children: 235 | errors += traverse_tree(child) 236 | 237 | return errors 238 | 239 | 240 | def find_filenames_and_linenums(text, fnames): 241 | """ 242 | Search text for all occurrences of :\\d+ and make a list of them 243 | where is one of the filenames in the list `fnames`. 244 | """ 245 | pattern = re.compile(r"(\b(?:" + "|".join(re.escape(fname) for fname in fnames) + r"):\d+\b)") 246 | matches = pattern.findall(text) 247 | result = {} 248 | for match in matches: 249 | fname, linenum = match.rsplit(":", 1) 250 | if fname not in result: 251 | result[fname] = set() 252 | result[fname].add(int(linenum)) 253 | return result 254 | 255 | 256 | def main(): 257 | """ 258 | Main function to parse files provided as command line arguments. 259 | """ 260 | if len(sys.argv) < 2: 261 | print("Usage: python linter.py ...") 262 | sys.exit(1) 263 | 264 | linter = Linter(root=os.getcwd()) 265 | for file_path in sys.argv[1:]: 266 | errors = linter.lint(file_path) 267 | if errors: 268 | print(errors) 269 | 270 | 271 | if __name__ == "__main__": 272 | main() 273 | -------------------------------------------------------------------------------- /motleycoder/prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from langchain_core.messages import HumanMessage, AIMessage, ToolMessage 4 | from langchain_core.prompts import SystemMessagePromptTemplate, ChatPromptTemplate 5 | 6 | 7 | class MotleyCoderPrompts: 8 | files_content_prefix = """The files have been *added to the chat* so you can go ahead and edit them. 9 | 10 | *Trust this message as the true contents of the files!* 11 | Any other messages in the chat may contain outdated versions of the files' contents. 12 | """ # noqa: E501 13 | 14 | file_edit_success = """The file {file_path} has been successfully edited. 15 | If you are finished, call the tool `return_to_user` to apply the changes and inform the user that you have finished. 16 | """ 17 | 18 | main_system = SystemMessagePromptTemplate.from_template( 19 | """Act as an expert software developer. 20 | Always use best practices when coding. 21 | Respect and use existing conventions, libraries, etc that are already present in the code base. 22 | 23 | You are diligent and tireless! 24 | You NEVER leave comments describing code without implementing it! 25 | You always COMPLETELY IMPLEMENT the needed code! 26 | 27 | Take requests for changes to the supplied code. 28 | If the request is ambiguous, ask questions using the tool `return_to_user`. 29 | 30 | Always reply to the user in the same language they are using. 31 | 32 | Once you understand the request you MUST: 33 | 1. Decide if you need to propose *SEARCH/REPLACE* edits to any files that haven't been added to the chat. 34 | You can create new files without asking. But if you need to propose edits to existing files not already added to the chat, 35 | you *MUST* call the tool `add_files` with their full path names to *add the files to the chat*. 36 | Make the tool call and wait for it to report success or failure. 37 | You can keep calling the tool if you then decide you need to edit more files. 38 | 2. Think step-by-step and explain the needed changes with a numbered list of short sentences. 39 | 3. Make the changes to the files by calling the tool `edit_file` with the *SEARCH/REPLACE arguments* for each change. 40 | You can keep calling the tool with new *SEARCH/REPLACE arguments* until you have made all the necessary changes. 41 | ONLY EVER RETURN CODE IN THE ARGUMENTS OF THE `edit_file` TOOL CALL! 42 | 4. After making all the necessary changes, you MUST call the tool `return_to_user` to apply the changes and to inform 43 | the user that you have finished. You can't call any tools after this step. 44 | 45 | You have access to the following tools: 46 | {tools} 47 | 48 | All changes to files must be made using the `edit_file` tool. 49 | """ 50 | ) 51 | 52 | example_messages = [ 53 | HumanMessage( 54 | content="Change get_factorial() to use math.factorial", 55 | ), 56 | AIMessage( 57 | content="""To make this change we need to modify `mathweb/flask/app.py` to: 58 | 59 | 1. Import the math package. 60 | 2. Remove the existing factorial() function. 61 | 3. Update get_factorial() to call math.factorial instead. 62 | """, 63 | additional_kwargs={ 64 | "tool_calls": [ 65 | { 66 | "index": 0, 67 | "id": "call_aSuMulBd6JVrHCMjyKSi93na", 68 | "function": { 69 | "arguments": json.dumps( 70 | dict( 71 | file_path="mathweb/flask/app.py", 72 | language="python", 73 | search="""from flask import Flask 74 | """, 75 | replace="""import math 76 | from flask import Flask 77 | """, 78 | ) 79 | ), 80 | "name": "edit_file", 81 | }, 82 | "type": "function", 83 | }, 84 | { 85 | "index": 1, 86 | "id": "call_hJ8ff6hI7dkNhFTy7IhqWlm9", 87 | "function": { 88 | "arguments": json.dumps( 89 | dict( 90 | file_path="mathweb/flask/app.py", 91 | language="python", 92 | search="""def factorial(n): 93 | "compute factorial" 94 | 95 | if n == 0: 96 | return 1 97 | else: 98 | return n * factorial(n-1) 99 | 100 | """, 101 | replace="", 102 | ) 103 | ), 104 | "name": "edit_file", 105 | }, 106 | "type": "function", 107 | }, 108 | { 109 | "index": 2, 110 | "id": "call_8b9slOlKS84JfuUi320KgLnA", 111 | "function": { 112 | "arguments": json.dumps( 113 | dict( 114 | file_path="mathweb/flask/app.py", 115 | language="python", 116 | search=""" return str(factorial(n)) 117 | """, 118 | replace=""" return str(math.factorial(n)) 119 | """, 120 | ) 121 | ), 122 | "name": "edit_file", 123 | }, 124 | "type": "function", 125 | }, 126 | ] 127 | }, 128 | ), 129 | ToolMessage( 130 | content=file_edit_success.format(file_path="mathweb/flask/app.py"), 131 | tool_call_id="call_aSuMulBd6JVrHCMjyKSi93na", 132 | ), 133 | ToolMessage( 134 | content=file_edit_success.format(file_path="mathweb/flask/app.py"), 135 | tool_call_id="call_hJ8ff6hI7dkNhFTy7IhqWlm9", 136 | ), 137 | ToolMessage( 138 | content=file_edit_success.format(file_path="mathweb/flask/app.py"), 139 | tool_call_id="call_8b9slOlKS84JfuUi320KgLnA", 140 | ), 141 | AIMessage( 142 | content="""I have now made all the necessary changes. 143 | I can now apply the changes to the code base. 144 | """, 145 | additional_kwargs={ 146 | "tool_calls": [ 147 | { 148 | "index": 0, 149 | "id": "call_fkr8403nFUv73kc9mN547IVm", 150 | "function": { 151 | "arguments": json.dumps(dict()), 152 | "name": "return_to_user", 153 | }, 154 | "type": "function", 155 | }, 156 | ] 157 | }, 158 | ), 159 | ToolMessage( 160 | content="Changes applied successfully.", 161 | tool_call_id="call_fkr8403nFUv73kc9mN547IVm", 162 | ), 163 | HumanMessage( 164 | content="Refactor hello() into its own file.", 165 | ), 166 | AIMessage( 167 | content="""To make this change we need to modify `main.py` and make a new file `hello.py`: 168 | 169 | 1. Make a new hello.py file with hello() in it. 170 | 2. Remove hello() from main.py and replace it with an import. 171 | """, 172 | additional_kwargs={ 173 | "tool_calls": [ 174 | { 175 | "index": 0, 176 | "id": "call_4lOYq2sR4ZCb22p2xUuq5igP", 177 | "function": { 178 | "arguments": json.dumps(dict(files=["hello.py"])), 179 | "name": "add_files", 180 | }, 181 | "type": "function", 182 | }, 183 | ] 184 | }, 185 | ), 186 | ToolMessage( 187 | content=files_content_prefix 188 | + """ 189 | 190 | hello.py 191 | ``` 192 | 193 | ``` 194 | """, 195 | tool_call_id="call_4lOYq2sR4ZCb22p2xUuq5igP", 196 | ), 197 | AIMessage( 198 | content="", 199 | additional_kwargs={ 200 | "tool_calls": [ 201 | { 202 | "index": 0, 203 | "id": "call_ld94UGfj3fkJn85KEk8vjDoK", 204 | "function": { 205 | "arguments": json.dumps( 206 | dict( 207 | file_path="hello.py", 208 | language="python", 209 | search="", 210 | replace="""def hello(): 211 | '''print a greeting''' 212 | 213 | print("hello") 214 | """, 215 | ) 216 | ), 217 | "name": "edit_file", 218 | }, 219 | "type": "function", 220 | }, 221 | { 222 | "index": 1, 223 | "id": "call_FgodUjbFFGAQ6x9mUeNU3TTR", 224 | "function": { 225 | "arguments": json.dumps( 226 | dict( 227 | file_path="main.py", 228 | language="python", 229 | search="""def hello(): 230 | '''print a greeting''' 231 | 232 | print("hello") 233 | """, 234 | replace="""from hello import hello 235 | """, 236 | ) 237 | ), 238 | "name": "edit_file", 239 | }, 240 | "type": "function", 241 | }, 242 | ] 243 | }, 244 | ), 245 | ToolMessage( 246 | content=file_edit_success.format(file_path="hello.py"), 247 | tool_call_id="call_ld94UGfj3fkJn85KEk8vjDoK", 248 | ), 249 | ToolMessage( 250 | content=file_edit_success.format(file_path="main.py"), 251 | tool_call_id="call_FgodUjbFFGAQ6x9mUeNU3TTR", 252 | ), 253 | AIMessage( 254 | content="""I have now made all the necessary changes. 255 | I can now apply the changes to the code base. 256 | """, 257 | additional_kwargs={ 258 | "tool_calls": [ 259 | { 260 | "index": 0, 261 | "id": "call_7TmRhiBSX5ud8DW1RyZSEcDf", 262 | "function": { 263 | "arguments": json.dumps(dict()), 264 | "name": "return_to_user", 265 | }, 266 | "type": "function", 267 | }, 268 | ] 269 | }, 270 | ), 271 | ToolMessage( 272 | content="Changes applied successfully.", 273 | tool_call_id="call_7TmRhiBSX5ud8DW1RyZSEcDf", 274 | ), 275 | ] 276 | 277 | system_reminder = SystemMessagePromptTemplate.from_template( 278 | """# `edit_file` tool call Rules: 279 | 280 | Every *SEARCH* argument must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc. 281 | 282 | `edit_file` tool will replace *all* matching occurrences. 283 | Include enough lines to make the SEARCH blocks unique. 284 | 285 | Include *ALL* the code being searched and replaced! 286 | 287 | Only call `edit_file` for files that the user has added to the chat! 288 | 289 | To move code within a file, use 2 `edit_file` calls: 1 to delete it from its current location, 1 to insert it in the new location. 290 | 291 | If you want to put code in a new file, call the `edit_file` tool with: 292 | - A new file path, including dir name if needed 293 | - An empty `SEARCH` argument 294 | - The new file's contents in the `REPLACE` argument 295 | 296 | You are diligent and tireless! 297 | You NEVER leave comments describing code without implementing it! 298 | You always COMPLETELY IMPLEMENT the needed code! 299 | 300 | ONLY EVER RETURN CODE IN THE ARGUMENTS OF THE `edit_file` TOOL CALL! 301 | 302 | You have access to the following tools: 303 | {tools} 304 | 305 | All changes to files must be made using the `edit_file` tool. 306 | """ 307 | ) 308 | 309 | def __init__(self): 310 | messages = [self.main_system] 311 | 312 | messages += self.example_messages 313 | messages += [ 314 | HumanMessage( 315 | "I switched to a new code base. Please don't consider the above files" 316 | " or try to edit them any longer." 317 | ), 318 | AIMessage("Ok."), 319 | ] 320 | 321 | messages += [self.system_reminder] 322 | 323 | self.prompt_template = ChatPromptTemplate.from_messages(messages) 324 | -------------------------------------------------------------------------------- /motleycoder/queries/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Credits 3 | 4 | Aider uses modified versions of the tags.scm files from these open source 5 | tree-sitter language implementations: 6 | 7 | * [https://github.com/tree-sitter/tree-sitter-c](https://github.com/tree-sitter/tree-sitter-c) — licensed under the MIT License. 8 | * [https://github.com/tree-sitter/tree-sitter-c-sharp](https://github.com/tree-sitter/tree-sitter-c-sharp) — licensed under the MIT License. 9 | * [https://github.com/tree-sitter/tree-sitter-cpp](https://github.com/tree-sitter/tree-sitter-cpp) — licensed under the MIT License. 10 | * [https://github.com/Wilfred/tree-sitter-elisp](https://github.com/Wilfred/tree-sitter-elisp) — licensed under the MIT License. 11 | * [https://github.com/elixir-lang/tree-sitter-elixir](https://github.com/elixir-lang/tree-sitter-elixir) — licensed under the Apache License, Version 2.0. 12 | * [https://github.com/elm-tooling/tree-sitter-elm](https://github.com/elm-tooling/tree-sitter-elm) — licensed under the MIT License. 13 | * [https://github.com/tree-sitter/tree-sitter-go](https://github.com/tree-sitter/tree-sitter-go) — licensed under the MIT License. 14 | * [https://github.com/tree-sitter/tree-sitter-java](https://github.com/tree-sitter/tree-sitter-java) — licensed under the MIT License. 15 | * [https://github.com/tree-sitter/tree-sitter-javascript](https://github.com/tree-sitter/tree-sitter-javascript) — licensed under the MIT License. 16 | * [https://github.com/tree-sitter/tree-sitter-ocaml](https://github.com/tree-sitter/tree-sitter-ocaml) — licensed under the MIT License. 17 | * [https://github.com/tree-sitter/tree-sitter-php](https://github.com/tree-sitter/tree-sitter-php) — licensed under the MIT License. 18 | * [https://github.com/tree-sitter/tree-sitter-python](https://github.com/tree-sitter/tree-sitter-python) — licensed under the MIT License. 19 | * [https://github.com/tree-sitter/tree-sitter-ql](https://github.com/tree-sitter/tree-sitter-ql) — licensed under the MIT License. 20 | * [https://github.com/r-lib/tree-sitter-r](https://github.com/r-lib/tree-sitter-r) — licensed under the MIT License. 21 | * [https://github.com/tree-sitter/tree-sitter-ruby](https://github.com/tree-sitter/tree-sitter-ruby) — licensed under the MIT License. 22 | * [https://github.com/tree-sitter/tree-sitter-rust](https://github.com/tree-sitter/tree-sitter-rust) — licensed under the MIT License. 23 | * [https://github.com/tree-sitter/tree-sitter-typescript](https://github.com/tree-sitter/tree-sitter-typescript) — licensed under the MIT License. 24 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-c-tags.scm: -------------------------------------------------------------------------------- 1 | (struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class 2 | 3 | (declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class 4 | 5 | (function_declarator declarator: (identifier) @name.definition.function) @definition.function 6 | 7 | (type_definition declarator: (type_identifier) @name.definition.type) @definition.type 8 | 9 | (enum_specifier name: (type_identifier) @name.definition.type) @definition.type 10 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-c_sharp-tags.scm: -------------------------------------------------------------------------------- 1 | (class_declaration 2 | name: (identifier) @name.definition.class 3 | ) @definition.class 4 | 5 | (class_declaration 6 | bases: (base_list (_) @name.reference.class) 7 | ) @reference.class 8 | 9 | (interface_declaration 10 | name: (identifier) @name.definition.interface 11 | ) @definition.interface 12 | 13 | (interface_declaration 14 | bases: (base_list (_) @name.reference.interface) 15 | ) @reference.interface 16 | 17 | (method_declaration 18 | name: (identifier) @name.definition.method 19 | ) @definition.method 20 | 21 | (object_creation_expression 22 | type: (identifier) @name.reference.class 23 | ) @reference.class 24 | 25 | (type_parameter_constraints_clause 26 | target: (identifier) @name.reference.class 27 | ) @reference.class 28 | 29 | (type_constraint 30 | type: (identifier) @name.reference.class 31 | ) @reference.class 32 | 33 | (variable_declaration 34 | type: (identifier) @name.reference.class 35 | ) @reference.class 36 | 37 | (invocation_expression 38 | function: 39 | (member_access_expression 40 | name: (identifier) @name.reference.send 41 | ) 42 | ) @reference.send 43 | 44 | (namespace_declaration 45 | name: (identifier) @name.definition.module 46 | ) @definition.module 47 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-cpp-tags.scm: -------------------------------------------------------------------------------- 1 | (struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class 2 | 3 | (declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class 4 | 5 | (function_declarator declarator: (identifier) @name.definition.function) @definition.function 6 | 7 | (function_declarator declarator: (field_identifier) @name.definition.function) @definition.function 8 | 9 | (function_declarator declarator: (qualified_identifier scope: (namespace_identifier) @scope name: (identifier) @name.definition.method)) @definition.method 10 | 11 | (type_definition declarator: (type_identifier) @name.definition.type) @definition.type 12 | 13 | (enum_specifier name: (type_identifier) @name.definition.type) @definition.type 14 | 15 | (class_specifier name: (type_identifier) @name.definition.class) @definition.class 16 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-elisp-tags.scm: -------------------------------------------------------------------------------- 1 | ;; defun/defsubst 2 | (function_definition name: (symbol) @name.definition.function) @definition.function 3 | 4 | ;; Treat macros as function definitions for the sake of TAGS. 5 | (macro_definition name: (symbol) @name.definition.function) @definition.function 6 | 7 | ;; Match function calls 8 | (list (symbol) @name.reference.function) @reference.function 9 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-elixir-tags.scm: -------------------------------------------------------------------------------- 1 | ; Definitions 2 | 3 | ; * modules and protocols 4 | (call 5 | target: (identifier) @ignore 6 | (arguments (alias) @name.definition.module) 7 | (#match? @ignore "^(defmodule|defprotocol)$")) @definition.module 8 | 9 | ; * functions/macros 10 | (call 11 | target: (identifier) @ignore 12 | (arguments 13 | [ 14 | ; zero-arity functions with no parentheses 15 | (identifier) @name.definition.function 16 | ; regular function clause 17 | (call target: (identifier) @name.definition.function) 18 | ; function clause with a guard clause 19 | (binary_operator 20 | left: (call target: (identifier) @name.definition.function) 21 | operator: "when") 22 | ]) 23 | (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @definition.function 24 | 25 | ; References 26 | 27 | ; ignore calls to kernel/special-forms keywords 28 | (call 29 | target: (identifier) @ignore 30 | (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp|defmodule|defprotocol|defimpl|defstruct|defexception|defoverridable|alias|case|cond|else|for|if|import|quote|raise|receive|require|reraise|super|throw|try|unless|unquote|unquote_splicing|use|with)$")) 31 | 32 | ; ignore module attributes 33 | (unary_operator 34 | operator: "@" 35 | operand: (call 36 | target: (identifier) @ignore)) 37 | 38 | ; * function call 39 | (call 40 | target: [ 41 | ; local 42 | (identifier) @name.reference.call 43 | ; remote 44 | (dot 45 | right: (identifier) @name.reference.call) 46 | ]) @reference.call 47 | 48 | ; * pipe into function call 49 | (binary_operator 50 | operator: "|>" 51 | right: (identifier) @name.reference.call) @reference.call 52 | 53 | ; * modules 54 | (alias) @name.reference.module @reference.module 55 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-elm-tags.scm: -------------------------------------------------------------------------------- 1 | (value_declaration (function_declaration_left (lower_case_identifier) @name.definition.function)) @definition.function 2 | 3 | (function_call_expr (value_expr (value_qid) @name.reference.function)) @reference.function 4 | (exposed_value (lower_case_identifier) @name.reference.function)) @reference.function 5 | (type_annotation ((lower_case_identifier) @name.reference.function) (colon)) @reference.function 6 | 7 | (type_declaration ((upper_case_identifier) @name.definition.type) ) @definition.type 8 | 9 | (type_ref (upper_case_qid (upper_case_identifier) @name.reference.type)) @reference.type 10 | (exposed_type (upper_case_identifier) @name.reference.type)) @reference.type 11 | 12 | (type_declaration (union_variant (upper_case_identifier) @name.definition.union)) @definition.union 13 | 14 | (value_expr (upper_case_qid (upper_case_identifier) @name.reference.union)) @reference.union 15 | 16 | 17 | (module_declaration 18 | (upper_case_qid (upper_case_identifier)) @name.definition.module 19 | ) @definition.module 20 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-go-tags.scm: -------------------------------------------------------------------------------- 1 | ( 2 | (comment)* @doc 3 | . 4 | (function_declaration 5 | name: (identifier) @name.definition.function) @definition.function 6 | (#strip! @doc "^//\\s*") 7 | (#set-adjacent! @doc @definition.function) 8 | ) 9 | 10 | ( 11 | (comment)* @doc 12 | . 13 | (method_declaration 14 | name: (field_identifier) @name.definition.method) @definition.method 15 | (#strip! @doc "^//\\s*") 16 | (#set-adjacent! @doc @definition.method) 17 | ) 18 | 19 | (call_expression 20 | function: [ 21 | (identifier) @name.reference.call 22 | (parenthesized_expression (identifier) @name.reference.call) 23 | (selector_expression field: (field_identifier) @name.reference.call) 24 | (parenthesized_expression (selector_expression field: (field_identifier) @name.reference.call)) 25 | ]) @reference.call 26 | 27 | (type_spec 28 | name: (type_identifier) @name.definition.type) @definition.type 29 | 30 | (type_identifier) @name.reference.type @reference.type 31 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-java-tags.scm: -------------------------------------------------------------------------------- 1 | (class_declaration 2 | name: (identifier) @name.definition.class) @definition.class 3 | 4 | (method_declaration 5 | name: (identifier) @name.definition.method) @definition.method 6 | 7 | (method_invocation 8 | name: (identifier) @name.reference.call 9 | arguments: (argument_list) @reference.call) 10 | 11 | (interface_declaration 12 | name: (identifier) @name.definition.interface) @definition.interface 13 | 14 | (type_list 15 | (type_identifier) @name.reference.implementation) @reference.implementation 16 | 17 | (object_creation_expression 18 | type: (type_identifier) @name.reference.class) @reference.class 19 | 20 | (superclass (type_identifier) @name.reference.class) @reference.class 21 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-javascript-tags.scm: -------------------------------------------------------------------------------- 1 | ( 2 | (comment)* @doc 3 | . 4 | (method_definition 5 | name: (property_identifier) @name.definition.method) @definition.method 6 | (#not-eq? @name.definition.method "constructor") 7 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 8 | (#select-adjacent! @doc @definition.method) 9 | ) 10 | 11 | ( 12 | (comment)* @doc 13 | . 14 | [ 15 | (class 16 | name: (_) @name.definition.class) 17 | (class_declaration 18 | name: (_) @name.definition.class) 19 | ] @definition.class 20 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 21 | (#select-adjacent! @doc @definition.class) 22 | ) 23 | 24 | ( 25 | (comment)* @doc 26 | . 27 | [ 28 | (function 29 | name: (identifier) @name.definition.function) 30 | (function_declaration 31 | name: (identifier) @name.definition.function) 32 | (generator_function 33 | name: (identifier) @name.definition.function) 34 | (generator_function_declaration 35 | name: (identifier) @name.definition.function) 36 | ] @definition.function 37 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 38 | (#select-adjacent! @doc @definition.function) 39 | ) 40 | 41 | ( 42 | (comment)* @doc 43 | . 44 | (lexical_declaration 45 | (variable_declarator 46 | name: (identifier) @name.definition.function 47 | value: [(arrow_function) (function)]) @definition.function) 48 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 49 | (#select-adjacent! @doc @definition.function) 50 | ) 51 | 52 | ( 53 | (comment)* @doc 54 | . 55 | (variable_declaration 56 | (variable_declarator 57 | name: (identifier) @name.definition.function 58 | value: [(arrow_function) (function)]) @definition.function) 59 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 60 | (#select-adjacent! @doc @definition.function) 61 | ) 62 | 63 | (assignment_expression 64 | left: [ 65 | (identifier) @name.definition.function 66 | (member_expression 67 | property: (property_identifier) @name.definition.function) 68 | ] 69 | right: [(arrow_function) (function)] 70 | ) @definition.function 71 | 72 | (pair 73 | key: (property_identifier) @name.definition.function 74 | value: [(arrow_function) (function)]) @definition.function 75 | 76 | ( 77 | (call_expression 78 | function: (identifier) @name.reference.call) @reference.call 79 | (#not-match? @name.reference.call "^(require)$") 80 | ) 81 | 82 | (call_expression 83 | function: (member_expression 84 | property: (property_identifier) @name.reference.call) 85 | arguments: (_) @reference.call) 86 | 87 | (new_expression 88 | constructor: (_) @name.reference.class) @reference.class 89 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-ocaml-tags.scm: -------------------------------------------------------------------------------- 1 | ; Modules 2 | ;-------- 3 | 4 | ( 5 | (comment)? @doc . 6 | (module_definition (module_binding (module_name) @name.definition.module) @definition.module) 7 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 8 | ) 9 | 10 | (module_path (module_name) @name.reference.module) @reference.module 11 | 12 | ; Modules types 13 | ;-------------- 14 | 15 | ( 16 | (comment)? @doc . 17 | (module_type_definition (module_type_name) @name.definition.interface) @definition.interface 18 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 19 | ) 20 | 21 | (module_type_path (module_type_name) @name.reference.implementation) @reference.implementation 22 | 23 | ; Functions 24 | ;---------- 25 | 26 | ( 27 | (comment)? @doc . 28 | (value_definition 29 | [ 30 | (let_binding 31 | pattern: (value_name) @name.definition.function 32 | (parameter)) 33 | (let_binding 34 | pattern: (value_name) @name.definition.function 35 | body: [(fun_expression) (function_expression)]) 36 | ] @definition.function 37 | ) 38 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 39 | ) 40 | 41 | ( 42 | (comment)? @doc . 43 | (external (value_name) @name.definition.function) @definition.function 44 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 45 | ) 46 | 47 | (application_expression 48 | function: (value_path (value_name) @name.reference.call)) @reference.call 49 | 50 | (infix_expression 51 | left: (value_path (value_name) @name.reference.call) 52 | (infix_operator) @reference.call 53 | (#eq? @reference.call "@@")) 54 | 55 | (infix_expression 56 | (infix_operator) @reference.call 57 | right: (value_path (value_name) @name.reference.call) 58 | (#eq? @reference.call "|>")) 59 | 60 | ; Operator 61 | ;--------- 62 | 63 | ( 64 | (comment)? @doc . 65 | (value_definition 66 | (let_binding 67 | pattern: (parenthesized_operator [ 68 | (prefix_operator) 69 | (infix_operator) 70 | (hash_operator) 71 | (indexing_operator) 72 | (let_operator) 73 | (and_operator) 74 | (match_operator) 75 | ] @name.definition.function)) @definition.function) 76 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 77 | ) 78 | 79 | [ 80 | (prefix_operator) 81 | (sign_operator) 82 | (infix_operator) 83 | (hash_operator) 84 | (indexing_operator) 85 | (let_operator) 86 | (and_operator) 87 | (match_operator) 88 | ] @name.reference.call @reference.call 89 | 90 | ; Classes 91 | ;-------- 92 | 93 | ( 94 | (comment)? @doc . 95 | [ 96 | (class_definition (class_binding (class_name) @name.definition.class) @definition.class) 97 | (class_type_definition (class_type_binding (class_type_name) @name.definition.class) @definition.class) 98 | ] 99 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 100 | ) 101 | 102 | [ 103 | (class_path (class_name) @name.reference.class) 104 | (class_type_path (class_type_name) @name.reference.class) 105 | ] @reference.class 106 | 107 | ; Methods 108 | ;-------- 109 | 110 | ( 111 | (comment)? @doc . 112 | (method_definition (method_name) @name.definition.method) @definition.method 113 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 114 | ) 115 | 116 | (method_invocation (method_name) @name.reference.call) @reference.call 117 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-php-tags.scm: -------------------------------------------------------------------------------- 1 | (class_declaration 2 | name: (name) @name.definition.class) @definition.class 3 | 4 | (function_definition 5 | name: (name) @name.definition.function) @definition.function 6 | 7 | (method_declaration 8 | name: (name) @name.definition.function) @definition.function 9 | 10 | (object_creation_expression 11 | [ 12 | (qualified_name (name) @name.reference.class) 13 | (variable_name (name) @name.reference.class) 14 | ]) @reference.class 15 | 16 | (function_call_expression 17 | function: [ 18 | (qualified_name (name) @name.reference.call) 19 | (variable_name (name)) @name.reference.call 20 | ]) @reference.call 21 | 22 | (scoped_call_expression 23 | name: (name) @name.reference.call) @reference.call 24 | 25 | (member_call_expression 26 | name: (name) @name.reference.call) @reference.call 27 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-python-tags.scm: -------------------------------------------------------------------------------- 1 | (class_definition 2 | name: (identifier) @name.definition.class) @definition.class 3 | 4 | (function_definition 5 | name: (identifier) @name.definition.function) @definition.function 6 | 7 | (call 8 | function: [ 9 | (identifier) @name.reference.call 10 | (attribute 11 | attribute: (identifier) @name.reference.call) 12 | ]) @reference.call 13 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-ql-tags.scm: -------------------------------------------------------------------------------- 1 | (classlessPredicate 2 | name: (predicateName) @name.definition.function) @definition.function 3 | 4 | (memberPredicate 5 | name: (predicateName) @name.definition.method) @definition.method 6 | 7 | (aritylessPredicateExpr 8 | name: (literalId) @name.reference.call) @reference.call 9 | 10 | (module 11 | name: (moduleName) @name.definition.module) @definition.module 12 | 13 | (dataclass 14 | name: (className) @name.definition.class) @definition.class 15 | 16 | (datatype 17 | name: (className) @name.definition.class) @definition.class 18 | 19 | (datatypeBranch 20 | name: (className) @name.definition.class) @definition.class 21 | 22 | (qualifiedRhs 23 | name: (predicateName) @name.reference.call) @reference.call 24 | 25 | (typeExpr 26 | name: (className) @name.reference.type) @reference.type 27 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-r-tags.scm: -------------------------------------------------------------------------------- 1 | (equals_assignment 2 | (identifier) @name 3 | (function_definition) @definition.function 4 | ) @definition.function 5 | 6 | (left_assignment 7 | (identifier) @name 8 | (function_definition) @definition.function 9 | ) @definition.function 10 | 11 | (call 12 | function: (identifier) @name 13 | ) @reference.call 14 | 15 | (call 16 | function: (binary 17 | left: (identifier) 18 | right: (identifier) @name 19 | ) 20 | ) @reference.call 21 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-ruby-tags.scm: -------------------------------------------------------------------------------- 1 | ; Method definitions 2 | 3 | ( 4 | (comment)* @doc 5 | . 6 | [ 7 | (method 8 | name: (_) @name.definition.method) @definition.method 9 | (singleton_method 10 | name: (_) @name.definition.method) @definition.method 11 | ] 12 | (#strip! @doc "^#\\s*") 13 | (#select-adjacent! @doc @definition.method) 14 | ) 15 | 16 | (alias 17 | name: (_) @name.definition.method) @definition.method 18 | 19 | (setter 20 | (identifier) @ignore) 21 | 22 | ; Class definitions 23 | 24 | ( 25 | (comment)* @doc 26 | . 27 | [ 28 | (class 29 | name: [ 30 | (constant) @name.definition.class 31 | (scope_resolution 32 | name: (_) @name.definition.class) 33 | ]) @definition.class 34 | (singleton_class 35 | value: [ 36 | (constant) @name.definition.class 37 | (scope_resolution 38 | name: (_) @name.definition.class) 39 | ]) @definition.class 40 | ] 41 | (#strip! @doc "^#\\s*") 42 | (#select-adjacent! @doc @definition.class) 43 | ) 44 | 45 | ; Module definitions 46 | 47 | ( 48 | (module 49 | name: [ 50 | (constant) @name.definition.module 51 | (scope_resolution 52 | name: (_) @name.definition.module) 53 | ]) @definition.module 54 | ) 55 | 56 | ; Calls 57 | 58 | (call method: (identifier) @name.reference.call) @reference.call 59 | 60 | ( 61 | [(identifier) (constant)] @name.reference.call @reference.call 62 | (#is-not? local) 63 | (#not-match? @name.reference.call "^(lambda|load|require|require_relative|__FILE__|__LINE__)$") 64 | ) 65 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-rust-tags.scm: -------------------------------------------------------------------------------- 1 | ; ADT definitions 2 | 3 | (struct_item 4 | name: (type_identifier) @name.definition.class) @definition.class 5 | 6 | (enum_item 7 | name: (type_identifier) @name.definition.class) @definition.class 8 | 9 | (union_item 10 | name: (type_identifier) @name.definition.class) @definition.class 11 | 12 | ; type aliases 13 | 14 | (type_item 15 | name: (type_identifier) @name.definition.class) @definition.class 16 | 17 | ; method definitions 18 | 19 | (declaration_list 20 | (function_item 21 | name: (identifier) @name.definition.method)) @definition.method 22 | 23 | ; function definitions 24 | 25 | (function_item 26 | name: (identifier) @name.definition.function) @definition.function 27 | 28 | ; trait definitions 29 | (trait_item 30 | name: (type_identifier) @name.definition.interface) @definition.interface 31 | 32 | ; module definitions 33 | (mod_item 34 | name: (identifier) @name.definition.module) @definition.module 35 | 36 | ; macro definitions 37 | 38 | (macro_definition 39 | name: (identifier) @name.definition.macro) @definition.macro 40 | 41 | ; references 42 | 43 | (call_expression 44 | function: (identifier) @name.reference.call) @reference.call 45 | 46 | (call_expression 47 | function: (field_expression 48 | field: (field_identifier) @name.reference.call)) @reference.call 49 | 50 | (macro_invocation 51 | macro: (identifier) @name.reference.call) @reference.call 52 | 53 | ; implementations 54 | 55 | (impl_item 56 | trait: (type_identifier) @name.reference.implementation) @reference.implementation 57 | 58 | (impl_item 59 | type: (type_identifier) @name.reference.implementation 60 | !trait) @reference.implementation 61 | -------------------------------------------------------------------------------- /motleycoder/queries/tree-sitter-typescript-tags.scm: -------------------------------------------------------------------------------- 1 | (function_signature 2 | name: (identifier) @name.definition.function) @definition.function 3 | 4 | (method_signature 5 | name: (property_identifier) @name.definition.method) @definition.method 6 | 7 | (abstract_method_signature 8 | name: (property_identifier) @name.definition.method) @definition.method 9 | 10 | (abstract_class_declaration 11 | name: (type_identifier) @name.definition.class) @definition.class 12 | 13 | (module 14 | name: (identifier) @name.definition.module) @definition.module 15 | 16 | (interface_declaration 17 | name: (type_identifier) @name.definition.interface) @definition.interface 18 | 19 | (type_annotation 20 | (type_identifier) @name.reference.type) @reference.type 21 | 22 | (new_expression 23 | constructor: (identifier) @name.reference.class) @reference.class 24 | 25 | (function_declaration 26 | name: (identifier) @name.definition.function) @definition.function 27 | 28 | (method_definition 29 | name: (property_identifier) @name.definition.method) @definition.method 30 | 31 | (class_declaration 32 | name: (type_identifier) @name.definition.class) @definition.class 33 | 34 | (interface_declaration 35 | name: (type_identifier) @name.definition.class) @definition.class 36 | 37 | (type_alias_declaration 38 | name: (type_identifier) @name.definition.type) @definition.type 39 | 40 | (enum_declaration 41 | name: (identifier) @name.definition.enum) @definition.enum 42 | -------------------------------------------------------------------------------- /motleycoder/repo.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path, PurePosixPath 2 | 3 | import git 4 | 5 | 6 | class GitRepo: 7 | def __init__(self, repo_path): 8 | self.repo = git.Repo(repo_path, search_parent_directories=True, odbt=git.GitDB) 9 | self.root = Path(self.repo.working_dir).resolve() 10 | 11 | def diff_commits(self, pretty, from_commit, to_commit): 12 | args = [] 13 | if pretty: 14 | args += ["--color"] 15 | 16 | args += [from_commit, to_commit] 17 | diffs = self.repo.git.diff(*args) 18 | 19 | return diffs 20 | 21 | def get_tracked_files(self): 22 | if not self.repo: 23 | return [] 24 | 25 | try: 26 | commit = self.repo.head.commit 27 | except ValueError: 28 | commit = None 29 | 30 | files = [] 31 | if commit: 32 | for blob in commit.tree.traverse(): 33 | if blob.type == "blob": # blob is a file 34 | files.append(blob.path) 35 | 36 | # Add staged files 37 | index = self.repo.index 38 | staged_files = [path for path, _ in index.entries.keys()] 39 | 40 | files.extend(staged_files) 41 | 42 | # convert to appropriate os.sep, since git always normalizes to / 43 | res = set(self.normalize_path(path) for path in files) 44 | 45 | return res 46 | 47 | def normalize_path(self, path): 48 | return str(Path(PurePosixPath((Path(self.root) / path).relative_to(self.root)))) 49 | 50 | def path_in_repo(self, path): 51 | if not self.repo: 52 | return 53 | 54 | tracked_files = set(self.get_tracked_files()) 55 | return self.normalize_path(path) in tracked_files 56 | 57 | def abs_root_path(self, path): 58 | res = Path(self.root) / path 59 | return str(res.resolve()) 60 | 61 | def get_dirty_files(self): 62 | """ 63 | Returns a list of all files which are dirty (not committed), either staged or in the working 64 | directory. 65 | """ 66 | dirty_files = set() 67 | 68 | # Get staged files 69 | staged_files = self.repo.git.diff("--name-only", "--cached").splitlines() 70 | dirty_files.update(staged_files) 71 | 72 | # Get unstaged files 73 | unstaged_files = self.repo.git.diff("--name-only").splitlines() 74 | dirty_files.update(unstaged_files) 75 | 76 | return list(dirty_files) 77 | 78 | def is_dirty(self, path=None): 79 | if path and not self.path_in_repo(path): 80 | return True 81 | 82 | return self.repo.is_dirty(path=path) 83 | -------------------------------------------------------------------------------- /motleycoder/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .add_files_tool import AddFilesTool 2 | from .file_edit_tool import FileEditTool 3 | from .inspect_entity_tool import InspectEntityTool 4 | from .return_to_user_tool import ReturnToUserTool 5 | -------------------------------------------------------------------------------- /motleycoder/tools/add_files_tool.py: -------------------------------------------------------------------------------- 1 | from langchain_core.pydantic_v1 import BaseModel, Field 2 | from langchain_core.tools import StructuredTool 3 | from motleycrew.common import logger 4 | from motleycrew.tools import MotleyTool 5 | 6 | from motleycoder.codemap.file_group import FileGroup 7 | from motleycoder.user_interface import UserInterface 8 | 9 | 10 | class AddFilesToolInput(BaseModel): 11 | files: list[str] = Field(description="List of file paths to add to the chat.") 12 | 13 | 14 | class AddFilesTool(MotleyTool): 15 | def __init__( 16 | self, 17 | file_group: FileGroup, 18 | user_interface: UserInterface, 19 | name: str = "add_files", 20 | ): 21 | self.file_group = file_group 22 | self.user_interface = user_interface 23 | 24 | langchain_tool = StructuredTool.from_function( 25 | func=self.add_files, 26 | name=name, 27 | description="""Add files to the list of files available for modification. 28 | Only files that are already in the list of files available for modification can be modified.""", 29 | args_schema=AddFilesToolInput, 30 | ) 31 | super().__init__(langchain_tool) 32 | 33 | def add_files(self, files: list[str]): 34 | added_files = [] 35 | for file in files: 36 | if not self.user_interface.confirm(f"Add {file} to the list of modifiable files?"): 37 | continue 38 | 39 | abs_filename = self.file_group.abs_root_path(file) 40 | logger.info(f"Trying to add to the list of modifiable files: {abs_filename}") 41 | 42 | content = self.read_text_file(abs_filename) 43 | if content is None: 44 | logger.error(f"Error reading {abs_filename}, skipping it.") 45 | continue 46 | self.file_group.add_for_modification(file) 47 | added_files.append(file) 48 | logger.info(f"Added {abs_filename} to the list of modifiable files.") 49 | 50 | if not added_files: 51 | return "No files were added to the list of modifiable files." 52 | else: 53 | return ( 54 | f"Added the following files to the list of modifiable files: {', '.join(added_files)}, " 55 | f"please use the `inspect_entity` tool to inspect them." 56 | ) 57 | 58 | # Should be using the inspect_object_tool instead 59 | # def make_files_content_prompt(self, files): 60 | # prompt = self.coder.gpt_prompts.files_content_prefix 61 | # for filename, content in self.get_files_content(files): 62 | # if not is_image_file(filename): 63 | # prompt += "\n" 64 | # prompt += filename 65 | # 66 | # prompt += f"\n```\n" 67 | # prompt += content 68 | # prompt += f"```\n" 69 | # 70 | # return prompt 71 | 72 | # def get_file_content(self, files: list[str]): 73 | # for filename in files: 74 | # abs_filename = self.coder.abs_root_path(filename) 75 | # content = self.read_text_file(abs_filename) 76 | # 77 | # if content is None: 78 | # logger.warning(f"Error reading {filename}, dropping it from the chat.") 79 | # self.coder.abs_fnames.remove(abs_filename) 80 | # else: 81 | # yield filename, content 82 | 83 | # TODO: move this to the FileGroup class 84 | def read_text_file(self, filename: str): 85 | try: 86 | with open(str(filename), "r", encoding="utf-8") as f: 87 | return f.read() 88 | except FileNotFoundError: 89 | logger.error(f"{filename}: file not found error") 90 | except IsADirectoryError: 91 | logger.error(f"{filename}: is a directory") 92 | return 93 | except UnicodeError as e: 94 | logger.error(f"{filename}: {e}") 95 | return 96 | -------------------------------------------------------------------------------- /motleycoder/tools/file_edit_tool.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from pathlib import Path 3 | from typing import TYPE_CHECKING, Optional 4 | 5 | from langchain_core.pydantic_v1 import BaseModel, Field 6 | from langchain_core.tools import StructuredTool 7 | 8 | from motleycoder.codemap.file_group import FileGroup 9 | from motleycoder.codemap.repomap import RepoMap 10 | from motleycoder.linter import Linter 11 | from motleycoder.prompts import MotleyCoderPrompts 12 | from motleycoder.user_interface import UserInterface 13 | from motleycrew.common import logger 14 | from motleycrew.tools import MotleyTool 15 | 16 | if TYPE_CHECKING: 17 | pass 18 | 19 | 20 | class FileEditToolInput(BaseModel): 21 | file_path: str = Field(description="The file path to edit.") 22 | language: str = Field(description="The programming language of the file.") 23 | search: str = Field(description="The SEARCH block.") 24 | replace: str = Field(description="The REPLACE block.") 25 | 26 | 27 | class FileEditTool(MotleyTool): 28 | def __init__( 29 | self, 30 | file_group: FileGroup, 31 | user_interface: UserInterface, 32 | repo_map: RepoMap, 33 | prompts: Optional[MotleyCoderPrompts] = None, 34 | linter: Optional[Linter] = None, 35 | name: str = "edit_file", 36 | ): 37 | # TODO: replace coder with specific components 38 | self.file_group = file_group 39 | self.user_interface = user_interface 40 | self.repo_map = repo_map 41 | 42 | self.prompts = prompts 43 | self.linter = linter 44 | 45 | langchain_tool = StructuredTool.from_function( 46 | func=self.edit_file, 47 | name=name, 48 | description="Make changes to a file using a *SEARCH/REPLACE* block.", 49 | args_schema=FileEditToolInput, 50 | ) 51 | super().__init__(langchain_tool) 52 | 53 | def edit_file(self, file_path: str, language: str, search: str, replace: str) -> str: 54 | error_message = self.edit_file_inner(file_path, search, replace) 55 | if error_message: # TODO: max_reflections 56 | return error_message 57 | 58 | if self.prompts: 59 | return self.prompts.file_edit_success.format(file_path=file_path) 60 | 61 | return f"Successfully edited file {file_path}." 62 | 63 | def prepare_file_for_edit(self, file_path: str): 64 | abs_path = self.file_group.abs_root_path(file_path) 65 | if abs_path not in self.file_group.files_for_modification: 66 | if not self.user_interface.confirm(f"Add {file_path} to the list of modifiable files?"): 67 | raise Exception( 68 | f"The user rejected adding {file_path} to the list of modifiable files." 69 | ) 70 | 71 | self.file_group.add_for_modification(file_path) 72 | 73 | if not Path(abs_path).exists(): 74 | if not self.user_interface.confirm(f"Allow creation of new file {file_path}?"): 75 | raise Exception(f"User rejected creation of new file {file_path}.") 76 | 77 | Path(abs_path).parent.mkdir(parents=True, exist_ok=True) 78 | Path(abs_path).touch() 79 | 80 | def invalidate_tag_graphs(self, file_path: str): 81 | if not self.repo_map.tag_graphs: 82 | return 83 | 84 | for files, graph in self.repo_map.tag_graphs.copy().items(): 85 | if file_path in files: 86 | self.repo_map.tag_graphs.pop(files) 87 | 88 | def edit_file_inner(self, file_path: str, search: str, replace: str) -> str: 89 | if not search or search[-1] != "\n": 90 | search += "\n" 91 | if not replace or replace[-1] != "\n": 92 | replace += "\n" 93 | 94 | logger.info( 95 | f"""Trying to edit file {file_path} 96 | <<<<<<< SEARCH 97 | {search}======= 98 | {replace}>>>>>>> REPLACE 99 | """ 100 | ) 101 | 102 | try: 103 | self.prepare_file_for_edit(file_path) 104 | except Exception as err: 105 | logger.error(f"Error preparing file for edit: {err}") 106 | return "Cannot edit file: " + str(err) 107 | 108 | try: 109 | # self.coder.dirty_commit() # Add the file to the repo if it's not already there 110 | result, close_match = self.file_group.edit_file(file_path, search, replace) 111 | # self.git_repo.commit_changes(f"Edit file {file_path}") 112 | except Exception as err: 113 | logger.warning("Exception while updating file:") 114 | logger.warning(str(err)) 115 | 116 | traceback.print_exc() 117 | return str(err) 118 | 119 | if not result: 120 | res = ( 121 | f"## SearchReplaceNoExactMatch: This SEARCH argument failed to exactly match " 122 | f"lines in {file_path}" 123 | ) 124 | if close_match: 125 | res += ( 126 | f"\nDid you mean to match some of these actual lines from {file_path}?\n" 127 | f"```\n{close_match}\n```" 128 | ) 129 | return res 130 | 131 | self.file_group.edited_files.add(file_path) 132 | self.invalidate_tag_graphs(file_path) 133 | 134 | if self.linter: 135 | errors = self.linter.lint(self.file_group.abs_root_path(file_path)) 136 | if errors: 137 | logger.error(f"Lint errors in {file_path}: {errors}") 138 | if self.user_interface.confirm("Attempt to fix lint errors?"): 139 | return errors 140 | 141 | 142 | if __name__ == "__main__": 143 | from motleycoder.codemap.repomap import RepoMap 144 | from motleycoder.repo import GitRepo 145 | 146 | repo_path = "/Users/whimo/codegen/motleycrew" 147 | 148 | repo = GitRepo(repo_path) 149 | file_group = FileGroup(repo) 150 | 151 | repo_map = RepoMap( 152 | root=repo.root, 153 | llm_name="gpt-4o", 154 | repo_content_prefix=None, 155 | file_group=file_group, 156 | cache_graphs=True, 157 | ) 158 | 159 | tool = FileEditTool( 160 | file_group=file_group, 161 | user_interface=UserInterface(yes=True), 162 | linter=Linter(), 163 | repo_map=repo_map, 164 | prompts=MotleyCoderPrompts(), 165 | ) 166 | print( 167 | tool.edit_file( 168 | file_path="motleycrew/agents/parent.py", 169 | language="python", 170 | search="""from motleycrew import MotleyCrew""", 171 | replace="""from motleycrew import MotleyCrew 172 | from a import b 173 | """, 174 | ) 175 | ) 176 | -------------------------------------------------------------------------------- /motleycoder/tools/get_full_text_tool.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from langchain_core.pydantic_v1 import BaseModel, Field 4 | from langchain_core.tools import StructuredTool 5 | from typing import List, Optional 6 | 7 | from motleycrew.common import logger 8 | from motleycrew.tools import MotleyTool 9 | 10 | from aider.codemap.repomap import RepoMap 11 | from aider.codemap.render import RenderCode 12 | from aider.codemap.tag import Tag 13 | 14 | if TYPE_CHECKING: 15 | from .motleycrew_coder import MotleyCrewCoder 16 | 17 | 18 | class GetFullTextToolInput(BaseModel): 19 | entity_name: Optional[str] = Field(description="Name of the entity to inspect.", default=None) 20 | file_name: Optional[str] = Field( 21 | description="Full or partial name of the file(s) to inspect", default=None 22 | ) 23 | first_line: Optional[int] = Field( 24 | description="First line of the code snippet to return", default=None 25 | ) 26 | 27 | 28 | class GetFullTextTool(MotleyTool): 29 | def __init__(self, repo_map: RepoMap): 30 | self.repo_map = repo_map 31 | self.requested_tags = set() 32 | 33 | langchain_tool = StructuredTool.from_function( 34 | func=self.get_full_text, 35 | name="Get_Full_Text", 36 | description=""""Get the full code text of the entity with a given name. 37 | Valid entities are function names, class names, 38 | method names prefixed with class, like `Foo.bar`. 39 | You can restrict your search to specific files by supplying the optional file_name argument. 40 | Do NOT use this tool to inspect whole files - use the inspect_entity tool for that. 41 | You MUST supply at least the entity_name. 42 | """, 43 | args_schema=GetFullTextToolInput, 44 | ) 45 | super().__init__(langchain_tool) 46 | 47 | def get_full_text( 48 | self, 49 | entity_name: Optional[str] = None, 50 | file_name: Optional[str] = None, 51 | first_line: Optional[int] = None, 52 | ) -> str: 53 | if entity_name is None: 54 | return "Please make sure to supply an entity name as an input to this tool" 55 | 56 | entity_name = entity_name.replace("()", "") 57 | 58 | if (entity_name, file_name) in self.requested_tags: 59 | return "You've already requested that one!" 60 | else: 61 | self.requested_tags.add((entity_name, file_name)) 62 | 63 | tag_graph = self.repo_map.get_tag_graph() 64 | 65 | re_tags = tag_graph.get_tags_from_entity_name(entity_name, file_name) 66 | 67 | if not re_tags: # maybe it was an explicit import? 68 | return f"Definition of entity {entity_name} not found in the repo" 69 | elif len(re_tags) == 1: 70 | return RenderCode.text_with_line_numbers(re_tags[0]) 71 | else: 72 | # Can get multiple tags eg when requesting a whole file 73 | if isinstance(first_line, int): 74 | sort = sorted(re_tags, key=lambda x: abs(x.line - first_line)) 75 | return RenderCode.text_with_line_numbers(sort[0]) 76 | 77 | return """Your query matches more than one entity, see the summary of the matches below. 78 | Please refine your query to match only one entity. 79 | """ + tag_graph.code_renderer.to_tree( 80 | re_tags 81 | ) 82 | -------------------------------------------------------------------------------- /motleycoder/tools/get_modifiable_files_tool.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain_core.pydantic_v1 import BaseModel 4 | from langchain_core.tools import StructuredTool 5 | 6 | from motleycoder.codemap.file_group import FileGroup 7 | from motleycrew.tools import MotleyTool 8 | 9 | 10 | class GetModifiableFilesToolInput(BaseModel): 11 | pass 12 | 13 | 14 | class GetModifiableFilesTool(MotleyTool): 15 | def __init__(self, file_group: FileGroup, name: str = "get_modifiable_files"): 16 | 17 | langchain_tool = StructuredTool.from_function( 18 | func=self.get_modifiable_files, 19 | name=name, 20 | description="Get the relative paths files that can be modified.", 21 | args_schema=GetModifiableFilesToolInput, 22 | ) 23 | super().__init__(langchain_tool) 24 | 25 | self.file_group = file_group 26 | 27 | def get_modifiable_files(self) -> List[str]: 28 | files = self.file_group.files_for_modification 29 | return [self.file_group.get_rel_fname(file) for file in files] 30 | -------------------------------------------------------------------------------- /motleycoder/tools/inspect_entity_tool.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from collections import deque 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from langchain_core.pydantic_v1 import BaseModel, Field 7 | from langchain_core.tools import StructuredTool 8 | 9 | from motleycoder.codemap.file_group import FileGroup 10 | from motleycoder.codemap.repomap import RepoMap 11 | from motleycoder.repo import GitRepo 12 | from motleycrew.tools import MotleyTool 13 | 14 | 15 | class InspectObjectToolInput(BaseModel): 16 | entity_name: Optional[str] = Field(description="Name of the entity to inspect.", default=None) 17 | file_name: Optional[str] = Field( 18 | description="Full or partial name of the file(s) to inspect", default=None 19 | ) 20 | 21 | 22 | class InspectEntityTool(MotleyTool): 23 | def __init__( 24 | self, 25 | repo_map: RepoMap, 26 | show_other_files: bool = False, 27 | max_lines_long=400, 28 | max_lines_short=50, 29 | block_identical_calls=2, 30 | ): 31 | self.repo_map = repo_map 32 | self.show_other_files = show_other_files 33 | self.max_lines_long = max_lines_long 34 | self.max_lines_short = max_lines_short 35 | 36 | self.requested_tags = deque(maxlen=block_identical_calls) 37 | 38 | langchain_tool = StructuredTool.from_function( 39 | func=self.get_object_summary, 40 | name="inspect_entity", 41 | description=""""Get the code of the entity with a given name, 42 | including summary of the entities it references. Valid entities 43 | are function names, class names, method names (prefix them by method name to disambiguate, like "Foo.bar") 44 | 45 | ONLY supply the file name/relative path if you need it to disambiguate the entity name, 46 | or if you want to inspect a whole file; in all other cases, just supply the entity name. 47 | You can also supply a partial file or directory name to get all files whose relative paths 48 | contain the partial name you supply. 49 | You can also request a whole file by name by omitting the entity name. 50 | """, 51 | args_schema=InspectObjectToolInput, 52 | ) 53 | super().__init__(langchain_tool) 54 | 55 | def get_object_summary( 56 | self, entity_name: Optional[str] = None, file_name: Optional[str] = None 57 | ) -> str: 58 | entity_name = entity_name or None 59 | file_name = file_name or None 60 | 61 | if entity_name is None and file_name is None: 62 | return "Please supply either the file name or the entity name" 63 | 64 | if entity_name is not None: 65 | entity_name = entity_name.strip().replace("()", "") 66 | 67 | if (entity_name, file_name) in self.requested_tags: 68 | return ( 69 | "You've already requested this entity recently, its contents are visible above. " 70 | "Please use existing information or request a different entity." 71 | ) 72 | else: 73 | self.requested_tags.append((entity_name, file_name)) 74 | 75 | tag_graph = self.repo_map.get_tag_graph(with_tests=True) 76 | 77 | if not entity_name: 78 | abs_file_path = self.repo_map.file_group.abs_root_path(file_name) 79 | try: 80 | file_content = Path(abs_file_path).read_text() 81 | except FileNotFoundError: 82 | return f"File {file_name} not found in the repo" 83 | except IsADirectoryError: 84 | files = self.repo_map.file_group.get_rel_fnames_in_directory( 85 | abs_file_path, level=None, with_tests=True 86 | ) 87 | return f"{file_name} is a directory. Files in it:\n{"\n".join(sorted(files))}" 88 | 89 | if not file_content: 90 | return f"File {file_name} is empty" 91 | 92 | return tag_graph.get_file_representation( 93 | file_name=abs_file_path, 94 | file_content=file_content, 95 | ) 96 | 97 | out = "" 98 | 99 | # TODO: if file_name is a directory, just list the files in it? 100 | re_tags = tag_graph.get_tags_from_entity_name(entity_name, file_name) 101 | if not len(re_tags) and entity_name is not None and "." in entity_name: 102 | entity_name_short = entity_name.split(".")[-1] 103 | out += f"Entity {entity_name} not found, searching for {entity_name_short}...\n" 104 | re_tags = tag_graph.get_tags_from_entity_name(entity_name_short, file_name) 105 | 106 | if not re_tags: # maybe it was an explicit import? 107 | if entity_name is not None: 108 | out += ( 109 | f"Definition of entity {entity_name} not found in the repo. " 110 | f"You can specify the entity name more broadly or omit it " 111 | f"for reading the whole file." 112 | ) 113 | if file_name is None or self.to_dir(file_name) is None: 114 | return out # Absolutely no directories to work with 115 | else: 116 | candidate_dirs = [self.to_dir(file_name)] 117 | else: 118 | return f"File {file_name} not found in the repo" 119 | 120 | elif all( 121 | tag.fname == re_tags[0].fname and tag.full_name == re_tags[0].full_name 122 | for tag in re_tags 123 | ): 124 | repr_parts = [ 125 | tag_graph.get_tag_representation( 126 | t, 127 | parent_details=True, 128 | max_lines=self.max_lines_long, 129 | force_include_full_text=True, 130 | ) 131 | for t in re_tags 132 | ] 133 | out += "\n".join(repr_parts) 134 | candidate_dirs = [self.to_dir(t.fname) for t in re_tags] 135 | else: # Can get multiple tags eg when requesting a whole file 136 | # TODO: this could be neater 137 | repr_parts = [ 138 | tag_graph.get_tag_representation( 139 | t, parent_details=False, max_lines=self.max_lines_short 140 | ) 141 | for t in re_tags 142 | ] 143 | repr = "\n".join(repr_parts) 144 | 145 | if len(repr.split("\n")) < self.max_lines_long: 146 | out += repr 147 | else: 148 | repr = tag_graph.code_renderer.to_tree(re_tags) 149 | if len(repr.split("\n")) < self.max_lines_long: 150 | out += repr 151 | else: 152 | fnames = sorted(list(set(t.rel_fname for t in re_tags))) 153 | 154 | out += ( 155 | "There are too many matches for the given query in the repo. " 156 | "You can narrow down the search by specifying the file and/or entity name " 157 | "more precisely. Here are the files that match the query:\n" 158 | ) 159 | out += "\n".join(fnames) 160 | 161 | candidate_dirs = list(set([self.to_dir(t.fname) for t in re_tags])) 162 | 163 | files = set( 164 | sum( 165 | [self.repo_map.file_group.get_rel_fnames_in_directory(d, with_tests=True) for d in candidate_dirs], 166 | [], 167 | ) 168 | ) 169 | 170 | mentioned_fnames = set([t.fname for t in re_tags]) 171 | other_fnames = files - mentioned_fnames 172 | if other_fnames and self.show_other_files: 173 | out += "\nOther files in same directory(s):\n" + "\n".join(sorted(list(other_fnames))) 174 | return out 175 | 176 | def to_dir(self, rel_fname: str) -> str: 177 | abs_dir = self.repo_map.file_group.abs_root_path(rel_fname) 178 | if os.path.isfile(abs_dir): 179 | abs_dir = os.path.dirname(abs_dir) 180 | if os.path.isdir(abs_dir): 181 | return abs_dir 182 | else: 183 | return None 184 | 185 | 186 | if __name__ == "__main__": 187 | from motleycoder.codemap.repomap import RepoMap 188 | 189 | repo_path = "/Users/whimo/codegen/motleycrew" 190 | 191 | repo = GitRepo(repo_path) 192 | file_group = FileGroup(repo) 193 | 194 | repo_map = RepoMap( 195 | root=repo.root, 196 | llm_name="gpt-4o", 197 | repo_content_prefix=None, 198 | file_group=file_group, 199 | cache_graphs=True, 200 | ) 201 | 202 | tool = InspectEntityTool(repo_map) 203 | print( 204 | tool.get_object_summary( 205 | file_name="tests" 206 | ) 207 | ) 208 | -------------------------------------------------------------------------------- /motleycoder/tools/return_to_user_tool.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | from motleycrew.agents import MotleyOutputHandler 4 | from motleycrew.common import Defaults 5 | from motleycrew.common.exceptions import InvalidOutput 6 | 7 | from motleycoder.user_interface import UserInterface 8 | 9 | 10 | class ReturnToUserTool(MotleyOutputHandler): 11 | _name = "return_to_user" 12 | 13 | def __init__( 14 | self, 15 | user_interface: UserInterface, 16 | tests_runner: Optional[Callable] = None, 17 | max_iterations: int = Defaults.DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS, 18 | ): 19 | self.user_interface = user_interface 20 | self.tests_runner = tests_runner 21 | super().__init__(max_iterations=max_iterations) 22 | 23 | self._iteration = 0 24 | 25 | def handle_output(self): 26 | self._iteration += 1 27 | 28 | out = self.tests_runner() 29 | if out is None: 30 | self._iteration = 0 31 | return "Tests passed!" 32 | elif self._iteration >= self.max_iterations: 33 | self._iteration = 0 34 | return "Maximum output handler iterations exceeded. Last test attempt failed:\n" + out 35 | else: 36 | if self.user_interface.confirm("Attempt to fix test errors?"): 37 | raise InvalidOutput("Existing tests failed:\n" + out) 38 | return "Last test attempt failed:\n" + out 39 | -------------------------------------------------------------------------------- /motleycoder/user_interface.py: -------------------------------------------------------------------------------- 1 | from motleycrew.common import logger 2 | 3 | class UserInterface: 4 | def __init__(self, yes: bool = False): 5 | self.yes = yes 6 | 7 | def confirm(self, message: str) -> bool: 8 | if self.yes: 9 | approved = True 10 | else: 11 | approved = input(f"{message} [y/n] ").lower().startswith("y") 12 | 13 | logger.info(f"{message} {"approved" if approved else "rejected"}") 14 | return approved 15 | -------------------------------------------------------------------------------- /motleycoder_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "da587728", 6 | "metadata": {}, 7 | "source": [ 8 | "# MotleyCoder: a set of tools and utilties for coding agents\n", 9 | "\n", 10 | "Getting an agent to write sensible code is a challenging task. Besides the concerns of correctness and efficiency, the agent must be able to interact with the code base in first place. This is where MotleyCoder comes in. MotleyCoder is a set of tools and utilities that help agents write better code.\n", 11 | "\n", 12 | "This notebook is a brief demo of MotleyCoder in action. You can regard the agent we use towards the end of this notebook as a ready-to-use AI coder. Feel free to customize it or create your own coding agent: MotleyCoder is built just for that.\n", 13 | "\n", 14 | "MotleyCoder consists of the following main elements:\n", 15 | "- `RepoMap`: provides the agent with an initial overview of the parts of the codebase relevant to the current task, so that the agent at least knows where to start.\n", 16 | "- `InspectEntityTool`: a tool given to the agent so it can inspect and navigate the codebase, read the code of specific entities or files, and list directories.\n", 17 | "- `FileEditTool`: a tool that allows editing code in a way an LLM can comprehend.\n", 18 | "\n", 19 | "Plain RAG does not work well with code, because the important connections between entities in it are simply ignored.\n", 20 | "MotleyCoder uses a combination of static code analysis and retrieval techniques to build a map of the codebase, with an emphasis on the parts relevant to the task. The map is then rendered into a view suitable for feeding into an LLM." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "c53923f0", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# We'll try out the code our agents write!\n", 31 | "%load_ext autoreload\n", 32 | "%autoreload 2" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "id": "initial_id", 39 | "metadata": { 40 | "ExecuteTime": { 41 | "end_time": "2024-08-21T18:05:29.393724Z", 42 | "start_time": "2024-08-21T18:05:29.277195Z" 43 | }, 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "from motleycoder.codemap.file_group import FileGroup\n", 49 | "from motleycoder.codemap.repomap import RepoMap\n", 50 | "from motleycoder.repo import GitRepo\n", 51 | "from motleycoder.tools import FileEditTool, InspectEntityTool\n", 52 | "from motleycoder.user_interface import UserInterface\n", 53 | "\n", 54 | "from motleycrew.agents.langchain import ReActToolCallingMotleyAgent\n", 55 | "from motleycrew.common import LLMFramework, LLMFamily\n", 56 | "from motleycrew.common.llms import init_llm\n", 57 | "from motleycrew.common import configure_logging\n", 58 | "\n", 59 | "configure_logging(verbose=True)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "a332f25c", 65 | "metadata": {}, 66 | "source": [ 67 | "Let's construct the essential parts of MotleyCoder." 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 5, 73 | "id": "f8d565d6f6c821d0", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "llm_name = \"gpt-4o\"\n", 78 | "\n", 79 | "repo = GitRepo(\"../motleycrew\") # The object responsible for interacting with the git repository\n", 80 | "\n", 81 | "file_group = FileGroup(repo) # Represents a group of files on the local disk we're dealing with\n", 82 | "repo_map = RepoMap( # Will provide the agent with a view of the repository\n", 83 | " root=repo.root,\n", 84 | " file_group=file_group,\n", 85 | " llm_name=llm_name,\n", 86 | ")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "id": "cbaa6d76fcb0b248", 93 | "metadata": { 94 | "ExecuteTime": { 95 | "end_time": "2024-08-21T18:03:43.790730Z", 96 | "start_time": "2024-08-21T18:03:43.790442Z" 97 | } 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "ui = UserInterface() # Interface for interacting with the user (in this case, via command line)\n", 102 | "\n", 103 | "inspect_entity_tool = InspectEntityTool( # Tool for inspecting entities in the code\n", 104 | " repo_map=repo_map\n", 105 | ")\n", 106 | "file_edit_tool = FileEditTool( # Tool for editing files\n", 107 | " file_group=file_group,\n", 108 | " user_interface=ui,\n", 109 | " repo_map=repo_map,\n", 110 | ")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "fafdbda0", 116 | "metadata": {}, 117 | "source": [ 118 | "`InspectEntityTool` and `FileEditTool` are tools that can be given to just about any agent that works with function calling models. We suggest using them with motleycrew's `ReActToolCallingAgent`.\n", 119 | "\n", 120 | "First, we'll build a trivial agent in that fashion. It will solve a simple task involving adding additional logging in one of motleycrew's classes." 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 7, 126 | "id": "38a8a2b135e24b5e", 127 | "metadata": { 128 | "ExecuteTime": { 129 | "end_time": "2024-08-21T18:03:43.792086Z", 130 | "start_time": "2024-08-21T18:03:43.792020Z" 131 | } 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "llm = init_llm(LLMFramework.LANGCHAIN, LLMFamily.OPENAI, llm_name=llm_name)\n", 136 | "agent = ReActToolCallingMotleyAgent(\n", 137 | " name=\"coder\",\n", 138 | " tools=[inspect_entity_tool, file_edit_tool],\n", 139 | " chat_history=False,\n", 140 | " llm=llm,\n", 141 | ")" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 8, 147 | "id": "43c9b8c6c6b3e550", 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stderr", 152 | "output_type": "stream", 153 | "text": [ 154 | "/Users/whimo/codegen/motleycoder/.venv/lib/python3.12/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", 155 | " return _methods._mean(a, axis=axis, dtype=dtype,\n", 156 | "/Users/whimo/codegen/motleycoder/.venv/lib/python3.12/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", 157 | " ret = ret.dtype.type(ret / rcount)\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "message = \"In the MotleyCrew class, add logging when a task unit is added to the graph.\"\n", 163 | "\n", 164 | "repo_map_str = repo_map.repo_map_from_message(message, llm=llm)\n", 165 | "\n", 166 | "prompt = f\"\"\"Here is an overview of the repo:\n", 167 | "{repo_map_str}\n", 168 | "\n", 169 | "{message}\"\"\"" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 9, 175 | "id": "3975f95521674d20", 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stderr", 180 | "output_type": "stream", 181 | "text": [ 182 | "2024-08-26 16:12:28,075 - motleycrew - INFO - Agent factory accepts output handler, passing it\n", 183 | "2024-08-26 16:12:32,879 - motleycrew - INFO - Trying to edit file motleycrew/crew/crew.py\n", 184 | "<<<<<<< SEARCH\n", 185 | "def add_task_unit_to_graph(self, task: Task, unit: TaskUnitType):\n", 186 | "=======\n", 187 | "def add_task_unit_to_graph(self, task: Task, unit: TaskUnitType):\n", 188 | " logger.info('Adding task unit %s to the graph for task %s', unit, task)\n", 189 | ">>>>>>> REPLACE\n", 190 | "\n", 191 | "2024-08-26 16:12:50,404 - motleycrew - INFO - Add motleycrew/crew/crew.py to the list of modifiable files? approved\n" 192 | ] 193 | }, 194 | { 195 | "data": { 196 | "text/plain": [ 197 | "'Final Answer: Logging has been added to the MotleyCrew class to log when a task unit is added to the graph. The `add_task_unit_to_graph` method now includes a logging statement that logs the task unit and the task to which it belongs.'" 198 | ] 199 | }, 200 | "execution_count": 9, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | } 204 | ], 205 | "source": [ 206 | "agent.invoke({\"prompt\": prompt})" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "id": "0c5b6044", 212 | "metadata": {}, 213 | "source": [ 214 | "The above example is trivial and involved no prompt engineering. We are yet to show you an important part of MotleyCoder: a refined set of prompts that make the agent much more robust and reliable.\n", 215 | "\n", 216 | "Also, using a linter is crucial for eliminating bad edits. MotleyCoder's built-in `Linter` class provides basic linting by parsing code using tree-sitter, and also advanced linting for Python using flake8. Adding custom linters for other languages is also easy.\n", 217 | "\n", 218 | "Another way to boost your agent's reliability right away is running tests after it's done with the edits. This is a natural usage pattern for motleycrew's output handler: the agent calls a special tool to inform that it has finished editing, and the tests are run inside that tool. If the tests fail, their output is fed back into the agent so it can fix them." 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "id": "8f91a318", 224 | "metadata": {}, 225 | "source": [ 226 | "## A reliable coding agent setup\n", 227 | "\n", 228 | "In this example, we'll showcase a MotleyCoder-based developer agent that can solve more complex tasks." 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 10, 234 | "id": "2a3b1e2a", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "from motleycoder.prompts import MotleyCoderPrompts\n", 239 | "from motleycoder.linter import Linter\n", 240 | "from motleycoder.tools import ReturnToUserTool\n", 241 | "\n", 242 | "from motleycrew.tasks import SimpleTask\n", 243 | "from motleycrew import MotleyCrew\n", 244 | "from langchain_core.tools import render_text_description" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 11, 250 | "id": "1ef48f1b", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "prompts = MotleyCoderPrompts()\n", 255 | "linter = Linter()" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "id": "2ed65e0e", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "file_edit_tool = FileEditTool(\n", 266 | " file_group=file_group,\n", 267 | " user_interface=ui,\n", 268 | " repo_map=repo_map,\n", 269 | " linter=linter,\n", 270 | " prompts=prompts,\n", 271 | ")\n", 272 | "\n", 273 | "inspect_entity_tool = InspectEntityTool( # Tool for inspecting entities in the code\n", 274 | " repo_map=repo_map\n", 275 | ")" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 13, 281 | "id": "15581863", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "import subprocess\n", 286 | "\n", 287 | "\n", 288 | "def run_tests():\n", 289 | " \"\"\"Run tests in the repository and return the output.\"\"\"\n", 290 | " work_dir = repo.root\n", 291 | " python_path = f\"{work_dir}/.venv/bin/python\"\n", 292 | " try:\n", 293 | " subprocess.run(\n", 294 | " [python_path, \"-m\", \"pytest\"],\n", 295 | " shell=False,\n", 296 | " check=True,\n", 297 | " cwd=work_dir,\n", 298 | " stdout=subprocess.PIPE,\n", 299 | " stderr=subprocess.PIPE,\n", 300 | " )\n", 301 | " print(\"Tests passed.\")\n", 302 | " return\n", 303 | " except subprocess.CalledProcessError as e:\n", 304 | " stdout = e.stdout.decode(\"utf-8\") if e.stdout else \"\"\n", 305 | " stderr = e.stderr.decode(\"utf-8\") if e.stderr else \"\"\n", 306 | " print(\"Tests failed:\")\n", 307 | " if stdout:\n", 308 | " print(\"STDOUT:\\n\" + stdout)\n", 309 | " if stderr:\n", 310 | " print(\"STDERR:\\n\" + stderr)\n", 311 | " return stdout + stderr" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "id": "06afea3d", 317 | "metadata": {}, 318 | "source": [ 319 | "The prompt prefix we give to the agent here provides various instructions for the agent, describes the tools it can use etc." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 14, 325 | "id": "56418c5d", 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "tools = [inspect_entity_tool, file_edit_tool]\n", 330 | "\n", 331 | "agent = ReActToolCallingMotleyAgent(\n", 332 | " name=\"coder\",\n", 333 | " tools=tools,\n", 334 | " prompt_prefix=prompts.prompt_template.partial(tools=render_text_description(tools)),\n", 335 | " chat_history=False,\n", 336 | " output_handler=ReturnToUserTool(user_interface=ui, tests_runner=run_tests),\n", 337 | " llm=llm,\n", 338 | ")" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "id": "29e0604f", 344 | "metadata": {}, 345 | "source": [ 346 | "We'll ask the agent to write an entire method in a Python class. This is a more complex task than the previous one, and the agent will need to understand the context of the class and the purpose of the method it's supposed to write." 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 15, 352 | "id": "a8d73225", 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "name": "stderr", 357 | "output_type": "stream", 358 | "text": [ 359 | "/Users/whimo/codegen/motleycoder/.venv/lib/python3.12/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", 360 | " return _methods._mean(a, axis=axis, dtype=dtype,\n", 361 | "/Users/whimo/codegen/motleycoder/.venv/lib/python3.12/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", 362 | " ret = ret.dtype.type(ret / rcount)\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "message = (\n", 368 | " \"In the Task class, add a method 'get_done_upstream_task_units' \"\n", 369 | " \"that lists all upstream task units whose status is 'done'.\"\n", 370 | ")\n", 371 | "\n", 372 | "repo_map_str = repo_map.repo_map_from_message(message, llm=llm)\n", 373 | "\n", 374 | "prompt = f\"\"\"Here is the overview of the repo:\n", 375 | "{repo_map_str}\n", 376 | "\n", 377 | "{message}\"\"\"" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 16, 383 | "id": "013e4ada", 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stderr", 388 | "output_type": "stream", 389 | "text": [ 390 | "2024-08-26 16:12:53,409 - motleycrew - INFO - No db_path provided, creating temporary directory for database\n", 391 | "2024-08-26 16:12:53,409 - motleycrew - INFO - Using Kuzu graph store with path: /var/folders/fv/tyhll76x0fn6l7j_q2nhvyg00000gn/T/tmpbde3lh7j/kuzu_db\n", 392 | "2024-08-26 16:12:53,422 - motleycrew - INFO - Node table TaskNode does not exist in the database, creating\n", 393 | "2024-08-26 16:12:53,427 - motleycrew - INFO - Property name not present in table for label TaskNode, creating\n", 394 | "2024-08-26 16:12:53,429 - motleycrew - INFO - Property done not present in table for label TaskNode, creating\n", 395 | "2024-08-26 16:12:53,430 - motleycrew - INFO - Node table SimpleTaskUnit does not exist in the database, creating\n", 396 | "2024-08-26 16:12:53,431 - motleycrew - INFO - Property status not present in table for label SimpleTaskUnit, creating\n", 397 | "2024-08-26 16:12:53,433 - motleycrew - INFO - Property output not present in table for label SimpleTaskUnit, creating\n", 398 | "2024-08-26 16:12:53,433 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[typing.Any], will use JSON string\n", 399 | "2024-08-26 16:12:53,435 - motleycrew - INFO - Property name not present in table for label SimpleTaskUnit, creating\n", 400 | "2024-08-26 16:12:53,436 - motleycrew - INFO - Property prompt not present in table for label SimpleTaskUnit, creating\n", 401 | "2024-08-26 16:12:53,438 - motleycrew - INFO - Property additional_params not present in table for label SimpleTaskUnit, creating\n", 402 | "2024-08-26 16:12:53,438 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[dict[str, typing.Any]], will use JSON string\n", 403 | "2024-08-26 16:12:53,440 - motleycrew - INFO - Relation table SimpleTaskUnit_belongs from SimpleTaskUnit to TaskNode does not exist in the database, creating\n", 404 | "2024-08-26 16:12:53,441 - motleycrew - INFO - Inserting new node with label TaskNode: name='Add method to Task class' done=False\n", 405 | "2024-08-26 16:12:53,445 - motleycrew - INFO - Node created OK\n", 406 | "2024-08-26 16:12:53,445 - motleycrew - INFO - Relation table task_is_upstream from TaskNode to TaskNode does not exist in the database, creating\n", 407 | "2024-08-26 16:12:53,457 - motleycrew - INFO - Available tasks: [SimpleTask(name=Add method to Task class, done=False)]\n", 408 | "2024-08-26 16:12:53,459 - motleycrew - INFO - Available tasks: [SimpleTask(name=Add method to Task class, done=False)]\n", 409 | "2024-08-26 16:12:53,459 - motleycrew - INFO - Processing task: SimpleTask(name=Add method to Task class, done=False)\n", 410 | "2024-08-26 16:12:53,463 - motleycrew - INFO - Got a matching unit for task SimpleTask(name=Add method to Task class, done=False)\n", 411 | "2024-08-26 16:12:53,464 - motleycrew - INFO - Processing unit: TaskUnit(status=pending)\n", 412 | "2024-08-26 16:12:53,464 - motleycrew - INFO - Assigned unit TaskUnit(status=pending) to agent ReActToolCallingMotleyAgent(name=coder), dispatching\n", 413 | "2024-08-26 16:12:53,464 - motleycrew - INFO - Adding task unit TaskUnit(status=running) to the graph for task SimpleTask(name=Add method to Task class, done=False)\n", 414 | "2024-08-26 16:12:53,464 - motleycrew - INFO - Node TaskUnit(status=running) does not exist, creating\n", 415 | "2024-08-26 16:12:53,465 - motleycrew - INFO - Inserting new node with label SimpleTaskUnit: TaskUnit(status=running)\n", 416 | "2024-08-26 16:12:53,465 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[typing.Any], will use JSON string\n", 417 | "2024-08-26 16:12:53,465 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[dict[str, typing.Any]], will use JSON string\n", 418 | "2024-08-26 16:12:53,467 - motleycrew - INFO - Node created OK\n", 419 | "2024-08-26 16:12:53,472 - motleycrew - INFO - Relation from TaskUnit(status=running) to name='Add method to Task class' done=False does not exist, creating\n", 420 | "2024-08-26 16:12:53,474 - motleycrew - INFO - Creating relation SimpleTaskUnit_belongs from SimpleTaskUnit:0 to TaskNode:0\n", 421 | "2024-08-26 16:12:53,477 - motleycrew - INFO - Relation created OK\n", 422 | "2024-08-26 16:12:53,477 - motleycrew - INFO - Agent factory accepts output handler, passing it\n", 423 | "2024-08-26 16:13:00,931 - motleycrew - INFO - Trying to edit file motleycrew/tasks/task.py\n", 424 | "<<<<<<< SEARCH\n", 425 | " def on_unit_dispatch(self, unit: TaskUnitType) -> None:\n", 426 | " \"\"\"Method that is called by the crew when a unit of the task is dispatched.\n", 427 | "\n", 428 | " Should be implemented by the subclass if needed.\n", 429 | "\n", 430 | " Args:\n", 431 | " unit: Task unit that is dispatched.\n", 432 | " \"\"\"\n", 433 | " pass\n", 434 | "=======\n", 435 | " def on_unit_dispatch(self, unit: TaskUnitType) -> None:\n", 436 | " \"\"\"Method that is called by the crew when a unit of the task is dispatched.\n", 437 | "\n", 438 | " Should be implemented by the subclass if needed.\n", 439 | "\n", 440 | " Args:\n", 441 | " unit: Task unit that is dispatched.\n", 442 | " \"\"\"\n", 443 | " pass\n", 444 | "\n", 445 | " def get_done_upstream_task_units(self) -> List[TaskUnitType]:\n", 446 | " \"\"\"Get all upstream task units whose status is 'done'.\n", 447 | "\n", 448 | " Returns:\n", 449 | " List of done upstream task units.\n", 450 | " \"\"\"\n", 451 | " done_units = []\n", 452 | " upstream_tasks = self.get_upstream_tasks()\n", 453 | " for task in upstream_tasks:\n", 454 | " done_units.extend(task.get_units(status='done'))\n", 455 | " return done_units\n", 456 | "\n", 457 | "\n", 458 | "\n", 459 | ">>>>>>> REPLACE\n", 460 | "\n", 461 | "2024-08-26 16:13:04,140 - motleycrew - INFO - Add motleycrew/tasks/task.py to the list of modifiable files? approved\n", 462 | "2024-08-26 16:13:12,816 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[typing.Any], will use JSON string\n", 463 | "2024-08-26 16:13:12,822 - motleycrew - INFO - Task unit TaskUnit(status=running) completed, marking as done\n", 464 | "2024-08-26 16:13:12,828 - motleycrew - INFO - Available tasks: []\n", 465 | "2024-08-26 16:13:12,830 - motleycrew - INFO - Available tasks: []\n", 466 | "2024-08-26 16:13:12,830 - motleycrew - INFO - Nothing left to do, exiting\n" 467 | ] 468 | }, 469 | { 470 | "name": "stdout", 471 | "output_type": "stream", 472 | "text": [ 473 | "Tests passed.\n" 474 | ] 475 | }, 476 | { 477 | "data": { 478 | "text/plain": [ 479 | "[TaskUnit(status=done)]" 480 | ] 481 | }, 482 | "execution_count": 16, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | } 486 | ], 487 | "source": [ 488 | "crew = MotleyCrew()\n", 489 | "\n", 490 | "task = SimpleTask(\n", 491 | " name=\"Add method to Task class\",\n", 492 | " description=prompt,\n", 493 | " crew=crew,\n", 494 | " agent=agent,\n", 495 | ")\n", 496 | "\n", 497 | "crew.run()" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "id": "d3aa3e0e", 503 | "metadata": {}, 504 | "source": [ 505 | "If you look closely at the logs, you'll see the log lines the agent added in the previous example :)\n", 506 | "\n", 507 | "Now let's create a task to write a test for the method the agent just wrote. As a bonus, we can try out our new method at the same time!\n", 508 | "\n", 509 | "The new test will of course also be executed in the output handler. This allows for a very tight feedback loop, where the agent can immediately see if the test fails and fix it." 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 17, 515 | "id": "1615ad2b", 516 | "metadata": {}, 517 | "outputs": [ 518 | { 519 | "name": "stderr", 520 | "output_type": "stream", 521 | "text": [ 522 | "2024-08-26 16:13:12,915 - motleycrew - INFO - Inserting new node with label TaskNode: name='Add test' done=False\n", 523 | "2024-08-26 16:13:12,916 - motleycrew - INFO - Node created OK\n", 524 | "2024-08-26 16:13:12,919 - motleycrew - INFO - Creating relation task_is_upstream from TaskNode:0 to TaskNode:1\n", 525 | "2024-08-26 16:13:12,923 - motleycrew - INFO - Relation created OK\n" 526 | ] 527 | }, 528 | { 529 | "name": "stdout", 530 | "output_type": "stream", 531 | "text": [ 532 | "[TaskUnit(status=done)]\n" 533 | ] 534 | } 535 | ], 536 | "source": [ 537 | "test_task = SimpleTask(\n", 538 | " name=\"Add test\",\n", 539 | " description=\"Add a test for the 'get_done_upstream_task_units' method in the appropriate place.\",\n", 540 | " crew=crew,\n", 541 | " agent=agent,\n", 542 | ")\n", 543 | "test_task.set_upstream(task)\n", 544 | "\n", 545 | "print(test_task.get_done_upstream_task_units()) # Let's try out the new method!" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "id": "8d3ae4f2", 551 | "metadata": {}, 552 | "source": [ 553 | "The new method seems to be working! Now let's kickoff the test writing task." 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 18, 559 | "id": "f47c7c48", 560 | "metadata": {}, 561 | "outputs": [ 562 | { 563 | "name": "stderr", 564 | "output_type": "stream", 565 | "text": [ 566 | "2024-08-26 16:13:12,943 - motleycrew - INFO - Available tasks: [SimpleTask(name=Add test, done=False)]\n", 567 | "2024-08-26 16:13:12,945 - motleycrew - INFO - Available tasks: [SimpleTask(name=Add test, done=False)]\n", 568 | "2024-08-26 16:13:12,945 - motleycrew - INFO - Processing task: SimpleTask(name=Add test, done=False)\n", 569 | "2024-08-26 16:13:12,950 - motleycrew - INFO - Got a matching unit for task SimpleTask(name=Add test, done=False)\n", 570 | "2024-08-26 16:13:12,950 - motleycrew - INFO - Processing unit: TaskUnit(status=pending)\n", 571 | "2024-08-26 16:13:12,950 - motleycrew - INFO - Assigned unit TaskUnit(status=pending) to agent ReActToolCallingMotleyAgent(name=coder), dispatching\n", 572 | "2024-08-26 16:13:12,950 - motleycrew - INFO - Adding task unit TaskUnit(status=running) to the graph for task SimpleTask(name=Add test, done=False)\n", 573 | "2024-08-26 16:13:12,951 - motleycrew - INFO - Node TaskUnit(status=running) does not exist, creating\n", 574 | "2024-08-26 16:13:12,951 - motleycrew - INFO - Inserting new node with label SimpleTaskUnit: TaskUnit(status=running)\n", 575 | "2024-08-26 16:13:12,951 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[typing.Any], will use JSON string\n", 576 | "2024-08-26 16:13:12,951 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[dict[str, typing.Any]], will use JSON string\n", 577 | "2024-08-26 16:13:12,953 - motleycrew - INFO - Node created OK\n", 578 | "2024-08-26 16:13:12,958 - motleycrew - INFO - Relation from TaskUnit(status=running) to name='Add test' done=False does not exist, creating\n", 579 | "2024-08-26 16:13:12,960 - motleycrew - INFO - Creating relation SimpleTaskUnit_belongs from SimpleTaskUnit:1 to TaskNode:1\n", 580 | "2024-08-26 16:13:12,962 - motleycrew - INFO - Relation created OK\n", 581 | "2024-08-26 16:13:25,471 - motleycrew - INFO - Trying to edit file tests/test_tasks/test_task.py\n", 582 | "<<<<<<< SEARCH\n", 583 | "class TestTask:\n", 584 | " def test_set_done(self, task_1):\n", 585 | " assert not task_1.done\n", 586 | " assert not task_1.node.done\n", 587 | " task_1.set_done()\n", 588 | " assert task_1.done\n", 589 | " assert task_1.node.done\n", 590 | "=======\n", 591 | "class TestTask:\n", 592 | " def test_set_done(self, task_1):\n", 593 | " assert not task_1.done\n", 594 | " assert not task_1.node.done\n", 595 | " task_1.set_done()\n", 596 | " assert task_1.done\n", 597 | " assert task_1.node.done\n", 598 | "\n", 599 | " def test_get_done_upstream_task_units(self, task_1, task_2):\n", 600 | " # Mock the get_upstream_tasks method to return a list of tasks\n", 601 | " task_1.get_upstream_tasks = lambda: [task_2]\n", 602 | " # Mock the get_units method of task_2 to return a list of done units\n", 603 | " task_2.get_units = lambda status: ['unit1', 'unit2'] if status == 'done' else []\n", 604 | " done_units = task_1.get_done_upstream_task_units()\n", 605 | " assert done_units == ['unit1', 'unit2']\n", 606 | ">>>>>>> REPLACE\n", 607 | "\n", 608 | "2024-08-26 16:13:29,203 - motleycrew - INFO - Add tests/test_tasks/test_task.py to the list of modifiable files? approved\n", 609 | "2024-08-26 16:13:38,051 - motleycrew - WARNING - No known Cypher type matching annotation typing.Optional[typing.Any], will use JSON string\n", 610 | "2024-08-26 16:13:38,055 - motleycrew - INFO - Task unit TaskUnit(status=running) completed, marking as done\n", 611 | "2024-08-26 16:13:38,061 - motleycrew - INFO - Available tasks: []\n", 612 | "2024-08-26 16:13:38,063 - motleycrew - INFO - Available tasks: []\n", 613 | "2024-08-26 16:13:38,063 - motleycrew - INFO - Nothing left to do, exiting\n" 614 | ] 615 | }, 616 | { 617 | "name": "stdout", 618 | "output_type": "stream", 619 | "text": [ 620 | "Tests passed.\n" 621 | ] 622 | }, 623 | { 624 | "data": { 625 | "text/plain": [ 626 | "[TaskUnit(status=done)]" 627 | ] 628 | }, 629 | "execution_count": 18, 630 | "metadata": {}, 631 | "output_type": "execute_result" 632 | } 633 | ], 634 | "source": [ 635 | "crew.run()" 636 | ] 637 | } 638 | ], 639 | "metadata": { 640 | "kernelspec": { 641 | "display_name": "Python 3", 642 | "language": "python", 643 | "name": "python3" 644 | }, 645 | "language_info": { 646 | "codemirror_mode": { 647 | "name": "ipython", 648 | "version": 3 649 | }, 650 | "file_extension": ".py", 651 | "mimetype": "text/x-python", 652 | "name": "python", 653 | "nbconvert_exporter": "python", 654 | "pygments_lexer": "ipython3", 655 | "version": "3.12.1" 656 | } 657 | }, 658 | "nbformat": 4, 659 | "nbformat_minor": 5 660 | } 661 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "motleycoder" 3 | version = "0.1.0" 4 | description = "A set of tools for building AI coders" 5 | authors = ["MotleyCrew "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.10,<=3.13" 10 | motleycrew = "^0.1.12" 11 | tree-sitter = "0.21.3" 12 | tree-sitter-languages = "^1.10.2" 13 | grep-ast = "^0.3.3" 14 | litellm = "^1.44.5" 15 | GitPython = "^3.1.43" 16 | rich = "^13.7.1" 17 | diskcache = "^5.6.3" 18 | networkx = "^3.3" 19 | flake8 = "^7.1.1" 20 | 21 | 22 | [build-system] 23 | requires = ["poetry-core"] 24 | build-backend = "poetry.core.masonry.api" 25 | --------------------------------------------------------------------------------