├── .gitignore ├── LICENSE ├── README.md ├── interfaces └── i_main.cairo ├── protostar.toml ├── pyproject.toml ├── src ├── cairo_toolkit │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── generator.cpython-39.pyc │ │ ├── interface_parser.cpython-39.pyc │ │ ├── order_imports.cpython-39.pyc │ │ └── utils.cpython-39.pyc │ ├── generator.py │ ├── interface_parser.py │ ├── order_imports.py │ └── utils.py ├── cli.py └── logic.py └── test ├── main.cairo ├── main_imports_test.cairo ├── nested_test └── main_imports_test_nested.cairo └── types.cairo /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | build 3 | src/__pycache__ 4 | src/cairo_toolkit/__pycache__ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Mathieu Saugier 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cairo-toolkit 2 | 3 | A set of useful tools for cairo / starknet development. 4 | 5 | - Generate / check the interfaces corresponding to your Starknet contracts. 6 | - Easily order your imports 7 | 8 | ## Installation 9 | 10 | `pip install cairo-toolkit` 11 | 12 | ## Usage 13 | 14 | ``` 15 | cairo-toolkit [OPTIONS] COMMAND [ARGS]... 16 | 17 | Options: 18 | --version 19 | --help Show this message and exit. 20 | 21 | Commands: 22 | check-interface 23 | generate-interface 24 | order-imports 25 | ``` 26 | 27 | ### Generate interfaces 28 | 29 | ``` 30 | Usage: cairo-toolkit generate-interface [OPTIONS] 31 | 32 | Options: 33 | -f, --files TEXT File paths 34 | -p, --protostar Uses `protostar.toml` to get file paths 35 | -d, --directory TEXT Output directory for the interfaces. If unspecified, 36 | they will be created in the same directory as the 37 | contracts 38 | --help Show this message and exit. 39 | ``` 40 | 41 | ### Check existing interfaces 42 | 43 | ``` 44 | Usage: cairo-toolkit check-interface [OPTIONS] 45 | 46 | Options: 47 | --files TEXT Contracts to check 48 | -p, --protostar Uses `protostar.toml` to get file paths 49 | -d, --directory TEXT Directory of the interfaces to check. Interfaces must 50 | be named `i_.cairo` 51 | --help Show this message and exit. 52 | ``` 53 | 54 | ### Ordering imports in existing file 55 | 56 | ``` 57 | Usage: cairo-toolkit order-imports [OPTIONS] 58 | 59 | Options: 60 | -d, --directory TEXT Directory with cairo files to format 61 | -f, --files TEXT File paths 62 | -i, --imports TEXT Imports order 63 | --help Show this message and exit. 64 | ``` 65 | 66 | ## Example 67 | 68 | Generate interfaces for the contracts in `contracts/` and put them in `interfaces/`: 69 | 70 | ``` 71 | find contracts/ -iname '*.cairo' -exec cairo-toolkit generate-interface --files {} \; 72 | ``` 73 | 74 | Check the interface for `test/main.cairo` against the interface `i_main.cairo` in interfaces/: 75 | 76 | ``` 77 | cairo-toolkit check-interface --files test/main.cairo -d interfaces 78 | ``` 79 | 80 | Order imports for all cairo files under `test` 81 | 82 | ``` 83 | cairo-toolkit order-imports -d test 84 | ``` 85 | 86 | ## Protostar 87 | 88 | You can use cairo-toolkit in a protostar project. 89 | This can be paired with a github action to automatically generate the interfaces for the contracts 90 | that specified inside the `protostar.toml` file. 91 | -------------------------------------------------------------------------------- /interfaces/i_main.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from test.main import MyStruct 4 | from test.types import ImportedStruct 5 | 6 | @contract_interface 7 | namespace IMain { 8 | func struct_in_arg(amount: felt, _struct: MyStruct, array_len: felt, array: felt*) { 9 | } 10 | 11 | func struct_ptr_in_return() -> ( 12 | res_len: felt, res: felt*, arr_len: felt, arr: ImportedStruct* 13 | ) { 14 | } 15 | 16 | func tuple_in_signature(tuple: (felt, felt)) { 17 | } 18 | 19 | func implicit_type(untyped_arg) { 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /protostar.toml: -------------------------------------------------------------------------------- 1 | ["protostar.config"] 2 | protostar_version = "0.4.2" 3 | 4 | ["protostar.project"] 5 | libs_path = "lib" 6 | 7 | ["protostar.contracts"] 8 | main = [ 9 | "test/main.cairo", 10 | ] 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cairo-toolkit" 3 | version = "1.0.1" 4 | description = "A set of useful tools for cairo / starknet development." 5 | authors = ["msaug "] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [{include = "cairo_toolkit", from = "src"},{include="src"}] 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.9" 12 | cairo-lang = "^0.10.0" 13 | toml = "0.10.2" 14 | click = "8.1.3" 15 | 16 | 17 | [build-system] 18 | requires = ["poetry-core"] 19 | build-backend = "poetry.core.masonry.api" 20 | 21 | [tool.poetry.scripts] 22 | cairo-toolkit = 'src.cli:main' -------------------------------------------------------------------------------- /src/cairo_toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enitrat/cairo-toolkit/9a17f94dd5acb4fcf5257ea737cf4ad0d5c0b076/src/cairo_toolkit/__init__.py -------------------------------------------------------------------------------- /src/cairo_toolkit/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enitrat/cairo-toolkit/9a17f94dd5acb4fcf5257ea737cf4ad0d5c0b076/src/cairo_toolkit/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/cairo_toolkit/__pycache__/generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enitrat/cairo-toolkit/9a17f94dd5acb4fcf5257ea737cf4ad0d5c0b076/src/cairo_toolkit/__pycache__/generator.cpython-39.pyc -------------------------------------------------------------------------------- /src/cairo_toolkit/__pycache__/interface_parser.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enitrat/cairo-toolkit/9a17f94dd5acb4fcf5257ea737cf4ad0d5c0b076/src/cairo_toolkit/__pycache__/interface_parser.cpython-39.pyc -------------------------------------------------------------------------------- /src/cairo_toolkit/__pycache__/order_imports.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enitrat/cairo-toolkit/9a17f94dd5acb4fcf5257ea737cf4ad0d5c0b076/src/cairo_toolkit/__pycache__/order_imports.cpython-39.pyc -------------------------------------------------------------------------------- /src/cairo_toolkit/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enitrat/cairo-toolkit/9a17f94dd5acb4fcf5257ea737cf4ad0d5c0b076/src/cairo_toolkit/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/cairo_toolkit/generator.py: -------------------------------------------------------------------------------- 1 | from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction, CodeBlock, CodeElementImport 2 | from starkware.cairo.lang.compiler.ast.visitor import Visitor 3 | 4 | from cairo_toolkit.utils import to_camel_case 5 | 6 | 7 | class Generator(Visitor): 8 | """ 9 | Generates an interface from a Cairo contract. 10 | """ 11 | 12 | def __init__(self, contract_dir: str, contract_name: str): 13 | super().__init__() 14 | self.contract_dir = contract_dir 15 | self.contract_name = contract_name 16 | self.imports = {} 17 | self.required_import_paths = [] 18 | self.functions = "" 19 | 20 | def generate_contract_interface(self, module): 21 | self.visit(module) 22 | interface = \ 23 | "%lang starknet\n\n" + \ 24 | "\n".join(self.required_import_paths) + "\n\n" \ 25 | "@contract_interface\n" + \ 26 | f"namespace I{to_camel_case(self.contract_name)}{{\n" \ 27 | f"{self.functions} \n" \ 28 | "}" 29 | return interface 30 | 31 | def parse_imports(self, elm: CodeBlock): 32 | # We want to keep track of the file imports so that we can import types inside the contract interface 33 | for x in elm.code_elements: 34 | if isinstance(x.code_elm, CodeElementImport): 35 | path = x.code_elm.path.name 36 | imported_items = x.code_elm.import_items 37 | for item in imported_items: 38 | self.imports[item.orig_identifier.name] = path 39 | 40 | def parse_function_signature(self, elm: CodeElementFunction): 41 | # We only visit proper functions decorated with 'external' or 'view'. 42 | need_instrumentation = any(decorator.name in [ 43 | "external", "view"] for decorator in elm.decorators) 44 | 45 | if not need_instrumentation: 46 | return 47 | 48 | # func name 49 | fn_signature = f"func {elm.name}(" 50 | 51 | # func arguments 52 | for i, arg in enumerate(elm.arguments.identifiers): 53 | arg_type = '' 54 | if (arg.expr_type == None): 55 | arg_type = 'felt' 56 | else: 57 | arg_type = arg.expr_type.format().replace('*', '') 58 | 59 | self.parse_type(arg_type) 60 | 61 | fn_signature += f"{arg.format()}" 62 | if i != len(elm.arguments.identifiers) - 1: 63 | fn_signature += "," 64 | fn_signature += ")" 65 | 66 | # func return values 67 | if elm.returns != None: 68 | fn_signature += " -> " 69 | fn_signature += elm.returns.format() 70 | 71 | # non-felt return types need to be imported 72 | return_elems = elm.returns.get_children() 73 | for elem in return_elems: 74 | type = elem.typ.format().replace('*', '') 75 | if type != 'felt': 76 | self.add_import_path(type) 77 | 78 | fn_signature += "{\n}\n\n" 79 | 80 | self.functions += fn_signature 81 | 82 | def parse_type(self, type: str): 83 | type = type.replace('*', '') 84 | if ('(' in type): 85 | type = type.replace('(', '') 86 | type = type.replace(')', '') 87 | tuple = type.split(',') 88 | for tuple_type in tuple: 89 | stripped = tuple_type.strip() 90 | self.parse_type(stripped) 91 | # non-felt types need to be imported 92 | elif type != 'felt': 93 | self.add_import_path(type) 94 | 95 | def add_import_path(self, arg_type: str): 96 | # If we have imported types, we need to add the import path to our interface 97 | # If we use namespace, we want to import the namespace and not the type itself. 98 | # if the type comes from a namespace, we only import the namespace 99 | import_name = arg_type.split('.')[0] 100 | import_path = self.imports.get( 101 | import_name) or f"{self.contract_dir.replace('/', '.')}.{self.contract_name}" # this is a bad practice, when the type is directly declared in the contract. 102 | import_statement = f"from {import_path} import {import_name}" 103 | if import_statement in self.required_import_paths: 104 | return 105 | self.required_import_paths.append(import_statement) 106 | 107 | def _visit_default(self, obj): 108 | # top-level code is not generated 109 | return obj 110 | 111 | def visit_CodeElementFunction(self, elm: CodeElementFunction): 112 | self.parse_function_signature(elm) 113 | return super().visit_CodeElementFunction(elm) 114 | 115 | def visit_CodeBlock(self, elm: CodeBlock): 116 | self.parse_imports(elm) 117 | return super().visit_CodeBlock(elm) 118 | -------------------------------------------------------------------------------- /src/cairo_toolkit/interface_parser.py: -------------------------------------------------------------------------------- 1 | from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction, CodeBlock, CodeElementImport 2 | from starkware.cairo.lang.compiler.ast.visitor import Visitor 3 | 4 | 5 | class InterfaceParser(Visitor): 6 | """ 7 | Parses a Cairo interface. Call `parse_interface` to return the interface as a dictionary. 8 | """ 9 | 10 | def __init__(self, contract_name: str): 11 | super().__init__() 12 | self.contract_name = contract_name 13 | self.imports = {} # map import_name => import_path 14 | # list of required import statements for the interface 15 | self.required_import_paths = [] 16 | self.functions = {} # TODO should be a list 17 | self.namespace_name = '' 18 | 19 | def parse_interface(self, module): 20 | self.visit(module) 21 | return { 22 | 'name': self.namespace_name, 23 | 'functions': self.functions, 24 | 'imports': self.imports 25 | } 26 | 27 | def parse_imports(self, elm: CodeBlock): 28 | # We want to keep track of the file imports so that we can import types inside the contract interface 29 | for x in elm.code_elements: 30 | if isinstance(x.code_elm, CodeElementImport): 31 | path = x.code_elm.path.name 32 | imported_items = x.code_elm.import_items 33 | for item in imported_items: 34 | self.imports[item.orig_identifier.name] = path 35 | 36 | def parse_code_elm_function(self, elm: CodeElementFunction): 37 | 38 | if elm.element_type == 'namespace': 39 | self.parse_namespace(elm) 40 | return 41 | 42 | # We only visit proper functions decorated with 'external' or 'view'. 43 | self.parse_function(elm) 44 | return 45 | 46 | def parse_namespace(self, elm): 47 | namespace_name = elm.name 48 | self.namespace_name = namespace_name 49 | 50 | def parse_function(self, elm: CodeElementFunction): 51 | 52 | fn_params = [] 53 | fn_returns = [] 54 | # func arguments 55 | for arg in elm.arguments.identifiers: 56 | fn_params.append(arg.format()) 57 | 58 | # func return values 59 | if elm.returns != None: 60 | return_elems = elm.returns.get_children() 61 | for elem in return_elems: 62 | str = f"{elem.name}: {elem.typ.format()}" 63 | fn_returns.append(str) 64 | 65 | self.functions[elm.name] = {'params': fn_params, 'returns': fn_returns} 66 | 67 | def _visit_default(self, obj): 68 | # top-level code is not generated 69 | return obj 70 | 71 | def visit_CodeElementFunction(self, elm: CodeElementFunction): 72 | self.parse_code_elm_function(elm) 73 | return super().visit_CodeElementFunction(elm) 74 | 75 | def visit_CodeBlock(self, elm: CodeBlock): 76 | self.parse_imports(elm) 77 | return super().visit_CodeBlock(elm) 78 | -------------------------------------------------------------------------------- /src/cairo_toolkit/order_imports.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import allow_connection_pickling 2 | from symbol import import_name 3 | from typing import List, Dict 4 | from collections import OrderedDict 5 | from numpy import isin 6 | from starkware.cairo.lang.compiler.ast.code_elements import CodeBlock, CodeElementImport, CodeElementEmptyLine, CommentedCodeElement, CodeElementDirective 7 | from starkware.cairo.lang.compiler.ast.visitor import Visitor 8 | 9 | class OrderImports(Visitor): 10 | """ 11 | Orders imports in Cairo files 12 | """ 13 | 14 | def __init__(self, import_order_names: List[str]): 15 | super().__init__() 16 | self.import_order_names=import_order_names 17 | 18 | def _visit_default(self, obj): 19 | # top-level code is not generated 20 | return obj 21 | 22 | def visit_CodeBlock(self, elm: CodeBlock): 23 | return self.extract_imports(elm) 24 | 25 | def extract_imports(self, elm): 26 | code_elements = elm.code_elements 27 | all_imports: Dict[str, List] = OrderedDict() 28 | all_imports = {x: [] for x in self.import_order_names} 29 | first_occurance_of_import = -1 30 | for i, x in enumerate(code_elements): 31 | if isinstance(x.code_elm, CodeElementImport): 32 | if first_occurance_of_import == -1: 33 | first_occurance_of_import = i 34 | # order the import_items 35 | x.code_elm.import_items.sort(key=lambda x: x.orig_identifier.name) 36 | import_first_word = x.code_elm.path.name.split(".")[0] 37 | # group additional elements if not specified in initial list 38 | if import_first_word not in self.import_order_names: 39 | self.import_order_names.append(import_first_word) 40 | all_imports[import_first_word] = [] 41 | for import_order_name in self.import_order_names: 42 | if (import_order_name == x.code_elm.path.name.split(".")[0]): 43 | all_imports[import_order_name].append(x) 44 | break 45 | code_elements = list(filter(lambda x: not(isinstance(x.code_elm, CodeElementImport)), code_elements)) 46 | all_imports = {x: all_imports[x] for x in self.import_order_names} 47 | 48 | ordered_imports = [] 49 | for _, v in all_imports.items(): 50 | v.sort(key=lambda x: x.code_elm.path.name) 51 | ordered_imports += [self.get_empty_element()] + v 52 | elm.code_elements = code_elements[:first_occurance_of_import] + ordered_imports + code_elements[first_occurance_of_import:] 53 | 54 | def get_empty_element(self): 55 | return CommentedCodeElement(code_elm=CodeElementEmptyLine(), comment=None, location=None) 56 | 57 | 58 | def create_ordered_imports(self, cairo_module): 59 | res = self.visit(cairo_module) 60 | return res 61 | -------------------------------------------------------------------------------- /src/cairo_toolkit/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def to_camel_case(string: str) -> str: 3 | return "".join(word.title() for word in string.split("_")) 4 | -------------------------------------------------------------------------------- /src/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import click 4 | import traceback 5 | 6 | from typing import List 7 | from starkware.cairo.lang.version import __version__ 8 | from src.logic import check_files, generate_interfaces, get_contracts_from_protostar, print_version, generate_ordered_imports 9 | 10 | 11 | @click.group() 12 | @click.option('--version', is_flag=True, callback=print_version, 13 | expose_value=False, is_eager=True) 14 | def cli(): 15 | pass 16 | 17 | 18 | @click.command() 19 | @click.option("--files", '-f', multiple=True, default=[], help="File paths") 20 | @click.option('--protostar', '-p', is_flag=True, help='Uses `protostar.toml` to get file paths') 21 | @click.option('--directory', '-d', help='Output directory for the interfaces. If unspecified, they will be created in the same directory as the contracts') 22 | def generate_interface(protostar: bool, directory: str, files: List[str]): 23 | if protostar: 24 | protostar_path = os.path.join(os.getcwd(), "protostar.toml") 25 | files = get_contracts_from_protostar(protostar_path) 26 | 27 | sys.exit(generate_interfaces(directory, files)) 28 | 29 | 30 | @click.command() 31 | @click.option("--files", multiple=True, default=[], help="Contracts to check") 32 | @click.option('--protostar', '-p', is_flag=True, help='Uses `protostar.toml` to get file paths') 33 | @click.option('--directory', '-d', help='Directory of the interfaces to check. Interfaces must be named `i_.cairo`') 34 | def check_interface(protostar: bool, directory: str, files: List[str]): 35 | if protostar: 36 | protostar_path = os.path.join(os.getcwd(), "protostar.toml") 37 | files = get_contracts_from_protostar(protostar_path) 38 | sys.exit(check_files(directory, files)) 39 | 40 | # this command may be run with: 41 | # python src/cli.py order-imports -f test/main_imports_test.cairo -i starkware -i openzeppelin 42 | # python src/cli.py order-imports -d test/ -i starkware -i openzeppelin 43 | 44 | 45 | @click.command() 46 | @click.option('--directory', '-d', help="Directory with cairo files to format") 47 | @click.option("--files", '-f', multiple=True, default=[], help="File paths") 48 | @click.option("--imports", '-i', multiple=True, default=["starkware", "openzeppelin"], help="Imports order") 49 | def order_imports(directory: str, files: List[str], imports: List[str]): 50 | files_to_order = [] 51 | if directory: 52 | path = os.path.join(os.getcwd(), directory) 53 | for (root, _, cairo_files) in os.walk(path, topdown=True): 54 | for f in cairo_files: 55 | files_to_order.append(os.path.join(root, f)) 56 | else: 57 | files_to_order = files 58 | 59 | sys.exit(generate_ordered_imports(files_to_order, imports)) 60 | 61 | 62 | cli.add_command(generate_interface) 63 | cli.add_command(check_interface) 64 | cli.add_command(order_imports) 65 | 66 | 67 | def main(): 68 | cli() 69 | 70 | 71 | if __name__ == "__main__": 72 | sys.exit(main()) 73 | -------------------------------------------------------------------------------- /src/logic.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from typing import Dict, List 3 | import click 4 | import toml 5 | import os 6 | import subprocess 7 | from starkware.cairo.lang.compiler.ast.module import CairoModule 8 | from starkware.cairo.lang.compiler.parser import parse_file 9 | 10 | from cairo_toolkit.generator import Generator 11 | from cairo_toolkit.interface_parser import InterfaceParser 12 | from cairo_toolkit.order_imports import OrderImports 13 | 14 | def cairo_parser(code, filename): return parse_file( 15 | code=code, filename=filename) 16 | 17 | 18 | def print_version(ctx, param, value): 19 | if not value or ctx.resilient_parsing: 20 | return 21 | click.echo('Version 0.1.5') 22 | ctx.exit() 23 | 24 | 25 | def generate_interfaces(directory: str, files: List[str]): 26 | for path in files: 27 | contract_file = open(path).read() 28 | dirpath, filename = os.path.split(path) 29 | contract_name = filename.split(".")[0] 30 | newfilename = f"i_" + filename 31 | newpath = os.path.join(directory or dirpath, newfilename) 32 | 33 | try: 34 | 35 | # Generate the AST of the cairo contract, visit it and generate the interface 36 | contract = CairoModule( 37 | cairo_file=cairo_parser(contract_file, filename), 38 | module_name=path, 39 | ) 40 | 41 | generator = Generator(dirpath, contract_name) 42 | contract_interface_str = generator.generate_contract_interface( 43 | contract) 44 | 45 | # Generate the AST from the cairo interface, format it, and write it to a file 46 | contract_interface = CairoModule( 47 | cairo_file=cairo_parser(contract_interface_str, newfilename), 48 | module_name=path, 49 | ) 50 | formatted_interface = contract_interface.format() 51 | 52 | except Exception as exc: 53 | print(traceback.format_exc()) 54 | return 1 55 | 56 | print(f"Generating interface {newpath}") 57 | open(newpath, "w").write(formatted_interface) 58 | return 0 59 | 60 | def generate_ordered_imports(files: List[str], imports: List[str]): 61 | for path in files: 62 | contract_file = open(path).read() 63 | _, filename = os.path.split(path) 64 | 65 | try: 66 | # Generate the AST of the cairo contract, visit it and generate the interface 67 | contract = CairoModule( 68 | cairo_file=cairo_parser(contract_file, filename), 69 | module_name=path, 70 | ) 71 | OrderImports([*imports]).create_ordered_imports(contract) 72 | contract = contract.format() 73 | 74 | except Exception as exc: 75 | print(traceback.format_exc()) 76 | return 1 77 | 78 | open(path, "w").write(contract) 79 | return 0 80 | 81 | def check_files(directory, files): 82 | errors = [] 83 | for path in files: 84 | contract_file = open(path).read() 85 | dirpath, filename = os.path.split(path) 86 | contract_name = filename.split(".")[0] 87 | interface_name = f"i_" + filename 88 | interface_path = f"{directory}/{interface_name}" 89 | tempfile_name = interface_name + ".tmp" 90 | try: 91 | interface_file = open(interface_path).read() 92 | except: 93 | print( 94 | f"Couldn't open corresponding interface file for {interface_path}") 95 | continue 96 | 97 | newpath = os.path.join(directory or dirpath, tempfile_name) 98 | 99 | try: 100 | 101 | # Generate the AST of the cairo contract, visit it and generate the interface 102 | contract = CairoModule( 103 | cairo_file=cairo_parser(contract_file, filename), 104 | module_name=path, 105 | ) 106 | 107 | generator = Generator(dirpath, contract_name) 108 | contract_interface_str = generator.generate_contract_interface( 109 | contract) 110 | 111 | # Generate the AST from the cairo interface, format it, and write it to a file 112 | contract_interface = CairoModule( 113 | cairo_file=cairo_parser( 114 | contract_interface_str, interface_name), 115 | module_name=path, 116 | ) 117 | parsed_generated_interface = InterfaceParser( 118 | contract_name).parse_interface(contract_interface) 119 | 120 | existing_interface = CairoModule( 121 | cairo_file=cairo_parser(interface_file, tempfile_name), 122 | module_name=path, 123 | ) 124 | parsed_existing_interface = InterfaceParser( 125 | contract_name).parse_interface(existing_interface) 126 | 127 | def check_name(generated: Dict, existing: Dict): 128 | if generated['name'] != existing['name']: 129 | errors.append( 130 | f"Name mismatch between contract and interface for {contract_name}") 131 | 132 | def check_functions(source: Dict, comparison: Dict, source_is_correct): 133 | error_detail = "is missing from the interface" if source_is_correct else "is not in the contract" 134 | for func_name in source: 135 | if func_name not in comparison: 136 | errors.append( 137 | f"Function <{func_name}> {error_detail} for {contract_name}") 138 | continue 139 | source_params = source[func_name]['params'] 140 | for source_param in source_params: 141 | if source_param not in comparison[func_name]['params']: 142 | errors.append( 143 | f"Parameter <{source_param}> {error_detail} for {contract_name}:{func_name}") 144 | continue 145 | source_returns = source[func_name]['returns'] 146 | for source_return in source_returns: 147 | if source_return not in comparison[func_name]['returns']: 148 | errors.append( 149 | f"Return <{source_return}> {error_detail} for {contract_name}:{func_name}") 150 | continue 151 | 152 | def check_imports(source: Dict, comparison: Dict, source_is_correct): 153 | error_detail = "is missing from the interface" if source_is_correct else "is not in the contract" 154 | for import_name in source: 155 | if import_name not in comparison: 156 | errors.append( 157 | f"Import <{import_name}> {error_detail} for {contract_name}") 158 | continue 159 | source_path = source[import_name] 160 | if source_path not in comparison[import_name]: 161 | errors.append( 162 | f"Import path <{source_path}> {error_detail} for {contract_name}:{import_name}") 163 | continue 164 | 165 | check_name(parsed_generated_interface, parsed_existing_interface) 166 | 167 | # Check if the existing interface has missing elements 168 | check_functions(parsed_generated_interface['functions'], 169 | parsed_existing_interface['functions'], True) 170 | # Check if the existing interface has extra elements 171 | check_functions( 172 | parsed_existing_interface['functions'], parsed_generated_interface['functions'], False) 173 | 174 | check_imports( 175 | parsed_generated_interface['imports'], parsed_existing_interface['imports'], True) 176 | check_imports( 177 | parsed_existing_interface['imports'], parsed_generated_interface['imports'], False) 178 | 179 | except Exception as exc: 180 | print(traceback.format_exc()) 181 | return 1 182 | 183 | print('\n'.join(str(x) for x in errors)) 184 | try: 185 | assert len(errors) == 0 186 | except: 187 | return 1 188 | return 0 189 | 190 | 191 | def get_contracts_from_protostar(protostar_path: str): 192 | config = toml.load(protostar_path) 193 | contracts = config['protostar.contracts'] 194 | contracts_paths = [contract[0] for contract in contracts.values()] 195 | return contracts_paths 196 | -------------------------------------------------------------------------------- /test/main.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin 4 | 5 | from test.types import ImportedStruct 6 | 7 | struct MyStruct { 8 | index: felt, 9 | } 10 | 11 | @storage_var 12 | func storage_skipped() -> (res: felt) { 13 | } 14 | 15 | @external 16 | func struct_in_arg{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 17 | amount: felt, _struct: MyStruct, array_len: felt, array: felt* 18 | ) { 19 | return (); 20 | } 21 | 22 | @view 23 | func struct_ptr_in_return{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( 24 | res_len: felt, res: felt*, arr_len: felt, arr: ImportedStruct* 25 | ) { 26 | return (1, new (1), 1, new ImportedStruct(0)); 27 | } 28 | 29 | @view 30 | func tuple_in_signature{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 31 | tuple: (felt, felt) 32 | ) { 33 | return (); 34 | } 35 | 36 | @view 37 | func implicit_type{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(untyped_arg) { 38 | return (); 39 | } 40 | 41 | @constructor 42 | func constructor{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { 43 | return (); 44 | } 45 | -------------------------------------------------------------------------------- /test/main_imports_test.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.alloc import alloc 4 | from starkware.cairo.common.cairo_builtins import HashBuiltin 5 | from starkware.cairo.common.cairo_builtins import BitwiseBuiltin 6 | from starkware.cairo.common.math import assert_nn_le, unsigned_div_rem 7 | from starkware.cairo.common.math_cmp import is_le, is_nn 8 | from starkware.cairo.common.memcpy import memcpy 9 | from starkware.cairo.common.memset import memset 10 | from starkware.cairo.common.pow import pow 11 | from starkware.cairo.common.registers import get_fp_and_pc 12 | 13 | from test.types import ImportedStruct 14 | 15 | struct MyStruct { 16 | index: felt, 17 | } 18 | 19 | @storage_var 20 | func storage_skipped() -> (res: felt) { 21 | } 22 | 23 | @external 24 | func struct_in_arg{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 25 | amount: felt, _struct: MyStruct, array_len: felt, array: felt* 26 | ) { 27 | return (); 28 | } 29 | 30 | @view 31 | func struct_ptr_in_return{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( 32 | res_len: felt, res: felt*, arr_len: felt, arr: ImportedStruct* 33 | ) { 34 | return (1, new (1), 1, new ImportedStruct(0)); 35 | } 36 | 37 | @constructor 38 | func constructor{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { 39 | return (); 40 | } 41 | -------------------------------------------------------------------------------- /test/nested_test/main_imports_test_nested.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.alloc import alloc 4 | from starkware.cairo.common.cairo_builtins import HashBuiltin 5 | from starkware.cairo.common.cairo_builtins import BitwiseBuiltin 6 | from starkware.cairo.common.math import assert_nn_le, unsigned_div_rem 7 | from starkware.cairo.common.math_cmp import is_le 8 | from starkware.cairo.common.memcpy import memcpy 9 | from starkware.cairo.common.memset import memset 10 | from starkware.cairo.common.pow import pow 11 | from starkware.cairo.common.registers import get_fp_and_pc 12 | 13 | from test.types import ImportedStruct 14 | 15 | struct MyStruct { 16 | index: felt, 17 | } 18 | 19 | @storage_var 20 | func storage_skipped() -> (res: felt) { 21 | } 22 | 23 | @external 24 | func struct_in_arg{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 25 | amount: felt, _struct: MyStruct, array_len: felt, array: felt* 26 | ) { 27 | return (); 28 | } 29 | 30 | @view 31 | func struct_ptr_in_return{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( 32 | res_len: felt, res: felt*, arr_len: felt, arr: ImportedStruct* 33 | ) { 34 | return (1, new (1), 1, new ImportedStruct(0)); 35 | } 36 | 37 | @constructor 38 | func constructor{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { 39 | return (); 40 | } 41 | -------------------------------------------------------------------------------- /test/types.cairo: -------------------------------------------------------------------------------- 1 | struct ImportedStruct { 2 | index: felt, 3 | } 4 | --------------------------------------------------------------------------------