├── .gitattributes ├── .gitignore ├── .gitmodules ├── README.md ├── environment.yml ├── examples └── under_over.cairo ├── shell.nix ├── smc ├── __init__.py ├── interfaces │ ├── module_introspection.cairo │ └── module_registry.cairo ├── libraries │ └── module_registry.cairo ├── main.cairo ├── modules │ ├── module_introspection.cairo │ └── module_registry.cairo └── testing │ ├── __init__.py │ └── modular_contract.py └── tests ├── conftest.py └── test_smc.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.cairo linguist-language=python -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.mamba 2 | /.pytest_cache 3 | __pycache__ 4 | /artifacts -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "vendor/cairo-contracts"] 2 | path = vendor/cairo-contracts 3 | url = https://github.com/OpenZeppelin/cairo-contracts.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StarkNet Modular Contracts (SMC) Standard 2 | 3 | This repository presents a way to build and deploy modular StarkNet contracts. 4 | It is heavily inspired by the [Diamond Standard](https://eips.ethereum.org/EIPS/eip-2535), but uses a different naming convention to avoid confusion. 5 | 6 | ## Running 7 | 8 | You can use nile to compile the smart contracts: 9 | 10 | ```bash 11 | nile compile 12 | ``` 13 | 14 | You can use pytest to runs the tests: 15 | 16 | ```bash 17 | PYTHONPATH=. pytest tests 18 | ``` 19 | 20 | ## Modular Contracts 21 | 22 | The idea is to deploy a `main` contract and then add modules to it 23 | to increase its functionality. Modules can be replaced to perform an 24 | upgrade. The `main` module automatically registers the `ModuleRegistry` 25 | module on deployment. You can make a contract immutable by removing the 26 | `ModuleRegistry`. 27 | 28 | ## Built-ins modules 29 | 30 | * `ModuleRegistry`: provides a function to add, replace, and remove modules. 31 | * `ModuleIntrospection`: provides functions to inspect registered modules. 32 | 33 | 34 | ## Authoring modules 35 | 36 | New modules are best implemented using the [extensibility 37 | pattern](https://github.com/OpenZeppelin/cairo-contracts/blob/main/docs/Extensibility.md) 38 | proposed by Open Zeppelin. 39 | 40 | Modules should never import functions from other modules, they should instead 41 | import functions from libraries. Importing from modules results in accidentally 42 | exporting the module's external functions. 43 | 44 | Modules are like contracts, but they don't have a 45 | constructor, initialization is provided by defining an `@external` `initializer` 46 | function. 47 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: diamond-standard 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - pip 6 | - python=3.8 7 | - pip: 8 | - tox 9 | - cairo-lang 10 | - cairo-nile -------------------------------------------------------------------------------- /examples/under_over.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin 4 | from starkware.cairo.common.math_cmp import is_le 5 | 6 | from openzeppelin.access.ownable import Ownable_only_owner 7 | 8 | @storage_var 9 | func _reference() -> (reference : felt): 10 | end 11 | 12 | @external 13 | func initializer{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 14 | reference : felt) -> (): 15 | _reference.write(reference) 16 | return () 17 | end 18 | 19 | @external 20 | func setReference{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 21 | reference : felt) -> (): 22 | Ownable_only_owner() 23 | _reference.write(reference) 24 | return () 25 | end 26 | 27 | @view 28 | func getReference{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}() -> ( 29 | reference : felt): 30 | let (reference) = _reference.read() 31 | return (reference=reference) 32 | end 33 | 34 | @view 35 | func underOver{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(n : felt) -> ( 36 | is_over : felt): 37 | alloc_locals 38 | let (reference) = _reference.read() 39 | 40 | local syscall_ptr : felt* = syscall_ptr 41 | 42 | let (is_over) = is_le(reference, n) 43 | 44 | return (is_over=is_over) 45 | end 46 | -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | with import {}; 2 | let 3 | fhs = pkgs.buildFHSUserEnv { 4 | name = "starknet-diamond-standard"; 5 | 6 | targetPkgs = _: [ 7 | pkgs.bash 8 | pkgs.bashInteractive 9 | pkgs.gcc 10 | pkgs.micromamba 11 | ]; 12 | 13 | multiPkgs = _: [ 14 | pkgs.gmpxx.dev 15 | ]; 16 | 17 | profile = '' 18 | set -e 19 | eval "$(micromamba shell hook -s bash)" 20 | export MAMBA_ROOT_PREFIX=${builtins.getEnv "PWD"}/.mamba 21 | micromamba create -q -f environment.yml --yes -c conda-forge 22 | micromamba activate diamond-standard 23 | set +e 24 | ''; 25 | }; 26 | in fhs.env -------------------------------------------------------------------------------- /smc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apibara/starknet-modular-contracts-standard/b4bb0ea7d268c79fdc4080ee687997b4d5c07e7f/smc/__init__.py -------------------------------------------------------------------------------- /smc/interfaces/module_introspection.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | @contract_interface 4 | namespace IModuleIntrospection: 5 | func moduleFunctionSelectors(module_address : felt) -> ( 6 | selectors_len : felt, selectors : felt*): 7 | end 8 | 9 | func moduleAddresses() -> (module_addresses_len : felt, module_addresses : felt*): 10 | end 11 | 12 | func moduleAddess(selector : felt) -> (module_address : felt): 13 | end 14 | end 15 | -------------------------------------------------------------------------------- /smc/interfaces/module_registry.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | const MODULE_FUNCTION_ADD = 0 4 | const MODULE_FUNCTION_REPLACE = 1 5 | const MODULE_FUNCTION_REMOVE = 2 6 | 7 | struct ModuleFunctionAction: 8 | member module_address : felt 9 | member action : felt 10 | member selector : felt 11 | end 12 | 13 | @contract_interface 14 | namespace IModuleRegistry: 15 | func changeModules( 16 | actions_len : felt, actions : ModuleFunctionAction*, address : felt, 17 | calldata_len : felt, calldata : felt*): 18 | end 19 | end 20 | 21 | @event 22 | func ModuleFunctionChange( 23 | actions_len : felt, actions : ModuleFunctionAction*, address : felt, calldata_len : felt, 24 | calldata : felt*): 25 | end 26 | -------------------------------------------------------------------------------- /smc/libraries/module_registry.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin 4 | from starkware.cairo.common.math import assert_not_zero, assert_not_equal 5 | from starkware.starknet.common.syscalls import delegate_call 6 | 7 | from smc.interfaces.module_registry import ( 8 | ModuleFunctionAction, ModuleFunctionChange, MODULE_FUNCTION_ADD, MODULE_FUNCTION_REPLACE, 9 | MODULE_FUNCTION_REMOVE) 10 | 11 | # Map function selectors to the modules that execute the function. 12 | @storage_var 13 | func _module_registry_modules(selector : felt) -> (module_address : felt): 14 | end 15 | 16 | # List of all selectors 17 | @storage_var 18 | func _module_registry_selectors(index : felt) -> (selector : felt): 19 | end 20 | 21 | # Length of _module_registry_selectors 22 | @storage_var 23 | func _module_registry_selectors_len() -> (len : felt): 24 | end 25 | 26 | # get_selector_from_name('changeModules') 27 | const CHANGE_MODULES_SELECTOR = 1808683055422503325942160754016371337440997851558534157930265361990569747463 28 | 29 | # get_selector_from_name('initializer') 30 | const INITIALIZER_SELECTOR = 1295919550572838631247819983596733806859788957403169325509326258146877103642 31 | 32 | # ---------------------------------------------------------------------------- # 33 | # # 34 | # Manage Module Functions # 35 | # # 36 | # ---------------------------------------------------------------------------- # 37 | 38 | func module_registry_change_modules{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 39 | actions_len : felt, actions : ModuleFunctionAction*, address : felt, calldata_len : felt, 40 | calldata : felt*): 41 | alloc_locals 42 | 43 | _module_registry_change_modules(actions_len, actions, address, calldata_len, calldata) 44 | 45 | local pedersen_ptr : HashBuiltin* = pedersen_ptr 46 | 47 | ModuleFunctionChange.emit(actions_len, actions, address, calldata_len, calldata) 48 | 49 | return () 50 | end 51 | 52 | func module_registry_get_module_address{ 53 | syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(selector : felt) -> ( 54 | address : felt): 55 | let (address) = _module_registry_modules.read(selector) 56 | return (address=address) 57 | end 58 | 59 | func _module_registry_change_modules{ 60 | syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 61 | actions_len : felt, actions : ModuleFunctionAction*, address : felt, calldata_len : felt, 62 | calldata : felt*): 63 | _change_modules_loop(actions_len, actions) 64 | 65 | if address != 0: 66 | assert_not_zero(calldata_len) 67 | 68 | let (retdata_size : felt, retdata : felt*) = delegate_call( 69 | contract_address=address, 70 | function_selector=INITIALIZER_SELECTOR, 71 | calldata_size=calldata_len, 72 | calldata=calldata) 73 | 74 | assert retdata_size = 0 75 | return () 76 | else: 77 | return () 78 | end 79 | end 80 | 81 | func _change_modules_loop{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 82 | actions_len : felt, actions : ModuleFunctionAction*): 83 | if actions_len == 0: 84 | return () 85 | end 86 | 87 | let module_action : ModuleFunctionAction = [actions] 88 | 89 | if module_action.action == MODULE_FUNCTION_ADD: 90 | _add_registry_module(module_action.selector, module_action.module_address) 91 | tempvar syscall_ptr : felt* = syscall_ptr 92 | tempvar pedersen_ptr : HashBuiltin* = pedersen_ptr 93 | tempvar range_check_ptr = range_check_ptr 94 | else: 95 | tempvar syscall_ptr : felt* = syscall_ptr 96 | tempvar pedersen_ptr : HashBuiltin* = pedersen_ptr 97 | tempvar range_check_ptr = range_check_ptr 98 | end 99 | 100 | if module_action.action == MODULE_FUNCTION_REPLACE: 101 | _replace_registry_module(module_action.selector, module_action.module_address) 102 | tempvar syscall_ptr : felt* = syscall_ptr 103 | tempvar pedersen_ptr : HashBuiltin* = pedersen_ptr 104 | tempvar range_check_ptr = range_check_ptr 105 | else: 106 | tempvar syscall_ptr : felt* = syscall_ptr 107 | tempvar pedersen_ptr : HashBuiltin* = pedersen_ptr 108 | tempvar range_check_ptr = range_check_ptr 109 | end 110 | 111 | if module_action.action == MODULE_FUNCTION_REMOVE: 112 | _remove_registry_module(module_action.selector, module_action.module_address) 113 | tempvar syscall_ptr : felt* = syscall_ptr 114 | tempvar pedersen_ptr : HashBuiltin* = pedersen_ptr 115 | tempvar range_check_ptr = range_check_ptr 116 | else: 117 | tempvar syscall_ptr : felt* = syscall_ptr 118 | tempvar pedersen_ptr : HashBuiltin* = pedersen_ptr 119 | tempvar range_check_ptr = range_check_ptr 120 | end 121 | 122 | return _change_modules_loop(actions_len - 1, actions + ModuleFunctionAction.SIZE) 123 | end 124 | 125 | func _add_registry_module{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 126 | selector : felt, module_address : felt): 127 | # checks: selector is not already registered 128 | let (existing_module_address) = _module_registry_modules.read(selector) 129 | with_attr error_message("selector already exists"): 130 | assert existing_module_address = 0 131 | end 132 | 133 | # effects: add selector to module list 134 | _module_registry_modules.write(selector, module_address) 135 | 136 | # effects: update list of selectors 137 | let (selectors_len) = _module_registry_selectors_len.read() 138 | _module_registry_selectors_len.write(selectors_len + 1) 139 | _module_registry_selectors.write(selectors_len, selector) 140 | 141 | return () 142 | end 143 | 144 | func _replace_registry_module{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 145 | selector : felt, module_address : felt): 146 | # checks: selector is registered 147 | let (existing_module_address) = _module_registry_modules.read(selector) 148 | with_attr error_message("selector does not exists"): 149 | assert_not_zero(existing_module_address) 150 | end 151 | 152 | # effects: add selector to module list 153 | _module_registry_modules.write(selector, module_address) 154 | 155 | return () 156 | end 157 | 158 | func _remove_registry_module{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 159 | selector : felt, module_address : felt): 160 | alloc_locals 161 | 162 | # checks: selector is registered 163 | let (existing_module_address) = _module_registry_modules.read(selector) 164 | with_attr error_message("selector does not exists"): 165 | assert_not_zero(existing_module_address) 166 | end 167 | 168 | # effects: remove selector from module list 169 | _module_registry_modules.write(selector, 0) 170 | 171 | # effects: update list of selectors 172 | # fill the hole left by this selector by moving the last selector to it 173 | let (local selectors_len) = _module_registry_selectors_len.read() 174 | 175 | let (selector_index) = _find_selector_index_loop(selector, selectors_len, 0) 176 | # checks: selector found 177 | assert_not_equal(selector_index, selectors_len) 178 | 179 | # notice: selectors length is > 1 180 | let (last_selector) = _module_registry_selectors.read(selectors_len - 1) 181 | 182 | _module_registry_selectors.write(selector_index, last_selector) 183 | _module_registry_selectors_len.write(selectors_len - 1) 184 | 185 | return () 186 | end 187 | 188 | func _find_selector_index_loop{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 189 | selector : felt, selectors_len : felt, current_index : felt) -> (index : felt): 190 | if current_index == selectors_len: 191 | return (index=selectors_len) 192 | end 193 | 194 | let (current_selector) = _module_registry_selectors.read(current_index) 195 | 196 | if current_selector == selector: 197 | return (index=current_index) 198 | else: 199 | return _find_selector_index_loop(selector, selectors_len, current_index + 1) 200 | end 201 | end 202 | 203 | -------------------------------------------------------------------------------- /smc/main.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin 4 | from starkware.cairo.common.alloc import alloc 5 | from starkware.cairo.common.math import assert_not_zero 6 | from starkware.starknet.common.syscalls import delegate_call 7 | 8 | from openzeppelin.access.ownable import Ownable_initializer 9 | 10 | from smc.interfaces.module_registry import ModuleFunctionAction, MODULE_FUNCTION_ADD 11 | from smc.libraries.module_registry import ( 12 | module_registry_change_modules, module_registry_get_module_address, 13 | CHANGE_MODULES_SELECTOR) 14 | 15 | @constructor 16 | func constructor{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 17 | owner : felt, module_registry_address : felt): 18 | alloc_locals 19 | 20 | # effects: set owner of this contract 21 | Ownable_initializer(owner) 22 | 23 | local module_action : ModuleFunctionAction 24 | module_action.action = MODULE_FUNCTION_ADD 25 | module_action.selector = CHANGE_MODULES_SELECTOR 26 | module_action.module_address = module_registry_address 27 | 28 | let (actions : ModuleFunctionAction*) = alloc() 29 | assert [actions] = module_action 30 | 31 | let (calldata : felt*) = alloc() 32 | 33 | # effects: initialize registry 34 | module_registry_change_modules(1, actions, 0, 0, calldata) 35 | 36 | return () 37 | end 38 | 39 | @external 40 | @raw_input 41 | @raw_output 42 | func __default__{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 43 | selector : felt, calldata_size : felt, calldata : felt*) -> ( 44 | retdata_size : felt, retdata : felt*): 45 | let (address) = module_registry_get_module_address(selector) 46 | 47 | with_attr error_message("selector not found"): 48 | assert_not_zero(address) 49 | end 50 | 51 | let (retdata_size : felt, retdata : felt*) = delegate_call( 52 | contract_address=address, 53 | function_selector=selector, 54 | calldata_size=calldata_size, 55 | calldata=calldata) 56 | return (retdata_size=retdata_size, retdata=retdata) 57 | end 58 | -------------------------------------------------------------------------------- /smc/modules/module_introspection.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.alloc import alloc 4 | from starkware.cairo.common.cairo_builtins import HashBuiltin 5 | from starkware.cairo.common.math import assert_not_zero 6 | 7 | from smc.libraries.module_registry import ( 8 | module_registry_get_module_address, _module_registry_selectors_len, _module_registry_selectors) 9 | 10 | # ---------------------------------------------------------------------------- # 11 | # # 12 | # IModuleIntrospection interface # 13 | # # 14 | # ---------------------------------------------------------------------------- # 15 | 16 | @view 17 | func moduleFunctionSelectors{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 18 | module_address : felt) -> (selectors_len : felt, selectors : felt*): 19 | alloc_locals 20 | 21 | let (selectors_len) = _module_registry_selectors_len.read() 22 | let (local module_selectors : felt*) = alloc() 23 | 24 | let (module_selectors_len) = _collect_module_selectors_loop( 25 | module_address, selectors_len, 0, 0, module_selectors) 26 | 27 | return (selectors_len=module_selectors_len, selectors=module_selectors) 28 | end 29 | 30 | @view 31 | func moduleAddresses{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}() -> ( 32 | module_addresses_len : felt, module_addresses : felt*): 33 | alloc_locals 34 | 35 | let (selectors_len) = _module_registry_selectors_len.read() 36 | let (local module_addresses : felt*) = alloc() 37 | let (module_addresses_len) = _collect_module_addresses_loop( 38 | selectors_len, 0, 0, module_addresses) 39 | 40 | return (module_addresses_len=module_addresses_len, module_addresses=module_addresses) 41 | end 42 | 43 | @view 44 | func moduleAddress{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 45 | selector : felt) -> (module_address : felt): 46 | let (module_address) = module_registry_get_module_address(selector) 47 | return (module_address=module_address) 48 | end 49 | 50 | # ---------------------------------------------------------------------------- # 51 | # # 52 | # Private Functions # 53 | # # 54 | # ---------------------------------------------------------------------------- # 55 | 56 | func _collect_module_selectors_loop{ 57 | syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 58 | module_address : felt, selectors_len : felt, current_index : felt, 59 | module_selectors_len : felt, module_selectors : felt*) -> (module_selectors_len : felt): 60 | if current_index == selectors_len: 61 | return (module_selectors_len=module_selectors_len) 62 | end 63 | 64 | # checks: selector exists and has a module 65 | let (selector) = _module_registry_selectors.read(current_index) 66 | assert_not_zero(selector) 67 | let (selector_module_address) = module_registry_get_module_address(selector) 68 | assert_not_zero(selector_module_address) 69 | 70 | if selector_module_address == module_address: 71 | assert [module_selectors] = selector 72 | return _collect_module_selectors_loop( 73 | module_address, 74 | selectors_len, 75 | current_index + 1, 76 | module_selectors_len + 1, 77 | module_selectors + 1) 78 | else: 79 | return _collect_module_selectors_loop( 80 | module_address, 81 | selectors_len, 82 | current_index + 1, 83 | module_selectors_len, 84 | module_selectors) 85 | end 86 | end 87 | 88 | func _collect_module_addresses_loop{ 89 | syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 90 | selectors_len : felt, current_index : felt, module_addresses_len : felt, 91 | module_addresses : felt*) -> (module_addresses_len : felt): 92 | alloc_locals 93 | 94 | if current_index == selectors_len: 95 | return (module_addresses_len=module_addresses_len) 96 | end 97 | 98 | # checks: selector exists and has a module 99 | let (selector) = _module_registry_selectors.read(current_index) 100 | assert_not_zero(selector) 101 | let (local selector_module_address) = module_registry_get_module_address(selector) 102 | assert_not_zero(selector_module_address) 103 | 104 | # offset module_addresses by -1 because the last element does not contain 105 | # any address 106 | let (address_included) = _module_addresses_contains_loop( 107 | selector_module_address, module_addresses_len, module_addresses - 1) 108 | 109 | if address_included == 0: 110 | assert [module_addresses] = selector_module_address 111 | return _collect_module_addresses_loop( 112 | selectors_len, current_index + 1, module_addresses_len + 1, module_addresses + 1) 113 | else: 114 | return _collect_module_addresses_loop( 115 | selectors_len, current_index + 1, module_addresses_len, module_addresses) 116 | end 117 | end 118 | 119 | func _module_addresses_contains_loop{ 120 | syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 121 | target_module_address : felt, module_addresses_len : felt, module_addresses : felt*) -> ( 122 | contains : felt): 123 | # if it looped this far then it doesn't contain the target module address 124 | if module_addresses_len == 0: 125 | return (contains=0) 126 | end 127 | 128 | let current_address = [module_addresses] 129 | 130 | if current_address == target_module_address: 131 | return (contains=1) 132 | else: 133 | return _module_addresses_contains_loop( 134 | target_module_address, module_addresses_len - 1, module_addresses - 1) 135 | end 136 | end 137 | -------------------------------------------------------------------------------- /smc/modules/module_registry.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin 4 | 5 | from openzeppelin.access.ownable import Ownable_only_owner 6 | 7 | from smc.interfaces.module_registry import ModuleFunctionAction 8 | from smc.libraries.module_registry import module_registry_change_modules 9 | 10 | # ---------------------------------------------------------------------------- # 11 | # # 12 | # IModuleRegistry interface # 13 | # # 14 | # ---------------------------------------------------------------------------- # 15 | 16 | @external 17 | func changeModules{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}( 18 | actions_len : felt, actions : ModuleFunctionAction*, address : felt, calldata_len : felt, 19 | calldata : felt*): 20 | Ownable_only_owner() 21 | return module_registry_change_modules(actions_len, actions, address, calldata_len, calldata) 22 | end 23 | -------------------------------------------------------------------------------- /smc/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .modular_contract import ModularContract -------------------------------------------------------------------------------- /smc/testing/modular_contract.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from enum import Enum 3 | 4 | from starkware.starknet.public.abi import get_selector_from_name 5 | from starkware.starknet.testing.contract import StarknetContract 6 | from starkware.starknet.testing.contract_utils import StructManager, EventManager 7 | 8 | 9 | class ModuleAction(Enum): 10 | ADD = 0 11 | REPLACE = 1 12 | REMOVE = 2 13 | 14 | 15 | class ModularContract(StarknetContract): 16 | """ 17 | A high-level interface to a StarkNet contract that follows the 18 | Modular Contracts Standard. 19 | """ 20 | def add_module(self, module: StarknetContract, initializer_args: Optional[List[int]] = None): 21 | res = self._change_module(module, ModuleAction.ADD, initializer_args) 22 | # add functions, structs, and events from module to this one 23 | _add_to_struct_manager(self.struct_manager, module.struct_manager) 24 | _add_to_event_manager(self.event_manager, module.event_manager) 25 | _add_functions(self, module) 26 | return res 27 | 28 | def remove_module(self, module: StarknetContract): 29 | res = self._change_module(module, ModuleAction.REMOVE) 30 | # keep structs and events, but remove functions 31 | _remove_functions(self, module) 32 | return res 33 | 34 | def _change_module(self, module: StarknetContract, action: ModuleAction, initializer_args: Optional[List[int]] = None): 35 | actions = [] 36 | for func in module._abi_function_mapping.keys(): 37 | if func == 'initializer': 38 | continue 39 | actions.append(( 40 | module.contract_address, 41 | action.value, 42 | get_selector_from_name(func), 43 | )) 44 | 45 | if initializer_args is None: 46 | return self.changeModules(actions, 0, []) 47 | 48 | return self.changeModules(actions, module.contract_address, initializer_args) 49 | 50 | 51 | def _add_to_struct_manager(dest: StructManager, src: StructManager): 52 | for k, v in src._struct_definition_mapping.items(): 53 | dest._struct_definition_mapping[k] = v 54 | 55 | 56 | def _add_to_event_manager(dest: EventManager, src: EventManager): 57 | for k, v in src._abi_event_mapping.items(): 58 | dest._abi_event_mapping[k] = v 59 | 60 | for k, v in src._selector_to_name.items(): 61 | dest._selector_to_name[k] = v 62 | 63 | 64 | def _add_functions(dest: StarknetContract, src: StarknetContract): 65 | for k, v in src._abi_function_mapping.items(): 66 | dest._abi_function_mapping[k] = v 67 | 68 | 69 | def _remove_functions(dest: StarknetContract, src: StarknetContract): 70 | for k, _v in src._abi_function_mapping.items(): 71 | del dest._abi_function_mapping[k] 72 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pathlib import Path 3 | from typing import List, Protocol, Tuple 4 | 5 | import pytest 6 | from starkware.starknet.compiler.compile import compile_starknet_files 7 | from starkware.starknet.testing.contract import StarknetContract 8 | from starkware.starknet.testing.state import StarknetState 9 | 10 | 11 | ALICE = 1189998819991197253 12 | BOB = 118999881999119725311 13 | 14 | 15 | class StarknetFactory(Protocol): 16 | def __call__(self) -> Tuple[StarknetState, List[StarknetContract]]: 17 | ... 18 | 19 | 20 | @pytest.fixture(scope="module") 21 | def event_loop(): 22 | return asyncio.new_event_loop() 23 | 24 | 25 | 26 | _root_dir = Path(__file__).parent.parent 27 | _examples_dir = _root_dir / 'examples' 28 | _contracts_dir = _root_dir / 'smc' 29 | 30 | 31 | def compile_examples_contract(contract_name): 32 | filename = str(_examples_dir / f'{contract_name}.cairo') 33 | return compile_starknet_files( 34 | [filename], 35 | debug_info=True, 36 | cairo_path=[ 37 | str(_contracts_dir), 38 | str(_root_dir / 'vendor' / 'cairo-contracts'), 39 | ], 40 | ) 41 | 42 | 43 | def compile_smc_contract(contract_name): 44 | filename = contract_path(contract_name) 45 | return compile_starknet_files( 46 | [filename], 47 | debug_info=True, 48 | cairo_path=[ 49 | str(_contracts_dir), 50 | str(_root_dir / 'vendor' / 'cairo-contracts'), 51 | ], 52 | ) 53 | 54 | 55 | def contract_path(contract_name): 56 | return str(_contracts_dir / f'{contract_name}.cairo') 57 | -------------------------------------------------------------------------------- /tests/test_smc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | import pytest_asyncio 5 | from starkware.starknet.testing.contract import StarknetContract 6 | from starkware.starknet.testing.state import StarknetState 7 | from starkware.starknet.public.abi import get_selector_from_name 8 | from starkware.starkware_utils.error_handling import StarkException 9 | 10 | from smc.testing import ModularContract 11 | 12 | from conftest import ALICE, BOB, StarknetFactory, compile_examples_contract, compile_smc_contract 13 | 14 | 15 | @pytest_asyncio.fixture(scope='module') 16 | async def starknet_factory(): 17 | starknet = await StarknetState.empty() 18 | 19 | module_registry_def = compile_smc_contract('modules/module_registry') 20 | module_registry_addr, module_registry_exec_info = await starknet.deploy( 21 | contract_definition=module_registry_def, 22 | constructor_calldata=[], 23 | ) 24 | 25 | smc_main_def = compile_smc_contract('main') 26 | 27 | alice_main_addr, alice_main_exec_info = await starknet.deploy( 28 | contract_definition=smc_main_def, 29 | constructor_calldata=[ALICE, module_registry_addr], 30 | ) 31 | bob_main_addr, bob_main_exec_info = await starknet.deploy( 32 | contract_definition=smc_main_def, 33 | constructor_calldata=[BOB, module_registry_addr], 34 | ) 35 | 36 | under_over_module_def = compile_examples_contract('under_over') 37 | under_over_module_addr, under_over_module_exec_info = await starknet.deploy( 38 | contract_definition=under_over_module_def, 39 | constructor_calldata=[], 40 | ) 41 | 42 | module_introspection_def = compile_smc_contract('modules/module_introspection') 43 | module_introspection_addr, module_introspection_exec_info = await starknet.deploy( 44 | contract_definition=module_introspection_def, 45 | constructor_calldata=[] 46 | ) 47 | 48 | 49 | def _f(): 50 | state = starknet.copy() 51 | module_registry = StarknetContract( 52 | state=state, 53 | abi=module_registry_def.abi, 54 | contract_address=module_registry_addr, 55 | deploy_execution_info=module_registry_exec_info, 56 | ) 57 | 58 | alice_main = ModularContract( 59 | state=state, 60 | abi=module_registry_def.abi, 61 | contract_address=alice_main_addr, 62 | deploy_execution_info=alice_main_exec_info, 63 | ) 64 | 65 | bob_main = ModularContract( 66 | state=state, 67 | abi=module_registry_def.abi, 68 | contract_address=bob_main_addr, 69 | deploy_execution_info=bob_main_exec_info, 70 | ) 71 | 72 | under_over = StarknetContract( 73 | state=state, 74 | abi=under_over_module_def.abi, 75 | contract_address=under_over_module_addr, 76 | deploy_execution_info=under_over_module_exec_info, 77 | ) 78 | 79 | module_introspection = StarknetContract( 80 | state=state, 81 | abi=module_introspection_def.abi, 82 | contract_address=module_introspection_addr, 83 | deploy_execution_info=module_introspection_exec_info, 84 | ) 85 | 86 | return state, [alice_main, bob_main, module_registry, under_over, module_introspection] 87 | 88 | return _f 89 | 90 | 91 | @pytest.mark.asyncio 92 | async def test_it_works(starknet_factory: StarknetFactory): 93 | starknet, [alice_main, bob_main, _module_registry, under_over, _module_introspection] = starknet_factory() 94 | 95 | with pytest.raises(StarkException): 96 | await starknet.invoke_raw( 97 | contract_address=alice_main.contract_address, 98 | selector='getReference', 99 | calldata=[], 100 | caller_address=ALICE, 101 | ) 102 | 103 | await alice_main.add_module(under_over).invoke(caller_address=ALICE) 104 | 105 | exec_info = await alice_main.getReference().call(caller_address=ALICE) 106 | 107 | # reference number is initialized to 0 by default 108 | assert exec_info.result.reference == 0 109 | 110 | # update number 111 | await alice_main.setReference(42).invoke(caller_address=ALICE) 112 | 113 | exec_info = await alice_main.getReference().call(caller_address=ALICE) 114 | assert exec_info.result.reference == 42 115 | 116 | # double check bob's contract is untouched 117 | with pytest.raises(StarkException): 118 | await starknet.invoke_raw( 119 | contract_address=bob_main.contract_address, 120 | selector='getReference', 121 | calldata=[], 122 | caller_address=ALICE, 123 | ) 124 | 125 | await bob_main.add_module(under_over, initializer_args=[100]).invoke(caller_address=BOB) 126 | 127 | exec_info = await bob_main.getReference().call(caller_address=ALICE) 128 | # since we called the initializer, the reference number is initalized 129 | # to something else 130 | assert exec_info.result.reference == 100 131 | 132 | # update number 133 | await bob_main.setReference(313).invoke(caller_address=BOB) 134 | exec_info = await bob_main.getReference().call(caller_address=ALICE) 135 | 136 | # bob's reference number was updated 137 | assert exec_info.result.reference == 313 138 | 139 | # double check alice's contract was not touched 140 | exec_info = await alice_main.getReference().call(caller_address=ALICE) 141 | assert exec_info.result.reference == 42 142 | 143 | await bob_main.remove_module(under_over).invoke(caller_address=BOB) 144 | 145 | with pytest.raises(StarkException): 146 | await starknet.invoke_raw( 147 | contract_address=bob_main.contract_address, 148 | selector='getReference', 149 | calldata=[], 150 | caller_address=ALICE, 151 | ) 152 | 153 | 154 | @pytest.mark.asyncio 155 | async def test_module_introspection(starknet_factory: StarknetFactory): 156 | starknet, [alice_main, _bob_main, module_registry, _under_over, module_introspection] = starknet_factory() 157 | 158 | with pytest.raises(StarkException): 159 | await starknet.invoke_raw( 160 | contract_address=alice_main.contract_address, 161 | selector='moduleAddresses', 162 | calldata=[], 163 | caller_address=ALICE, 164 | ) 165 | 166 | await alice_main.add_module(module_introspection).invoke(caller_address=ALICE) 167 | 168 | exec_info = await alice_main.moduleAddresses().call(caller_address=ALICE) 169 | assert set(exec_info.result.module_addresses) == {module_registry.contract_address, module_introspection.contract_address} 170 | 171 | exec_info = await alice_main.moduleFunctionSelectors(module_registry.contract_address).call(caller_address=ALICE) 172 | assert exec_info.result.selectors == [get_selector_from_name('changeModules')] 173 | 174 | exec_info = await alice_main.moduleAddress(get_selector_from_name('moduleAddresses')).call(caller_address=ALICE) 175 | assert exec_info.result.module_address == module_introspection.contract_address 176 | 177 | 178 | @pytest.mark.asyncio 179 | async def test_ownership(starknet_factory: StarknetFactory): 180 | starknet, [alice_main, _bob_main, _module_registry, under_over, _module_introspection] = starknet_factory() 181 | 182 | with pytest.raises(StarkException): 183 | await alice_main.add_module(under_over).invoke(caller_address=BOB) 184 | 185 | exec_info = await alice_main.add_module(under_over).invoke(caller_address=ALICE) 186 | assert len(exec_info.raw_events) == 1 187 | --------------------------------------------------------------------------------