├── .github └── workflows │ └── test.yaml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── contracts ├── account │ ├── IPluginAccount.cairo │ ├── PluginAccount.cairo │ └── library.cairo ├── plugins │ ├── SessionKey.cairo │ └── signer │ │ └── StarkSigner.cairo ├── test │ ├── Dapp.cairo │ └── FakeAccount.cairo └── upgrade │ ├── IProxy.cairo │ ├── Proxy.cairo │ └── Upgradable.cairo ├── localhost.accounts.json ├── node.json ├── protostar.toml ├── pyproject.toml ├── requirements.txt └── tests ├── test_account.cairo ├── test_account.py ├── test_session_key.py ├── test_stark_signer.py └── utils ├── merkle_utils.py ├── plugin_signer.py ├── session_keys_utils.py └── utils.py /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: protostar 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Install protostar 14 | run: curl -L https://raw.githubusercontent.com/software-mansion/protostar/master/install.sh | bash 15 | - name: Activate protostar 16 | run: echo "/home/runner/.protostar/dist/protostar" >> $GITHUB_PATH 17 | - name: Run tests 18 | run: protostar test 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .temp/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | artifacts 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # vscode project settings 136 | .vscode/ 137 | 138 | # nile 139 | accounts.json 140 | node.json 141 | .idea 142 | 143 | # protostar 144 | .protostar_cache -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 argentlabs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Build and test 2 | build :; nile compile 3 | test :; pytest tests/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # starknet-plugin-account 2 | 3 | Account abstraction opens a completely new design space for accounts. 4 | 5 | This repository is a community effort lead by [Argent](https://www.argent.xyz/), [Cartridge](https://cartridge.gg) and [Ledger](https://www.ledger.com/), to explore the possibility to make accounts more flexible and modular by defining a plugin account architecture which lets users compose functionalities they want to enable when creating their account. The proposed architecture also aims to make the account extendable by letting users add or remove functionalities after the account has been created. 6 | 7 | The idea of modular smart-contracts is not new and several architectures have been proposed for Ethereum [Argent smart-wallet, Diamond Pattern]. However, it is the first time that this is applied to accounts directly by leveraging account abstraction. 8 | 9 | ## Account Abstraction: 10 | 11 | In StarkNet accounts must comply to the `IAccount` interface: 12 | 13 | ```cairo 14 | @contract_interface 15 | namespace IAccount { 16 | 17 | func supportsInterface(interfaceId: felt) -> (success: felt) { 18 | } 19 | 20 | func isValidSignature(hash: felt, signature_len: felt, signature: felt*) -> (isValid: felt) { 21 | } 22 | 23 | func __validate__( 24 | call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* 25 | ) { 26 | } 27 | 28 | func __validate_declare__(class_hash: felt) { 29 | } 30 | 31 | func __validate_deploy__( 32 | class_hash: felt, ctr_args_len: felt, ctr_args: felt*, salt: felt 33 | ) { 34 | } 35 | 36 | func __execute__( 37 | call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* 38 | ) -> (response_len: felt, response: felt*) { 39 | } 40 | } 41 | ``` 42 | The two important methods are `__validate__` which is called by the Starknet OS to verify that the transaction is valid and that the account will pay the fee before `__execute__` is called by the OS to execute the transaction. 43 | 44 | The `__validate__` method has some constraints to protect the network. In particular, its logic must be implemented in a small number of steps and it cannot access the mutable state of any other contracts (i.e. it can only read the storage of the account). 45 | 46 | ## PluginAccount: 47 | 48 | The `PluginAccount` contract is the main account contract that supports the addition of plugins. 49 | 50 | A plugin is a separate piece of logic that can extend the functionalities of the account. 51 | 52 | In this first version we focus only on the validation of transactions so plugins can implement different validation logic. However, the architecture can be easily extended to let plugins handle the execution of transactions in the future. 53 | 54 | The Plugin Account extends the base account interface with the following interface: 55 | 56 | ```cairo 57 | func addPlugin(plugin: felt, plugin_calldata_len: felt, plugin_calldata: felt*) { 58 | } 59 | 60 | func removePlugin(plugin: felt) { 61 | } 62 | 63 | func isPlugin(plugin: felt) -> (success: felt) { 64 | } 65 | 66 | func readOnPlugin(plugin: felt, selector: felt, calldata_len: felt, calldata: felt*) { 67 | } 68 | 69 | func executeOnPlugin(plugin: felt, selector: felt, calldata_len: felt, calldata: felt*) -> (retdata_len: felt, retdata: felt*){ 70 | } 71 | 72 | ``` 73 | 74 | A plugin must expose the following interface: 75 | 76 | ```cairo 77 | @contract_interface 78 | namespace IPlugin { 79 | func initialize( 80 | calldata_len: felt, 81 | calldata: felt*) {} 82 | 83 | func is_valid_signature( 84 | hash: felt, 85 | sig_len: felt, 86 | sig: felt* 87 | ) -> (isValid: felt) {} 88 | 89 | func validate( 90 | call_array_len: felt, 91 | call_array: CallArray*, 92 | calldata_len: felt, 93 | calldata: felt*) {} 94 | } 95 | ``` 96 | Plugins can be enabled and disabled with the methods `addPlugin` and `removePlugin` respectively. 97 | 98 | The presence of a plugin can be checked with the `isPlugin` method. 99 | 100 | ### Validating with a plugin: 101 | 102 | For every transaction the caller can instruct the account to validate the multi-call with a given plugin provided that the plugin has been registered in the account. Once the plugin is identified, the account will delegate the validation of the transaction to the plugin by calling the `validate` method of the plugin. 103 | 104 | We note that the plugin must be called with a `library_call` to comply to the constraints of the `__validate__` method, which prevents accessing the storage of other contracts. I.e. the logic of the plugin is executed in the context of the account and the state of the plugin, if any, must be stored in the account. 105 | 106 | To instruct the account to use a specific plugin we leverage the transaction signature data. By convention, the first item in the signature data specifies the class hash of the plugin which should be used for validation. Any additional context necessary to validate the transaction, such as the signature itself, should be appended to the signature data. 107 | 108 | So to validate a call using a specific plugin, the signature data should look like `[pluginClassHash, ...]` 109 | 110 | Similarly, the `isValidSignature` will validate a signature using the provided plugin in the passed signature data. 111 | 112 | ### Changing the state of a plugin: 113 | 114 | To manipulate the state of a plugin, the account has a `executeOnPlugin` that can be only called from the wallet 115 | 116 | ### Reading the state of a plugin: 117 | 118 | The view methods of a plugin can be accessed through the `readOnPlugin` method. 119 | 120 | ## Development 121 | 122 | ### Setup a local virtual env 123 | 124 | ``` 125 | python3.9 -m venv ./venv 126 | source ./venv/bin/activate 127 | ``` 128 | 129 | ### Install Cairo dependencies 130 | ``` 131 | brew install gmp 132 | ``` 133 | 134 | You might need this extra step if you are running on a Mac with the M1 chip 135 | 136 | ``` 137 | CFLAGS=-I`brew --prefix gmp`/include LDFLAGS=-L`brew --prefix gmp`/lib pip install ecdsa fastecdsa sympy 138 | ``` 139 | 140 | ``` 141 | pip install -r requirements.txt 142 | ``` 143 | 144 | See for more details: 145 | - https://www.cairo-lang.org/docs/quickstart.html 146 | - https://github.com/martriay/nile 147 | 148 | ### Compile the contracts 149 | ``` 150 | nile compile 151 | ``` 152 | 153 | ### Coverage 154 | ``` 155 | nile coverage 156 | ``` -------------------------------------------------------------------------------- /contracts/account/IPluginAccount.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | // Tmp struct introduced while we wait for Cairo 4 | // to support passing `[Call]` to __execute__ 5 | struct CallArray { 6 | to: felt, 7 | selector: felt, 8 | data_offset: felt, 9 | data_len: felt, 10 | } 11 | 12 | @contract_interface 13 | namespace IPluginAccount { 14 | 15 | ///////////////////// 16 | // Plugin 17 | ///////////////////// 18 | 19 | func addPlugin(plugin: felt, plugin_calldata_len: felt, plugin_calldata: felt*) { 20 | } 21 | 22 | func removePlugin(plugin: felt) { 23 | } 24 | 25 | func isPlugin(plugin: felt) -> (success: felt) { 26 | } 27 | 28 | 29 | func executeOnPlugin( 30 | plugin: felt, selector: felt, calldata_len: felt, calldata: felt* 31 | ) -> (retdata_len: felt, retdata: felt*){ 32 | } 33 | 34 | ///////////////////// 35 | // IAccount 36 | ///////////////////// 37 | 38 | func upgrade(implementation: felt) { 39 | } 40 | 41 | func supportsInterface(interfaceId: felt) -> (success: felt) { 42 | } 43 | 44 | func isValidSignature(hash: felt, signature_len: felt, signature: felt*) -> (isValid: felt) { 45 | } 46 | 47 | func __validate__( 48 | call_array_len: felt, 49 | call_array: CallArray*, 50 | calldata_len: felt, 51 | calldata: felt* 52 | ) { 53 | } 54 | 55 | // Parameter temporarily named `cls_hash` instead of `class_hash` (expected). 56 | // See https://github.com/starkware-libs/cairo-lang/issues/100 for details. 57 | func __validate_declare__(cls_hash: felt) { 58 | } 59 | 60 | // Parameter temporarily named `cls_hash` instead of `class_hash` (expected). 61 | // See https://github.com/starkware-libs/cairo-lang/issues/100 for details. 62 | func __validate_deploy__( 63 | cls_hash: felt, ctr_args_len: felt, ctr_args: felt*, salt: felt 64 | ) { 65 | } 66 | 67 | func __execute__( 68 | call_array_len: felt, 69 | call_array: CallArray*, 70 | calldata_len: felt, 71 | calldata: felt* 72 | ) -> (response_len: felt, response: felt*) { 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /contracts/account/PluginAccount.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin 4 | from starkware.cairo.common.alloc import alloc 5 | from starkware.starknet.common.syscalls import library_call 6 | from starkware.cairo.common.bool import TRUE 7 | 8 | from contracts.account.library import CallArray, PluginAccount, ERC165_ACCOUNT_INTERFACE_ID 9 | from contracts.upgrade.Upgradable import _set_implementation 10 | 11 | ///////////////////// 12 | // CONSTANTS 13 | ///////////////////// 14 | 15 | const NAME = 'PluginAccount'; 16 | const VERSION = '0.0.1'; 17 | const SUPPORTS_INTERFACE_SELECTOR = 1184015894760294494673613438913361435336722154500302038630992932234692784845; 18 | 19 | ///////////////////// 20 | // EVENTS 21 | ///////////////////// 22 | 23 | @event 24 | func account_upgraded(new_implementation: felt) { 25 | } 26 | 27 | ///////////////////// 28 | // PROTOCOL 29 | ///////////////////// 30 | 31 | @external 32 | func __validate__{ 33 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 34 | }( 35 | call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* 36 | ) { 37 | PluginAccount.validate(call_array_len, call_array, calldata_len, calldata); 38 | return (); 39 | } 40 | 41 | @external 42 | @raw_output 43 | func __execute__{ 44 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 45 | } ( 46 | call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* 47 | ) -> (retdata_size: felt, retdata: felt*) { 48 | let (response_len, response) = PluginAccount.execute(call_array_len, call_array, calldata_len, calldata); 49 | return (retdata_size=response_len, retdata=response); 50 | } 51 | 52 | @external 53 | func __validate_declare__{ 54 | syscall_ptr: felt*, 55 | pedersen_ptr: HashBuiltin*, 56 | ecdsa_ptr: SignatureBuiltin*, 57 | range_check_ptr 58 | } ( 59 | class_hash: felt 60 | ) { 61 | PluginAccount.validate_declare(); 62 | return (); 63 | } 64 | 65 | ///////////////////// 66 | // EXTERNAL FUNCTIONS 67 | ///////////////////// 68 | 69 | @external 70 | func initialize{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 71 | plugin: felt, plugin_calldata_len: felt, plugin_calldata: felt* 72 | ) { 73 | PluginAccount.initializer(plugin, plugin_calldata_len, plugin_calldata); 74 | return (); 75 | } 76 | 77 | @external 78 | func addPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt, plugin_calldata_len: felt, plugin_calldata: felt*) { 79 | PluginAccount.add_plugin(plugin, plugin_calldata_len, plugin_calldata); 80 | return (); 81 | } 82 | 83 | @external 84 | func removePlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt) { 85 | PluginAccount.remove_plugin(plugin); 86 | return (); 87 | } 88 | 89 | @external 90 | func executeOnPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 91 | plugin: felt, selector: felt, calldata_len: felt, calldata: felt* 92 | ) -> (retdata_len: felt, retdata: felt*) { 93 | return PluginAccount.execute_on_plugin(plugin, selector, calldata_len, calldata); 94 | } 95 | 96 | @external 97 | func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 98 | implementation: felt 99 | ) { 100 | // only called via execute 101 | PluginAccount.assert_only_self(); 102 | // make sure the target is an account 103 | with_attr error_message("PluginAccount: invalid implementation") { 104 | let (calldata: felt*) = alloc(); 105 | assert calldata[0] = ERC165_ACCOUNT_INTERFACE_ID; 106 | let (retdata_size: felt, retdata: felt*) = library_call( 107 | class_hash=implementation, 108 | function_selector=SUPPORTS_INTERFACE_SELECTOR, 109 | calldata_size=1, 110 | calldata=calldata, 111 | ); 112 | assert retdata_size = 1; 113 | assert [retdata] = TRUE; 114 | } 115 | // change implementation 116 | _set_implementation(implementation); 117 | account_upgraded.emit(new_implementation=implementation); 118 | 119 | return (); 120 | } 121 | 122 | ///////////////////// 123 | // VIEW FUNCTIONS 124 | ///////////////////// 125 | 126 | @view 127 | func isValidSignature{ 128 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 129 | }(hash: felt, sig_len: felt, sig: felt*) -> (isValid: felt) { 130 | let (isValid) = PluginAccount.is_valid_signature(hash, sig_len, sig); 131 | return (isValid=isValid); 132 | } 133 | 134 | @view 135 | func supportsInterface{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 136 | interfaceId: felt 137 | ) -> (success: felt) { 138 | let (res) = PluginAccount.is_interface_supported(interfaceId); 139 | return (res,); 140 | } 141 | 142 | @view 143 | func isPlugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin_id: felt) -> ( 144 | success: felt 145 | ) { 146 | let (res) = PluginAccount.is_plugin(plugin_id); 147 | return (success=res); 148 | } 149 | 150 | @view 151 | func getName() -> (name: felt) { 152 | return (name=NAME); 153 | } 154 | 155 | @view 156 | func getVersion() -> (version: felt) { 157 | return (version=VERSION); 158 | } 159 | -------------------------------------------------------------------------------- /contracts/account/library.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin 4 | from starkware.cairo.common.alloc import alloc 5 | from starkware.cairo.common.memcpy import memcpy 6 | from starkware.cairo.common.math import assert_not_zero, assert_not_equal 7 | from starkware.starknet.common.syscalls import ( 8 | library_call, 9 | call_contract, 10 | get_tx_info, 11 | get_contract_address, 12 | get_caller_address, 13 | ) 14 | from starkware.cairo.common.bool import TRUE, FALSE 15 | from contracts.account.IPluginAccount import CallArray 16 | 17 | const ERC165_ACCOUNT_INTERFACE_ID = 0x3943f10f; 18 | const TRANSACTION_VERSION = 1; 19 | const QUERY_VERSION = 2**128 + TRANSACTION_VERSION; 20 | 21 | struct Call { 22 | to: felt, 23 | selector: felt, 24 | calldata_len: felt, 25 | calldata: felt*, 26 | } 27 | 28 | ///////////////////// 29 | // INTERFACES 30 | ///////////////////// 31 | 32 | @contract_interface 33 | namespace IPlugin { 34 | func initialize(data_len: felt, data: felt*) { 35 | } 36 | 37 | func is_valid_signature(hash: felt, sig_len: felt, sig: felt*) -> (isValid: felt) { 38 | } 39 | 40 | func supportsInterface(interfaceId: felt) -> (success: felt) { 41 | } 42 | 43 | func validate( 44 | call_array_len: felt, 45 | call_array: CallArray*, 46 | calldata_len: felt, 47 | calldata: felt*, 48 | ) { 49 | } 50 | } 51 | 52 | ///////////////////// 53 | // EVENTS 54 | ///////////////////// 55 | 56 | @event 57 | func account_created(account: felt) { 58 | } 59 | 60 | @event 61 | func transaction_executed(hash: felt, response_len: felt, response: felt*) { 62 | } 63 | 64 | ///////////////////// 65 | // STORAGE VARIABLES 66 | ///////////////////// 67 | 68 | @storage_var 69 | func PluginAccount_plugins(plugin: felt) -> (res: felt) { 70 | } 71 | 72 | @storage_var 73 | func PluginAccount_initialized() -> (res: felt) { 74 | } 75 | 76 | namespace PluginAccount { 77 | func initializer{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 78 | plugin: felt, plugin_calldata_len: felt, plugin_calldata: felt* 79 | ) { 80 | let (initialized) = PluginAccount_initialized.read(); 81 | with_attr error_message("PluginAccount: already initialized") { 82 | assert initialized = 0; 83 | } 84 | 85 | with_attr error_message("PluginAccount: plugin cannot be null") { 86 | assert_not_zero(plugin); 87 | } 88 | 89 | PluginAccount_plugins.write(plugin, 1); 90 | PluginAccount_initialized.write(1); 91 | 92 | initialize_plugin(plugin, plugin_calldata_len, plugin_calldata); 93 | 94 | let (self) = get_contract_address(); 95 | account_created.emit(self); 96 | 97 | return (); 98 | } 99 | 100 | func validate{ 101 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 102 | }( 103 | call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* 104 | ) { 105 | alloc_locals; 106 | 107 | let (tx_info) = get_tx_info(); 108 | assert_correct_tx_version(tx_info.version); 109 | assert_initialized(); 110 | 111 | let (plugin) = get_plugin_from_signature(tx_info.signature_len, tx_info.signature); 112 | 113 | IPlugin.library_call_validate( 114 | class_hash=plugin, 115 | call_array_len=call_array_len, 116 | call_array=call_array, 117 | calldata_len=calldata_len, 118 | calldata=calldata, 119 | ); 120 | return (); 121 | } 122 | 123 | func validate_deploy{ 124 | syscall_ptr: felt*, 125 | pedersen_ptr: HashBuiltin*, 126 | ecdsa_ptr: SignatureBuiltin*, 127 | range_check_ptr 128 | }() { 129 | alloc_locals; 130 | let (tx_info) = get_tx_info(); 131 | let (is_valid) = is_valid_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature); 132 | with_attr error_message("PluginAccount: invalid deploy") { 133 | assert_not_zero(is_valid); 134 | } 135 | return (); 136 | } 137 | 138 | func validate_declare{ 139 | syscall_ptr: felt*, 140 | pedersen_ptr: HashBuiltin*, 141 | ecdsa_ptr: SignatureBuiltin*, 142 | range_check_ptr 143 | }() { 144 | alloc_locals; 145 | let (tx_info) = get_tx_info(); 146 | let (is_valid) = is_valid_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature); 147 | with_attr error_message("PluginAccount: invalid declare") { 148 | assert_not_zero(is_valid); 149 | } 150 | return (); 151 | } 152 | 153 | func execute{ 154 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 155 | }( 156 | call_array_len: felt, 157 | call_array: CallArray*, 158 | calldata_len: felt, 159 | calldata: felt*, 160 | ) -> (response_len: felt, response: felt*) { 161 | alloc_locals; 162 | 163 | let (tx_info) = get_tx_info(); 164 | assert_correct_tx_version(tx_info.version); 165 | assert_non_reentrant(); 166 | 167 | /////////////// TMP ///////////////////// 168 | // parse inputs to an array of 'Call' struct 169 | let (calls: Call*) = alloc(); 170 | from_call_array_to_call(call_array_len, call_array, calldata, calls); 171 | let calls_len = call_array_len; 172 | ////////////////////////////////////////// 173 | 174 | let (response: felt*) = alloc(); 175 | let (response_len) = execute_list(calls_len, calls, response); 176 | transaction_executed.emit( 177 | hash=tx_info.transaction_hash, response_len=response_len, response=response 178 | ); 179 | return (response_len, response); 180 | } 181 | 182 | func add_plugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt, plugin_calldata_len: felt, plugin_calldata: felt*) { 183 | assert_only_self(); 184 | 185 | with_attr error_message("PluginAccount: plugin cannot be null") { 186 | assert_not_zero(plugin); 187 | } 188 | 189 | let (is_plugin) = PluginAccount_plugins.read(plugin); 190 | with_attr error_message("PluginAccount: plugin already registered") { 191 | assert is_plugin = 0; 192 | } 193 | 194 | PluginAccount_plugins.write(plugin, 1); 195 | 196 | initialize_plugin(plugin, plugin_calldata_len, plugin_calldata); 197 | 198 | return (); 199 | } 200 | 201 | func remove_plugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt) { 202 | assert_only_self(); 203 | 204 | let (is_plugin) = PluginAccount_plugins.read(plugin); 205 | with_attr error_message("PluginAccount: unknown plugin") { 206 | assert_not_zero(is_plugin); 207 | } 208 | 209 | let (tx_info) = get_tx_info(); 210 | 211 | let (signature_plugin) = get_plugin_from_signature(tx_info.signature_len, tx_info.signature); 212 | with_attr error_message("PluginAccount: plugin can't remove itself") { 213 | assert_not_equal(signature_plugin, plugin); 214 | } 215 | 216 | PluginAccount_plugins.write(plugin, 0); 217 | return (); 218 | } 219 | 220 | func execute_on_plugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 221 | plugin: felt, selector: felt, calldata_len: felt, calldata: felt* 222 | ) -> (retdata_len: felt, retdata: felt*) { 223 | 224 | // only valid plugin 225 | let (is_plugin) = PluginAccount_plugins.read(plugin); 226 | assert_not_zero(is_plugin); 227 | 228 | let (retdata_len: felt, retdata: felt*) = library_call( 229 | class_hash=plugin, 230 | function_selector=selector, 231 | calldata_size=calldata_len, 232 | calldata=calldata, 233 | ); 234 | return (retdata_len=retdata_len, retdata=retdata); 235 | } 236 | 237 | func is_valid_signature{ 238 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 239 | }(hash: felt, sig_len: felt, sig: felt*) -> (is_valid: felt) { 240 | alloc_locals; 241 | 242 | let (plugin) = get_plugin_from_signature(sig_len, sig); 243 | 244 | let (is_valid) = IPlugin.library_call_is_valid_signature( 245 | class_hash=plugin, 246 | hash=hash, 247 | sig_len=sig_len, 248 | sig=sig 249 | ); 250 | 251 | return (is_valid=is_valid); 252 | } 253 | 254 | func is_interface_supported{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 255 | interface_id: felt 256 | ) -> (is_supported: felt) { 257 | // 165 258 | if (interface_id == 0x01ffc9a7) { 259 | return (TRUE,); 260 | } 261 | // IAccount 262 | if (interface_id == ERC165_ACCOUNT_INTERFACE_ID) { 263 | return (TRUE,); 264 | } 265 | 266 | return (FALSE,); 267 | } 268 | 269 | func is_plugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin: felt) -> ( 270 | success: felt 271 | ) { 272 | let (res) = PluginAccount_plugins.read(plugin); 273 | return (success=res); 274 | } 275 | 276 | func initialize_plugin{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 277 | plugin: felt, plugin_calldata_len: felt, plugin_calldata: felt* 278 | ) { 279 | if (plugin_calldata_len == 0) { 280 | return (); 281 | } 282 | 283 | IPlugin.library_call_initialize( 284 | class_hash=plugin, 285 | data_len=plugin_calldata_len, 286 | data=plugin_calldata, 287 | ); 288 | 289 | return (); 290 | } 291 | 292 | func get_plugin_from_signature{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 293 | signature_len: felt, signature: felt*, 294 | ) -> (plugin: felt) { 295 | alloc_locals; 296 | 297 | with_attr error_message("PluginAccount: invalid signature") { 298 | assert_not_zero(signature_len); 299 | } 300 | 301 | let plugin = signature[0]; 302 | 303 | let (is_plugin) = PluginAccount_plugins.read(plugin); 304 | with_attr error_message("PluginAccount: unregistered plugin") { 305 | assert_not_zero(is_plugin); 306 | } 307 | return (plugin=plugin); 308 | } 309 | 310 | func assert_only_self{syscall_ptr: felt*}() -> () { 311 | let (self) = get_contract_address(); 312 | let (caller_address) = get_caller_address(); 313 | with_attr error_message("PluginAccount: only self") { 314 | assert self = caller_address; 315 | } 316 | return (); 317 | } 318 | 319 | func assert_non_reentrant{syscall_ptr: felt*}() -> () { 320 | let (caller) = get_caller_address(); 321 | with_attr error_message("PluginAccount: no reentrant call") { 322 | assert caller = 0; 323 | } 324 | return (); 325 | } 326 | 327 | func assert_initialized{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { 328 | let (initialized) = PluginAccount_initialized.read(); 329 | with_attr error_message("PluginAccount: account not initialized") { 330 | assert_not_zero(initialized); 331 | } 332 | return (); 333 | } 334 | 335 | func assert_correct_tx_version{syscall_ptr: felt*}(tx_version: felt) -> () { 336 | with_attr error_message("PluginAccount: invalid tx version") { 337 | assert (tx_version - TRANSACTION_VERSION) * (tx_version - QUERY_VERSION) = 0; 338 | } 339 | return (); 340 | } 341 | 342 | // @notice Executes a list of contract calls recursively. 343 | // @param calls_len The number of calls to execute 344 | // @param calls A pointer to the first call to execute 345 | // @param response The array of felt to pupulate with the returned data 346 | // @return response_len The size of the returned data 347 | func execute_list{syscall_ptr: felt*}( 348 | calls_len: felt, calls: Call*, reponse: felt* 349 | ) -> (response_len: felt) { 350 | alloc_locals; 351 | 352 | // if no more calls 353 | if (calls_len == 0) { 354 | return (0,); 355 | } 356 | 357 | // do the current call 358 | let this_call: Call = [calls]; 359 | let res = call_contract( 360 | contract_address=this_call.to, 361 | function_selector=this_call.selector, 362 | calldata_size=this_call.calldata_len, 363 | calldata=this_call.calldata, 364 | ); 365 | 366 | // copy the result in response 367 | memcpy(reponse, res.retdata, res.retdata_size); 368 | // do the next calls recursively 369 | let (response_len) = execute_list( 370 | calls_len - 1, calls + Call.SIZE, reponse + res.retdata_size 371 | ); 372 | return (response_len + res.retdata_size,); 373 | } 374 | 375 | func from_call_array_to_call{syscall_ptr: felt*}( 376 | call_array_len: felt, call_array: CallArray*, calldata: felt*, calls: Call* 377 | ) { 378 | // if no more calls 379 | if (call_array_len == 0) { 380 | return (); 381 | } 382 | 383 | // parse the current call 384 | assert [calls] = Call( 385 | to=[call_array].to, 386 | selector=[call_array].selector, 387 | calldata_len=[call_array].data_len, 388 | calldata=calldata + [call_array].data_offset 389 | ); 390 | 391 | // parse the remaining calls recursively 392 | from_call_array_to_call( 393 | call_array_len - 1, call_array + CallArray.SIZE, calldata, calls + Call.SIZE 394 | ); 395 | return (); 396 | } 397 | } -------------------------------------------------------------------------------- /contracts/plugins/SessionKey.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin 4 | from starkware.cairo.common.signature import verify_ecdsa_signature 5 | from starkware.cairo.common.hash_state import ( 6 | HashState, 7 | hash_finalize, 8 | hash_init, 9 | hash_update, 10 | hash_update_single, 11 | ) 12 | from starkware.cairo.common.hash import hash2 13 | from starkware.cairo.common.math_cmp import is_le_felt 14 | from starkware.cairo.common.registers import get_fp_and_pc 15 | from starkware.cairo.common.alloc import alloc 16 | from starkware.cairo.common.bool import TRUE, FALSE 17 | from starkware.cairo.common.math import assert_not_zero, assert_nn 18 | from starkware.starknet.common.syscalls import ( 19 | call_contract, 20 | get_tx_info, 21 | get_contract_address, 22 | get_caller_address, 23 | get_block_timestamp, 24 | ) 25 | from contracts.account.IPluginAccount import CallArray 26 | 27 | // H('StarkNetDomain(chainId:felt)') 28 | const STARKNET_DOMAIN_TYPE_HASH = 0x13cda234a04d66db62c06b8e3ad5f91bd0c67286c2c7519a826cf49da6ba478; 29 | // H('Session(key:felt,expires:felt,root:merkletree)') 30 | const SESSION_TYPE_HASH = 0x1aa0e1c56b45cf06a54534fa1707c54e520b842feb21d03b7deddb6f1e340c; 31 | // H(Policy(contractAddress:felt,selector:selector)) 32 | const POLICY_TYPE_HASH = 0x2f0026e78543f036f33e26a8f5891b88c58dc1e20cbbfaf0bb53274da6fa568; 33 | 34 | @contract_interface 35 | namespace IAccount { 36 | func isValidSignature(hash: felt, sig_len: felt, sig: felt*) { 37 | } 38 | } 39 | 40 | @event 41 | func session_key_revoked(session_key: felt) { 42 | } 43 | 44 | @storage_var 45 | func SessionKey_revoked_keys(key: felt) -> (res: felt) { 46 | } 47 | 48 | @view 49 | func supportsInterface{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 50 | interfaceId: felt 51 | ) -> (success: felt) { 52 | // 165 53 | if (interfaceId == 0x01ffc9a7) { 54 | return (TRUE,); 55 | } 56 | return (FALSE,); 57 | } 58 | 59 | @view 60 | func is_valid_signature{ 61 | syscall_ptr : felt*, 62 | pedersen_ptr : HashBuiltin*, 63 | range_check_ptr, 64 | ecdsa_ptr: SignatureBuiltin* 65 | }( 66 | hash: felt, 67 | signature_len: felt, 68 | signature: felt* 69 | ) -> (is_valid: felt) { 70 | return (is_valid=FALSE); // This plugin can only validate call 71 | } 72 | @external 73 | func validate{ 74 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 75 | }( 76 | call_array_len: felt, 77 | call_array: CallArray*, 78 | calldata_len: felt, 79 | calldata: felt*, 80 | ) { 81 | alloc_locals; 82 | 83 | // get the tx info 84 | let (tx_info) = get_tx_info(); 85 | 86 | // parse the plugin data 87 | with_attr error_message("SessionKey: invalid plugin data") { 88 | let sig_r = tx_info.signature[1]; 89 | let sig_s = tx_info.signature[2]; 90 | let session_key = tx_info.signature[3]; 91 | let session_expires = tx_info.signature[4]; 92 | let root = tx_info.signature[5]; 93 | let proof_len = tx_info.signature[6]; 94 | let proofs_len = tx_info.signature[7]; 95 | let proofs = tx_info.signature + 8; 96 | let session_token_offset = 8 + proofs_len; 97 | let session_token_len = tx_info.signature[session_token_offset]; 98 | let session_token = tx_info.signature + session_token_offset + 1; 99 | } 100 | 101 | with_attr error_message("SessionKey: invalid proof len") { 102 | assert proofs_len = call_array_len * proof_len; 103 | } 104 | 105 | with_attr error_message("SessionKey: invalid signature length") { 106 | assert tx_info.signature_len = session_token_offset + 1 + session_token_len; 107 | } 108 | 109 | with_attr error_message("SessionKey: session expired") { 110 | let (now) = get_block_timestamp(); 111 | assert_nn(session_expires - now); 112 | } 113 | 114 | let (session_hash) = compute_session_hash( 115 | session_key, session_expires, root, tx_info.chain_id, tx_info.account_contract_address 116 | ); 117 | with_attr error_message("SessionKey: unauthorised session") { 118 | IAccount.isValidSignature( 119 | contract_address=tx_info.account_contract_address, 120 | hash=session_hash, 121 | sig_len=session_token_len, 122 | sig=session_token, 123 | ); 124 | } 125 | // check if the session key is revoked 126 | with_attr error_message("SessionKey: session key revoked") { 127 | let (is_revoked) = SessionKey_revoked_keys.read(session_key); 128 | assert is_revoked = 0; 129 | } 130 | // check if the tx is signed by the session key 131 | with_attr error_message("SessionKey: invalid signature") { 132 | verify_ecdsa_signature( 133 | message=tx_info.transaction_hash, 134 | public_key=session_key, 135 | signature_r=sig_r, 136 | signature_s=sig_s, 137 | ); 138 | } 139 | check_policy(call_array_len, call_array, root, proof_len, proofs_len, proofs); 140 | 141 | return (); 142 | } 143 | 144 | @external 145 | func revokeSessionKey{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 146 | session_key: felt 147 | ) { 148 | assert_only_self(); 149 | 150 | SessionKey_revoked_keys.write(session_key, 1); 151 | session_key_revoked.emit(session_key); 152 | return (); 153 | } 154 | 155 | ///////////////////// 156 | // INTERNAL FUNCTIONS 157 | ///////////////////// 158 | 159 | func check_policy{ 160 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ecdsa_ptr: SignatureBuiltin*, range_check_ptr 161 | }( 162 | call_array_len: felt, 163 | call_array: CallArray*, 164 | root: felt, 165 | proof_len: felt, 166 | proofs_len: felt, 167 | proofs: felt*, 168 | ) { 169 | alloc_locals; 170 | 171 | if (call_array_len == 0) { 172 | return (); 173 | } 174 | 175 | let hash_ptr = pedersen_ptr; 176 | with hash_ptr { 177 | let (hash_state) = hash_init(); 178 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=POLICY_TYPE_HASH); 179 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=[call_array].to); 180 | let (hash_state) = hash_update_single( 181 | hash_state_ptr=hash_state, item=[call_array].selector 182 | ); 183 | let (leaf) = hash_finalize(hash_state_ptr=hash_state); 184 | let pedersen_ptr = hash_ptr; 185 | } 186 | 187 | let (proof_valid) = merkle_verify(leaf, root, proof_len, proofs); 188 | with_attr error_message("SessionKey: not allowed by policy") { 189 | assert proof_valid = TRUE; 190 | } 191 | check_policy( 192 | call_array_len - 1, 193 | call_array + CallArray.SIZE, 194 | root, 195 | proof_len, 196 | proofs_len - proof_len, 197 | proofs + proof_len, 198 | ); 199 | return (); 200 | } 201 | 202 | func compute_session_hash{pedersen_ptr: HashBuiltin*}( 203 | session_key: felt, session_expires: felt, root: felt, chain_id: felt, account: felt 204 | ) -> (hash: felt) { 205 | alloc_locals; 206 | let hash_ptr = pedersen_ptr; 207 | with hash_ptr { 208 | let (hash_state) = hash_init(); 209 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item='StarkNet Message'); 210 | let (domain_hash) = hash_domain(chain_id); 211 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=domain_hash); 212 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=account); 213 | let (message_hash) = hash_message(session_key, session_expires, root); 214 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=message_hash); 215 | let (hash) = hash_finalize(hash_state_ptr=hash_state); 216 | let pedersen_ptr = hash_ptr; 217 | } 218 | return (hash=hash); 219 | } 220 | 221 | func hash_domain{hash_ptr: HashBuiltin*}(chain_id: felt) -> (hash: felt) { 222 | let (hash_state) = hash_init(); 223 | let (hash_state) = hash_update_single( 224 | hash_state_ptr=hash_state, item=STARKNET_DOMAIN_TYPE_HASH 225 | ); 226 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=chain_id); 227 | let (hash) = hash_finalize(hash_state_ptr=hash_state); 228 | return (hash=hash); 229 | } 230 | 231 | func hash_message{hash_ptr: HashBuiltin*}(session_key: felt, session_expires: felt, root: felt) -> ( 232 | hash: felt 233 | ) { 234 | let (hash_state) = hash_init(); 235 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=SESSION_TYPE_HASH); 236 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=session_key); 237 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=session_expires); 238 | let (hash_state) = hash_update_single(hash_state_ptr=hash_state, item=root); 239 | let (hash) = hash_finalize(hash_state_ptr=hash_state); 240 | return (hash=hash); 241 | } 242 | 243 | func merkle_verify{pedersen_ptr: HashBuiltin*, range_check_ptr}( 244 | leaf: felt, root: felt, proof_len: felt, proof: felt* 245 | ) -> (res: felt) { 246 | let (calc_root) = calc_merkle_root(leaf, proof_len, proof); 247 | // check if calculated root is equal to expected 248 | if (calc_root == root) { 249 | return (TRUE,); 250 | } else { 251 | return (FALSE,); 252 | } 253 | } 254 | 255 | // calculates the merkle root of a given proof 256 | func calc_merkle_root{pedersen_ptr: HashBuiltin*, range_check_ptr}( 257 | curr: felt, proof_len: felt, proof: felt* 258 | ) -> (res: felt) { 259 | alloc_locals; 260 | 261 | if (proof_len == 0) { 262 | return (curr,); 263 | } 264 | 265 | local node; 266 | local proof_elem = [proof]; 267 | let le = is_le_felt(curr, proof_elem); 268 | 269 | if (le == 1) { 270 | let (n) = hash2{hash_ptr=pedersen_ptr}(curr, proof_elem); 271 | node = n; 272 | } else { 273 | let (n) = hash2{hash_ptr=pedersen_ptr}(proof_elem, curr); 274 | node = n; 275 | } 276 | 277 | let (res) = calc_merkle_root(node, proof_len - 1, proof + 1); 278 | return (res,); 279 | } 280 | 281 | 282 | func assert_only_self{syscall_ptr: felt*}() -> () { 283 | let (self) = get_contract_address(); 284 | let (caller_address) = get_caller_address(); 285 | with_attr error_message("SessionKey: only self") { 286 | assert self = caller_address; 287 | } 288 | return (); 289 | } -------------------------------------------------------------------------------- /contracts/plugins/signer/StarkSigner.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin 4 | from starkware.cairo.common.math import assert_not_zero 5 | from starkware.cairo.common.bool import TRUE, FALSE 6 | from starkware.cairo.common.signature import verify_ecdsa_signature 7 | from contracts.account.IPluginAccount import CallArray 8 | from starkware.starknet.common.syscalls import ( 9 | get_tx_info, 10 | get_contract_address, 11 | get_caller_address, 12 | ) 13 | 14 | @storage_var 15 | func StarkSigner_public_key() -> (res: felt) { 16 | } 17 | 18 | @external 19 | func initialize{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(plugin_data_len: felt, plugin_data: felt*) { 20 | let (is_initialized) = StarkSigner_public_key.read(); 21 | with_attr error_message("StarkSigner: already initialized") { 22 | assert is_initialized = 0; 23 | } 24 | with_attr error_message("StarkSigner: initialise failed") { 25 | assert plugin_data_len = 1; 26 | } 27 | StarkSigner_public_key.write(plugin_data[0]); 28 | return (); 29 | } 30 | 31 | @external 32 | func setPublicKey{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 33 | public_key: felt 34 | ) { 35 | assert_only_self(); 36 | 37 | with_attr error_message("StarkSigner: public key can not be zero") { 38 | assert_not_zero(public_key); 39 | } 40 | StarkSigner_public_key.write(public_key); 41 | return (); 42 | } 43 | 44 | @view 45 | func getPublicKey{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( 46 | public_key: felt 47 | ) { 48 | let (public_key) = StarkSigner_public_key.read(); 49 | return (public_key=public_key); 50 | } 51 | 52 | 53 | @view 54 | func supportsInterface{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 55 | interfaceId: felt 56 | ) -> (success: felt) { 57 | // 165 58 | if (interfaceId == 0x01ffc9a7) { 59 | return (TRUE,); 60 | } 61 | return (FALSE,); 62 | } 63 | 64 | @view 65 | func validate{ 66 | syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, ecdsa_ptr: SignatureBuiltin* 67 | }( 68 | call_array_len: felt, 69 | call_array: CallArray*, 70 | calldata_len: felt, 71 | calldata: felt*, 72 | ) { 73 | alloc_locals; 74 | let (tx_info) = get_tx_info(); 75 | is_valid_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature); 76 | return (); 77 | } 78 | 79 | @view 80 | func is_valid_signature{ 81 | syscall_ptr : felt*, 82 | pedersen_ptr : HashBuiltin*, 83 | range_check_ptr, 84 | ecdsa_ptr: SignatureBuiltin* 85 | }( 86 | hash: felt, 87 | signature_len: felt, 88 | signature: felt* 89 | ) -> (is_valid: felt) { 90 | 91 | with_attr error_message("StarkSigner: invalid signature length") { 92 | assert signature_len = 3; 93 | } 94 | 95 | let (public_key) = StarkSigner_public_key.read(); 96 | 97 | let sig_r = signature[1]; 98 | let sig_s = signature[2]; 99 | 100 | verify_ecdsa_signature( 101 | message=hash, 102 | public_key=public_key, 103 | signature_r=sig_r, 104 | signature_s=sig_s 105 | ); 106 | 107 | return (is_valid=TRUE); 108 | } 109 | 110 | func assert_only_self{syscall_ptr: felt*}() -> () { 111 | let (self) = get_contract_address(); 112 | let (caller_address) = get_caller_address(); 113 | with_attr error_message("StarkSigner: only self") { 114 | assert self = caller_address; 115 | } 116 | return (); 117 | } 118 | -------------------------------------------------------------------------------- /contracts/test/Dapp.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | from starkware.cairo.common.math import assert_nn 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin 4 | 5 | @storage_var 6 | func balance() -> (res: felt) { 7 | } 8 | 9 | @external 10 | func increase_balance{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 11 | amount: felt 12 | ) { 13 | with_attr error_message("Amount must be positive. Got: {amount}.") { 14 | assert_nn(amount); 15 | } 16 | 17 | let (res) = balance.read(); 18 | balance.write(res + amount); 19 | return (); 20 | } 21 | 22 | @external 23 | func set_balance{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(amount: felt) { 24 | with_attr error_message("Amount must be positive. Got: {amount}.") { 25 | assert_nn(amount); 26 | } 27 | 28 | balance.write(amount); 29 | return (); 30 | } 31 | 32 | @external 33 | func set_balance_double{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 34 | amount: felt 35 | ) { 36 | with_attr error_message("Amount must be positive. Got: {amount}.") { 37 | assert_nn(amount); 38 | } 39 | 40 | balance.write(amount * 2); 41 | return (); 42 | } 43 | 44 | @external 45 | func set_balance_times3{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 46 | amount: felt 47 | ) { 48 | with_attr error_message("Amount must be positive. Got: {amount}.") { 49 | assert_nn(amount); 50 | } 51 | 52 | balance.write(amount * 3); 53 | return (); 54 | } 55 | 56 | @view 57 | func get_balance{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (res: felt) { 58 | let (res) = balance.read(); 59 | return (res,); 60 | } 61 | 62 | @constructor 63 | func constructor{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() { 64 | balance.write(0); 65 | return (); 66 | } 67 | -------------------------------------------------------------------------------- /contracts/test/FakeAccount.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin 4 | from starkware.cairo.common.bool import TRUE, FALSE 5 | 6 | const ERC165_ACCOUNT_INTERFACE_ID = 0x3943f10f; 7 | 8 | @view 9 | func supportsInterface{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 10 | interfaceId: felt 11 | ) -> (success: felt) { 12 | // IAccount 13 | if (interfaceId == ERC165_ACCOUNT_INTERFACE_ID) { 14 | return (TRUE,); 15 | } 16 | 17 | return (FALSE,); 18 | } 19 | -------------------------------------------------------------------------------- /contracts/upgrade/IProxy.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | @contract_interface 4 | namespace IProxy { 5 | func get_implementation() -> (implementation: felt) { 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /contracts/upgrade/Proxy.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin 4 | from starkware.starknet.common.syscalls import library_call, library_call_l1_handler 5 | 6 | from contracts.upgrade.Upgradable import _get_implementation, _set_implementation 7 | from contracts.account.library import PluginAccount 8 | 9 | ///////////////////// 10 | // CONSTRUCTOR 11 | ///////////////////// 12 | 13 | @constructor 14 | func constructor{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 15 | implementation: felt, selector: felt, calldata_len: felt, calldata: felt* 16 | ) { 17 | _set_implementation(implementation); 18 | library_call( 19 | class_hash=implementation, 20 | function_selector=selector, 21 | calldata_size=calldata_len, 22 | calldata=calldata, 23 | ); 24 | return (); 25 | } 26 | 27 | ///////////////////// 28 | // EXTERNAL FUNCTIONS 29 | ///////////////////// 30 | 31 | @external 32 | @raw_input 33 | @raw_output 34 | func __default__{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 35 | selector: felt, calldata_size: felt, calldata: felt* 36 | ) -> (retdata_size: felt, retdata: felt*) { 37 | let (implementation) = _get_implementation(); 38 | 39 | let (retdata_size: felt, retdata: felt*) = library_call( 40 | class_hash=implementation, 41 | function_selector=selector, 42 | calldata_size=calldata_size, 43 | calldata=calldata, 44 | ); 45 | return (retdata_size=retdata_size, retdata=retdata); 46 | } 47 | 48 | @l1_handler 49 | @raw_input 50 | func __l1_default__{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 51 | selector: felt, calldata_size: felt, calldata: felt* 52 | ) { 53 | let (implementation) = _get_implementation(); 54 | 55 | library_call_l1_handler( 56 | class_hash=implementation, 57 | function_selector=selector, 58 | calldata_size=calldata_size, 59 | calldata=calldata, 60 | ); 61 | return (); 62 | } 63 | 64 | @external 65 | func __validate_deploy__{ 66 | syscall_ptr: felt*, 67 | pedersen_ptr: HashBuiltin*, 68 | ecdsa_ptr: SignatureBuiltin*, 69 | range_check_ptr 70 | } ( 71 | class_hash: felt, 72 | contract_address_salt: felt, 73 | implementation: felt, 74 | selector: felt, 75 | calldata_len: felt, 76 | calldata: felt* 77 | ) { 78 | PluginAccount.validate_deploy(); 79 | return (); 80 | } 81 | 82 | ///////////////////// 83 | // VIEW FUNCTIONS 84 | ///////////////////// 85 | 86 | @view 87 | func get_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( 88 | implementation: felt 89 | ) { 90 | let (implementation) = _get_implementation(); 91 | return (implementation=implementation); 92 | } 93 | -------------------------------------------------------------------------------- /contracts/upgrade/Upgradable.cairo: -------------------------------------------------------------------------------- 1 | // Taken from: https://github.com/argentlabs/argent-contracts-starknet/blob/develop/contracts/upgrade/Upgradable.cairo 2 | %lang starknet 3 | 4 | from starkware.cairo.common.cairo_builtins import HashBuiltin 5 | from starkware.cairo.common.math import assert_not_zero 6 | 7 | ///////////////////// 8 | // STORAGE VARIABLES 9 | ///////////////////// 10 | 11 | @storage_var 12 | func _implementation() -> (address: felt) { 13 | } 14 | 15 | ///////////////////// 16 | // INTERNAL FUNCTIONS 17 | ///////////////////// 18 | 19 | func _get_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( 20 | implementation: felt 21 | ) { 22 | let (res) = _implementation.read(); 23 | return (implementation=res); 24 | } 25 | 26 | func _set_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( 27 | implementation: felt 28 | ) { 29 | assert_not_zero(implementation); 30 | _implementation.write(implementation); 31 | return (); 32 | } 33 | -------------------------------------------------------------------------------- /localhost.accounts.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /node.json: -------------------------------------------------------------------------------- 1 | {"localhost": "http://127.0.0.1:5050/"} -------------------------------------------------------------------------------- /protostar.toml: -------------------------------------------------------------------------------- 1 | ["protostar.config"] 2 | protostar_version = "0.7.0" 3 | 4 | ["protostar.project"] 5 | libs_path = "lib" 6 | 7 | ["protostar.contracts"] 8 | account = [ 9 | "contracts/account/PluginAccount.cairo", 10 | ] 11 | signer = [ 12 | "contracts/plugins/signer/StarkSigner.cairo" 13 | ] 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | log_cli = true 3 | log_cli_level = "INFO" 4 | log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" 5 | log_cli_date_format = "%Y-%m-%d %H:%M:%S" 6 | filterwarnings = [ 7 | "ignore::DeprecationWarning", 8 | ] 9 | asyncio_mode = "auto" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cairo-lang>=0.10.0 2 | cairo-nile>=0.9.0 3 | nile-coverage>=0.2.0 4 | pytest>=7.1.2 5 | pytest-asyncio>=0.19.0 -------------------------------------------------------------------------------- /tests/test_account.cairo: -------------------------------------------------------------------------------- 1 | %lang starknet 2 | 3 | from starkware.cairo.common.cairo_builtins import HashBuiltin 4 | from starkware.cairo.common.uint256 import Uint256 5 | from starkware.starknet.common.syscalls import get_block_timestamp, get_caller_address 6 | 7 | from contracts.account.IPluginAccount import IPluginAccount 8 | from contracts.upgrade.IProxy import IProxy 9 | 10 | @external 11 | func test_upgrade{syscall_ptr: felt*, range_check_ptr, pedersen_ptr: HashBuiltin*}() { 12 | alloc_locals; 13 | 14 | local initial_implementation: felt; 15 | local fake_implementation: felt; 16 | local account_address: felt; 17 | %{ 18 | from starkware.starknet.compiler.compile import get_selector_from_name 19 | ids.initial_implementation = declare("./contracts/account/PluginAccount.cairo").class_hash 20 | ids.fake_implementation = declare("./contracts/test/FakeAccount.cairo").class_hash 21 | signer_hash = declare("./contracts/plugins/signer/StarkSigner.cairo").class_hash 22 | ids.account_address = deploy_contract("./contracts/upgrade/Proxy.cairo", [ids.initial_implementation, get_selector_from_name('initialize'), 3, signer_hash, 1, 420]).contract_address 23 | 24 | # prank from self 25 | stop_prank_callable = start_prank(ids.account_address, target_contract_address=ids.account_address) 26 | %} 27 | 28 | let (implementation) = IProxy.get_implementation(account_address); 29 | 30 | assert implementation = initial_implementation; 31 | 32 | %{ expect_revert(error_message="PluginAccount: invalid implementation") %} 33 | IPluginAccount.upgrade(account_address, 0xdead); 34 | 35 | // Successfully upgrades to fake account that masquerades as ERC165. 36 | IPluginAccount.upgrade(account_address, fake_implementation); 37 | 38 | %{ 39 | stop_prank_callable() 40 | %} 41 | 42 | return (); 43 | } 44 | -------------------------------------------------------------------------------- /tests/test_account.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | import logging 4 | from starkware.starknet.testing.starknet import Starknet 5 | from utils.utils import compile, build_contract, StarkKeyPair, ERC165_INTERFACE_ID, ERC165_ACCOUNT_INTERFACE_ID, assert_event_emitted, assert_revert, str_to_felt 6 | from utils.plugin_signer import StarkPluginSigner 7 | from utils.session_keys_utils import SessionPluginSigner 8 | from starkware.starknet.compiler.compile import get_selector_from_name 9 | 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | signer_key = StarkKeyPair(123456789987654321) 14 | signer_key_2 = StarkKeyPair(123456789987654322) 15 | session_key = StarkKeyPair(666666666666666666) 16 | 17 | VERSION = '0.0.1' 18 | 19 | @pytest.fixture(scope='module') 20 | def event_loop(): 21 | return asyncio.new_event_loop() 22 | 23 | 24 | @pytest.fixture(scope='module') 25 | async def starknet(): 26 | return await Starknet.empty() 27 | 28 | 29 | @pytest.fixture(scope='module') 30 | async def account_setup(starknet: Starknet): 31 | account_cls = compile('contracts/account/PluginAccount.cairo') 32 | sts_plugin_cls = compile("contracts/plugins/signer/StarkSigner.cairo") 33 | sts_plugin_decl = await starknet.declare(contract_class=sts_plugin_cls) 34 | 35 | account = await starknet.deploy(contract_class=account_cls, constructor_calldata=[]) 36 | await account.initialize(sts_plugin_decl.class_hash, [signer_key.public_key]).execute() 37 | 38 | account2 = await starknet.deploy(contract_class=account_cls, constructor_calldata=[]) 39 | await account2.initialize(sts_plugin_decl.class_hash, [signer_key_2.public_key]).execute() 40 | return account, account2, account_cls, sts_plugin_decl 41 | 42 | 43 | @pytest.fixture(scope='module') 44 | async def session_plugin_setup(starknet: Starknet): 45 | session_key_cls = compile('contracts/plugins/SessionKey.cairo') 46 | session_key_decl = await starknet.declare(contract_class=session_key_cls) 47 | return session_key_decl 48 | 49 | 50 | @pytest.fixture(scope='module') 51 | async def dapp_setup(starknet: Starknet): 52 | dapp_cls = compile('contracts/test/Dapp.cairo') 53 | await starknet.declare(contract_class=dapp_cls) 54 | return await starknet.deploy(contract_class=dapp_cls, constructor_calldata=[]) 55 | 56 | 57 | @pytest.fixture 58 | async def network(starknet: Starknet, account_setup, session_plugin_setup, dapp_setup): 59 | account, account_2, account_cls, sts_plugin_decl = account_setup 60 | session_key_decl = session_plugin_setup 61 | 62 | clean_state = starknet.state.copy() 63 | account = build_contract(account, state=clean_state) 64 | account_2 = build_contract(account_2, state=clean_state) 65 | 66 | stark_plugin_signer = StarkPluginSigner( 67 | stark_key=signer_key, 68 | account=account, 69 | plugin_class_hash=sts_plugin_decl.class_hash 70 | ) 71 | 72 | stark_plugin_signer_2 = StarkPluginSigner( 73 | stark_key=signer_key_2, 74 | account=account_2, 75 | plugin_class_hash=sts_plugin_decl.class_hash 76 | ) 77 | 78 | session_plugin_signer = SessionPluginSigner( 79 | stark_key=session_key, 80 | account=account, 81 | plugin_class_hash=session_key_decl.class_hash 82 | ) 83 | dapp = build_contract(dapp_setup, state=clean_state) 84 | 85 | return account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp 86 | 87 | 88 | @pytest.mark.asyncio 89 | async def test_addPlugin(network): 90 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp = network 91 | plugin_class_hash = session_plugin_signer.plugin_class_hash 92 | assert (await account.isPlugin(plugin_class_hash).call()).result.success == 0 93 | await stark_plugin_signer.add_plugin(plugin_class_hash) 94 | assert (await account.isPlugin(plugin_class_hash).call()).result.success == 1 95 | 96 | 97 | @pytest.mark.asyncio 98 | async def test_removePlugin(network): 99 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp = network 100 | plugin_class_hash = session_plugin_signer.plugin_class_hash 101 | assert (await account.isPlugin(plugin_class_hash).call()).result.success == 0 102 | await stark_plugin_signer.add_plugin(plugin_class_hash) 103 | assert (await account.isPlugin(plugin_class_hash).call()).result.success == 1 104 | await stark_plugin_signer.remove_plugin(plugin_class_hash) 105 | assert (await account.isPlugin(plugin_class_hash).call()).result.success == 0 106 | 107 | 108 | @pytest.mark.asyncio 109 | async def test_supportsInterface(network): 110 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp = network 111 | assert (await account.supportsInterface(ERC165_INTERFACE_ID).call()).result.success == 1 112 | assert (await account.supportsInterface(ERC165_ACCOUNT_INTERFACE_ID).call()).result.success == 1 113 | assert (await account.supportsInterface(0x123).call()).result.success == 0 114 | assert (await account.getVersion().call()).result.version == str_to_felt(VERSION) 115 | 116 | 117 | @pytest.mark.asyncio 118 | async def test_dapp(network): 119 | account, stark_plugin_signer, stark_plugin_signer, session_plugin_signer, dapp = network 120 | assert (await dapp.get_balance().call()).result.res == 0 121 | tx_exec_info = await stark_plugin_signer.send_transaction( 122 | calls=[(dapp.contract_address, 'set_balance', [47])], 123 | ) 124 | assert_event_emitted( 125 | tx_exec_info, 126 | from_address=stark_plugin_signer.account.contract_address, 127 | name='transaction_executed', 128 | data=[] 129 | ) 130 | assert (await dapp.get_balance().call()).result.res == 47 131 | 132 | 133 | @pytest.mark.asyncio 134 | async def test_executeOnPlugin(network): 135 | # Account 2 tries to change the signer key on Account 1, via executeOnPlugin and via readOnPlugin 136 | 137 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp = network 138 | read_execution_info = await stark_plugin_signer.read_on_plugin("getPublicKey") 139 | assert read_execution_info.result[0] == [signer_key.public_key] 140 | 141 | set_public_key_arguments = [signer_key_2.public_key] 142 | exec_arguments = [ 143 | stark_plugin_signer.plugin_class_hash, 144 | get_selector_from_name("setPublicKey"), 145 | len(set_public_key_arguments), 146 | *set_public_key_arguments 147 | ] 148 | await assert_revert( 149 | stark_plugin_signer_2.send_transaction(calls=[(stark_plugin_signer.account.contract_address, 'executeOnPlugin', exec_arguments)]), 150 | reverted_with="StarkSigner: only self" 151 | ) 152 | 153 | read_execution_info = await stark_plugin_signer.read_on_plugin("getPublicKey") 154 | assert read_execution_info.result[0] == [signer_key.public_key] 155 | -------------------------------------------------------------------------------- /tests/test_session_key.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | import logging 4 | from starkware.starknet.testing.starknet import Starknet 5 | from starkware.starknet.definitions.general_config import StarknetChainId 6 | from starkware.starknet.business_logic.state.state import BlockInfo 7 | from utils.utils import assert_revert, compile, cached_contract, assert_event_emitted, StarkKeyPair, build_contract, ERC165_INTERFACE_ID, ERC165_ACCOUNT_INTERFACE_ID 8 | from utils.plugin_signer import StarkPluginSigner 9 | from utils.session_keys_utils import build_session, SessionPluginSigner 10 | from starkware.starknet.compiler.compile import get_selector_from_name 11 | 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | signer_key = StarkKeyPair(123456789987654321) 16 | signer_key_2 = StarkKeyPair(123456789987654322) 17 | session_key = StarkKeyPair(666666666666666666) 18 | wrong_session_key = StarkKeyPair(6767676767) 19 | 20 | DEFAULT_TIMESTAMP = 1640991600 21 | 22 | 23 | @pytest.fixture(scope='module') 24 | def event_loop(): 25 | return asyncio.new_event_loop() 26 | 27 | 28 | @pytest.fixture(scope='module') 29 | async def starknet(): 30 | return await Starknet.empty() 31 | 32 | 33 | def update_starknet_block(starknet, block_number=1, block_timestamp=DEFAULT_TIMESTAMP): 34 | old_block_info = starknet.state.state.block_info 35 | starknet.state.state.block_info = BlockInfo( 36 | block_number=block_number, 37 | block_timestamp=block_timestamp, 38 | gas_price=old_block_info.gas_price, 39 | starknet_version=old_block_info.starknet_version, 40 | sequencer_address=old_block_info.sequencer_address 41 | ) 42 | 43 | 44 | def reset_starknet_block(starknet): 45 | update_starknet_block(starknet=starknet) 46 | 47 | 48 | @pytest.fixture(scope='module') 49 | async def account_setup(starknet: Starknet): 50 | account_cls = compile('contracts/account/PluginAccount.cairo') 51 | session_key_cls = compile('contracts/plugins/SessionKey.cairo') 52 | sts_plugin_cls = compile("contracts/plugins/signer/StarkSigner.cairo") 53 | 54 | session_key_class = await starknet.declare(contract_class=session_key_cls) 55 | sts_plugin_decl = await starknet.declare(contract_class=sts_plugin_cls) 56 | 57 | account = await starknet.deploy(contract_class=account_cls, constructor_calldata=[]) 58 | account_2 = await starknet.deploy(contract_class=account_cls, constructor_calldata=[]) 59 | 60 | await account.initialize(sts_plugin_decl.class_hash, [signer_key.public_key]).execute() 61 | await account_2.initialize(sts_plugin_decl.class_hash, [signer_key_2.public_key]).execute() 62 | 63 | return account, account_2, session_key_class.class_hash, sts_plugin_decl.class_hash 64 | 65 | 66 | @pytest.fixture(scope='module') 67 | async def dapp_setup(starknet: Starknet): 68 | dapp_cls = compile('contracts/test/Dapp.cairo') 69 | await starknet.declare(contract_class=dapp_cls) 70 | dapp1 = await starknet.deploy(contract_class=dapp_cls, constructor_calldata=[]) 71 | dapp2 = await starknet.deploy(contract_class=dapp_cls, constructor_calldata=[]) 72 | return dapp1, dapp2 73 | 74 | 75 | @pytest.fixture 76 | def contracts(starknet: Starknet, account_setup, dapp_setup): 77 | account, account_2, session_plugin_class_hash, sts_plugin_class_hash = account_setup 78 | dapp1, dapp2 = dapp_setup 79 | clean_state = starknet.state.copy() 80 | 81 | account = build_contract(account, state=clean_state) 82 | account_2 = build_contract(account_2, state=clean_state) 83 | 84 | dapp1 = build_contract(dapp1, state=clean_state) 85 | dapp2 = build_contract(dapp2, state=clean_state) 86 | 87 | stark_plugin_signer = StarkPluginSigner( 88 | stark_key=signer_key, 89 | account=account, 90 | plugin_class_hash=sts_plugin_class_hash 91 | ) 92 | 93 | stark_plugin_signer_2 = StarkPluginSigner( 94 | stark_key=signer_key_2, 95 | account=account_2, 96 | plugin_class_hash=sts_plugin_class_hash 97 | ) 98 | 99 | session_plugin_signer = SessionPluginSigner( 100 | stark_key=session_key, 101 | account=account, 102 | plugin_class_hash=session_plugin_class_hash 103 | ) 104 | 105 | return account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp1, dapp2, session_plugin_class_hash 106 | 107 | 108 | @pytest.mark.asyncio 109 | async def test_call_dapp_with_session_key(starknet: Starknet, contracts): 110 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp1, dapp2, session_key_class = contracts 111 | 112 | # add session key plugin 113 | await stark_plugin_signer.add_plugin(session_key_class) 114 | 115 | # authorise session key 116 | session = build_session( 117 | signer=stark_plugin_signer, 118 | allowed_calls=[ 119 | (dapp1.contract_address, 'set_balance'), 120 | (dapp1.contract_address, 'set_balance_double'), 121 | (dapp2.contract_address, 'set_balance'), 122 | (dapp2.contract_address, 'set_balance_double'), 123 | (dapp2.contract_address, 'set_balance_times3'), 124 | ], 125 | session_public_key=session_key.public_key, 126 | session_expiration=DEFAULT_TIMESTAMP + 10, 127 | chain_id=StarknetChainId.TESTNET.value, 128 | account_address=account.contract_address 129 | ) 130 | 131 | assert (await dapp1.get_balance().call()).result.res == 0 132 | update_starknet_block(starknet=starknet, block_timestamp=DEFAULT_TIMESTAMP) 133 | # call with session key 134 | tx_exec_info = await session_plugin_signer.send_transaction( 135 | calls=[ 136 | (dapp1.contract_address, 'set_balance', [47]), 137 | (dapp2.contract_address, 'set_balance_times3', [20]) 138 | ], 139 | session=session 140 | ) 141 | 142 | assert_event_emitted( 143 | tx_exec_info, 144 | from_address=account.contract_address, 145 | name='transaction_executed', 146 | data=[] 147 | ) 148 | # check it worked 149 | assert (await dapp1.get_balance().call()).result.res == 47 150 | assert (await dapp2.get_balance().call()).result.res == 60 151 | 152 | # wrong policy call with random proof 153 | await assert_revert( 154 | session_plugin_signer.send_transaction_with_proofs( 155 | calls=[(dapp1.contract_address, 'set_balance_times3', [47])], 156 | proofs=[session.proofs[0]], 157 | session=session 158 | ), 159 | reverted_with="SessionKey: not allowed by policy" 160 | ) 161 | 162 | # revoke session key 163 | tx_exec_info = await stark_plugin_signer.execute_on_plugin("revokeSessionKey", [session_key.public_key], plugin=session_key_class) 164 | assert_event_emitted( 165 | tx_exec_info, 166 | from_address=account.contract_address, 167 | name='session_key_revoked', 168 | data=[session_key.public_key] 169 | ) 170 | # check the session key is no longer authorised 171 | await assert_revert( 172 | session_plugin_signer.send_transaction( 173 | calls=[(dapp1.contract_address, 'set_balance', [47])], 174 | session=session 175 | ), 176 | reverted_with="SessionKey: session key revoked" 177 | ) 178 | 179 | 180 | @pytest.mark.asyncio 181 | async def test_supportsInterface(contracts): 182 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp1, dapp2, session_key_class = contracts 183 | await stark_plugin_signer.add_plugin(session_key_class) 184 | assert (await stark_plugin_signer.read_on_plugin("supportsInterface", [ERC165_INTERFACE_ID], plugin=session_key_class)).result[0] == [1] 185 | assert (await stark_plugin_signer.read_on_plugin("supportsInterface", [ERC165_ACCOUNT_INTERFACE_ID], plugin=session_key_class)).result[0] == [0] 186 | assert (await stark_plugin_signer.read_on_plugin("supportsInterface", [ERC165_ACCOUNT_INTERFACE_ID], plugin=session_key_class)).result[0] == [0] 187 | 188 | 189 | @pytest.mark.asyncio 190 | async def test_dapp_bad_signature(starknet: Starknet, contracts): 191 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp, dapp2, session_key_class = contracts 192 | assert (await dapp.get_balance().call()).result.res == 0 193 | 194 | await stark_plugin_signer.add_plugin(session_key_class) 195 | update_starknet_block(starknet=starknet, block_timestamp=DEFAULT_TIMESTAMP) 196 | 197 | session = build_session( 198 | signer=stark_plugin_signer, 199 | allowed_calls=[(dapp.contract_address, 'set_balance')], 200 | session_public_key=session_key.public_key, 201 | session_expiration=DEFAULT_TIMESTAMP + 10, 202 | chain_id=StarknetChainId.TESTNET.value, 203 | account_address=account.contract_address 204 | ) 205 | 206 | signed_tx = await session_plugin_signer.get_signed_transaction( 207 | calls=[(dapp.contract_address, 'set_balance', [47])], 208 | session=session 209 | ) 210 | signed_tx.signature[2] = 3333 211 | 212 | await assert_revert( 213 | session_plugin_signer.send_signed_tx(signed_tx) 214 | ) 215 | assert (await dapp.get_balance().call()).result.res == 0 216 | 217 | 218 | @pytest.mark.asyncio 219 | async def test_dapp_long_signature(starknet: Starknet, contracts): 220 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp, dapp2, session_key_class = contracts 221 | assert (await dapp.get_balance().call()).result.res == 0 222 | 223 | await stark_plugin_signer.add_plugin(session_key_class) 224 | update_starknet_block(starknet=starknet, block_timestamp=DEFAULT_TIMESTAMP) 225 | 226 | session = build_session( 227 | signer=stark_plugin_signer, 228 | allowed_calls=[(dapp.contract_address, 'set_balance')], 229 | session_public_key=session_key.public_key, 230 | session_expiration=DEFAULT_TIMESTAMP + 10, 231 | chain_id=StarknetChainId.TESTNET.value, 232 | account_address=account.contract_address 233 | ) 234 | 235 | signed_tx = await session_plugin_signer.get_signed_transaction( 236 | calls=[(dapp.contract_address, 'set_balance', [47])], 237 | session= session 238 | ) 239 | signed_tx.signature.extend([1, 1, 1, 1]) 240 | await assert_revert( 241 | session_plugin_signer.send_signed_tx(signed_tx), 242 | reverted_with="SessionKey: invalid signature length" 243 | ) 244 | 245 | signed_tx = await session_plugin_signer.get_signed_transaction( 246 | calls=[(dapp.contract_address, 'set_balance', [47])], 247 | session=session 248 | ) 249 | index_proofs_len = 7 250 | proofs_len = signed_tx.signature[index_proofs_len] 251 | index_session_token_len = index_proofs_len + proofs_len + 1 252 | assert signed_tx.signature[index_session_token_len] == len(session.session_token) 253 | 254 | signed_tx.signature[index_proofs_len] = proofs_len + 1 255 | signed_tx.signature.insert(index_session_token_len, 3333) 256 | 257 | await assert_revert( 258 | session_plugin_signer.send_signed_tx(signed_tx), 259 | reverted_with="SessionKey: invalid proof len" 260 | ) 261 | 262 | assert (await dapp.get_balance().call()).result.res == 0 263 | 264 | signed_tx = await session_plugin_signer.get_signed_transaction( 265 | calls=[(dapp.contract_address, 'set_balance', [47])], 266 | session=session, 267 | ) 268 | tx_exec_info = await session_plugin_signer.send_signed_tx(signed_tx) 269 | 270 | assert_event_emitted( 271 | tx_exec_info, 272 | from_address=account.contract_address, 273 | name='transaction_executed', 274 | data=[] 275 | ) 276 | # check it worked 277 | assert (await dapp.get_balance().call()).result.res == 47 278 | 279 | @pytest.mark.asyncio 280 | async def test_executeOnPlugin(starknet: Starknet, contracts): 281 | # Account 2 tries to revoke a session key on Account 1, via executeOnPlugin and via readOnPlugin 282 | 283 | account, stark_plugin_signer, stark_plugin_signer_2, session_plugin_signer, dapp1, dapp2, session_key_class = contracts 284 | 285 | await stark_plugin_signer.add_plugin(session_key_class) 286 | update_starknet_block(starknet=starknet, block_timestamp=DEFAULT_TIMESTAMP) 287 | 288 | session = build_session( 289 | signer=stark_plugin_signer, 290 | allowed_calls=[(dapp1.contract_address, 'set_balance')], 291 | session_public_key=session_key.public_key, 292 | session_expiration=DEFAULT_TIMESTAMP + 10, 293 | chain_id=StarknetChainId.TESTNET.value, 294 | account_address=account.contract_address 295 | ) 296 | 297 | revoke_session_key_arguments = [session.session_public_key] 298 | exec_arguments = [ 299 | session_plugin_signer.plugin_class_hash, 300 | get_selector_from_name("revokeSessionKey"), 301 | len(revoke_session_key_arguments), 302 | *revoke_session_key_arguments 303 | ] 304 | await assert_revert( 305 | stark_plugin_signer_2.send_transaction( 306 | [(stark_plugin_signer.account.contract_address, 'executeOnPlugin', exec_arguments)] 307 | ), 308 | reverted_with="SessionKey: only self" 309 | ) 310 | -------------------------------------------------------------------------------- /tests/test_stark_signer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | from starkware.starknet.testing.starknet import Starknet 4 | from utils.utils import str_to_felt, build_contract, compile 5 | from utils.utils import StarkKeyPair, ERC165_INTERFACE_ID, ERC165_ACCOUNT_INTERFACE_ID, assert_event_emitted, assert_revert 6 | from utils.plugin_signer import StarkPluginSigner 7 | 8 | key_pair = StarkKeyPair(1234) 9 | new_key_pair = StarkKeyPair(5678) 10 | 11 | 12 | @pytest.fixture(scope='module') 13 | def event_loop(): 14 | return asyncio.new_event_loop() 15 | 16 | 17 | @pytest.fixture(scope='module') 18 | async def starknet(): 19 | return await Starknet.empty() 20 | 21 | 22 | @pytest.fixture(scope='module') 23 | async def account_setup(starknet: Starknet): 24 | account_cls = compile('contracts/account/PluginAccount.cairo') 25 | sts_plugin_cls = compile("contracts/plugins/signer/StarkSigner.cairo") 26 | 27 | sts_plugin_decl = await starknet.declare(contract_class=sts_plugin_cls) 28 | 29 | account = await starknet.deploy( 30 | contract_class=account_cls, 31 | constructor_calldata=[] 32 | ) 33 | 34 | await account.initialize(sts_plugin_decl.class_hash, [key_pair.public_key]).execute() 35 | 36 | return account, sts_plugin_decl.class_hash 37 | 38 | 39 | @pytest.fixture(scope='module') 40 | async def dapp(starknet: Starknet): 41 | dapp_cls = compile('contracts/test/Dapp.cairo') 42 | await starknet.declare(contract_class=dapp_cls) 43 | return await starknet.deploy(contract_class=dapp_cls, constructor_calldata=[]) 44 | 45 | 46 | @pytest.fixture 47 | def contracts(starknet: Starknet, account_setup, dapp): 48 | account, sts_plugin_class_hash = account_setup 49 | clean_state = starknet.state.copy() 50 | 51 | account = build_contract(account, state=clean_state) 52 | 53 | stark_plugin_signer = StarkPluginSigner( 54 | stark_key=key_pair, 55 | account=account, 56 | plugin_class_hash=sts_plugin_class_hash 57 | ) 58 | 59 | dapp = build_contract(dapp, state=clean_state) 60 | 61 | return account, stark_plugin_signer, sts_plugin_class_hash, dapp 62 | 63 | 64 | @pytest.mark.asyncio 65 | async def test_initialise(contracts): 66 | account, stark_plugin_signer, sts_plugin_hash, _ = contracts 67 | 68 | execution_info = await account.getName().call() 69 | assert execution_info.result == (str_to_felt('PluginAccount'),) 70 | 71 | execution_info = await account.isPlugin(sts_plugin_hash).call() 72 | assert execution_info.result == (1,) 73 | 74 | execution_info = await stark_plugin_signer.read_on_plugin("getPublicKey") 75 | assert execution_info.result[0] == [stark_plugin_signer.public_key] 76 | 77 | 78 | @pytest.mark.asyncio 79 | async def test_change_public_key(contracts): 80 | account, stark_plugin_signer, sts_plugin_hash, _ = contracts 81 | 82 | await stark_plugin_signer.execute_on_plugin( 83 | selector_name="setPublicKey", 84 | arguments=[new_key_pair.public_key] 85 | ) 86 | 87 | execution_info = await stark_plugin_signer.read_on_plugin("getPublicKey") 88 | assert execution_info.result[0] == [new_key_pair.public_key] 89 | 90 | 91 | @pytest.mark.asyncio 92 | async def test_supportsInterface(contracts): 93 | account, stark_plugin_signer, sts_plugin_hash, _ = contracts 94 | assert (await stark_plugin_signer.read_on_plugin("supportsInterface", [ERC165_INTERFACE_ID])).result[0] == [1] 95 | assert (await stark_plugin_signer.read_on_plugin("supportsInterface", [ERC165_ACCOUNT_INTERFACE_ID])).result[0] == [0] 96 | assert (await stark_plugin_signer.read_on_plugin("supportsInterface", [ERC165_ACCOUNT_INTERFACE_ID])).result[0] == [0] 97 | 98 | 99 | @pytest.mark.asyncio 100 | async def test_dapp(contracts): 101 | account, stark_plugin_signer, sts_plugin_hash, dapp = contracts 102 | assert (await dapp.get_balance().call()).result.res == 0 103 | tx_exec_info = await stark_plugin_signer.send_transaction( 104 | calls=[(dapp.contract_address, 'set_balance', [47])], 105 | ) 106 | assert_event_emitted( 107 | tx_exec_info, 108 | from_address=account.contract_address, 109 | name='transaction_executed', 110 | data=[] 111 | ) 112 | assert (await dapp.get_balance().call()).result.res == 47 113 | 114 | 115 | @pytest.mark.asyncio 116 | async def test_dapp_bad_signature(contracts): 117 | account, stark_plugin_signer, sts_plugin_hash, dapp = contracts 118 | assert (await dapp.get_balance().call()).result.res == 0 119 | 120 | signed_tx = await stark_plugin_signer.get_signed_transaction( 121 | calls=[(dapp.contract_address, 'set_balance', [47])], 122 | ) 123 | signed_tx.signature[2] = 3333 124 | 125 | await assert_revert( 126 | stark_plugin_signer.send_signed_tx(signed_tx) 127 | ) 128 | assert (await dapp.get_balance().call()).result.res == 0 129 | 130 | 131 | @pytest.mark.asyncio 132 | async def test_dapp_long_signature(contracts): 133 | account, stark_plugin_signer, sts_plugin_hash, dapp = contracts 134 | assert (await dapp.get_balance().call()).result.res == 0 135 | 136 | signed_tx = await stark_plugin_signer.get_signed_transaction( 137 | calls=[(dapp.contract_address, 'set_balance', [47])], 138 | ) 139 | signed_tx.signature.extend([1, 1, 1, 1]) 140 | await assert_revert( 141 | stark_plugin_signer.send_signed_tx(signed_tx) 142 | ) 143 | assert (await dapp.get_balance().call()).result.res == 0 144 | -------------------------------------------------------------------------------- /tests/utils/merkle_utils.py: -------------------------------------------------------------------------------- 1 | from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash 2 | from starkware.cairo.common.hash_state import compute_hash_on_elements 3 | 4 | # generates merkle root from values list 5 | # each pair of values must be in sorted order 6 | def generate_merkle_root(values: 'list[int]') -> int: 7 | if len(values) == 1: 8 | return values[0] 9 | 10 | if len(values) % 2 != 0: 11 | values.append(0) 12 | 13 | next_level = get_next_level(values) 14 | return generate_merkle_root(next_level) 15 | 16 | # generates merkle proof from an index of the value list 17 | # each pair of values must be in sorted order 18 | def generate_merkle_proof(values: 'list[int]', index: int) -> 'list[int]': 19 | return generate_proof_helper(values, index, []) 20 | 21 | # checks the validity of a merkle proof 22 | # the last element of the proof should be the root 23 | def verify_merkle_proof(leaf: int, proof: 'list[int]') -> bool: 24 | root = proof[len(proof)-1] 25 | proof = proof[:-1] 26 | curr = leaf 27 | 28 | for proof_elem in proof: 29 | if curr < proof_elem: 30 | curr = pedersen_hash(curr, proof_elem) 31 | else: 32 | curr = pedersen_hash(proof_elem, curr) 33 | 34 | return curr == root 35 | 36 | # creates the inital merkle leaf values to use 37 | def get_leaves(policy_type_hash: 'int', contracts: 'list[int]', selectors: 'list[int]') -> 'list[tuple[int, int, int]]': 38 | values = [] 39 | for i in range(0, len(contracts)): 40 | leaf = compute_hash_on_elements([policy_type_hash, contracts[i], selectors[i]]) 41 | value = (leaf, contracts[i], selectors[i]) 42 | values.append(value) 43 | 44 | if len(values) % 2 != 0: 45 | last_value = (0, 0, 0) 46 | values.append(last_value) 47 | 48 | return values 49 | 50 | def get_next_level(level: 'list[int]') -> 'list[int]': 51 | next_level = [] 52 | 53 | for i in range(0, len(level), 2): 54 | node = 0 55 | if level[i] < level[i+1]: 56 | node = pedersen_hash(level[i], level[i+1]) 57 | else: 58 | node = pedersen_hash(level[i+1], level[i]) 59 | 60 | next_level.append(node) 61 | 62 | return next_level 63 | 64 | def generate_proof_helper(level: 'list[int]', index: int, proof: 'list[int]') -> 'list[int]': 65 | if len(level) == 1: 66 | return proof 67 | if len(level) % 2 != 0: 68 | level.append(0) 69 | 70 | next_level = get_next_level(level) 71 | index_parent = 0 72 | 73 | for i in range(0, len(level)): 74 | if i == index: 75 | index_parent = i // 2 76 | if i % 2 == 0: 77 | proof.append(level[index+1]) 78 | else: 79 | proof.append(level[index-1]) 80 | 81 | return generate_proof_helper(next_level, index_parent, proof) -------------------------------------------------------------------------------- /tests/utils/plugin_signer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Optional, List, Tuple 3 | from starkware.crypto.signature.signature import sign 4 | from starkware.starknet.testing.contract import StarknetContract 5 | from starkware.starknet.definitions.general_config import StarknetChainId 6 | from starkware.starknet.core.os.transaction_hash.transaction_hash import calculate_transaction_hash_common, TransactionHashPrefix 7 | from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Declare 8 | from starkware.starknet.business_logic.transaction.objects import InternalTransaction, TransactionExecutionInfo 9 | from starkware.starknet.compiler.compile import get_selector_from_name 10 | from utils.utils import from_call_to_call_array, StarkKeyPair, cached_contract, copy_contract_state 11 | TRANSACTION_VERSION = 1 12 | 13 | 14 | class PluginSigner: 15 | def __init__(self, account: StarknetContract, plugin_class_hash): 16 | self.account = account 17 | self.plugin_class_hash = plugin_class_hash 18 | 19 | @abstractmethod 20 | def sign(self, message_hash: int) -> List[int]: 21 | ... 22 | 23 | async def send_transaction(self, calls, nonce: Optional[int] = None, max_fee: Optional[int] = 0) -> TransactionExecutionInfo: 24 | return await self.send_signed_tx(await self.get_signed_transaction(calls, nonce, max_fee)) 25 | 26 | async def send_signed_tx(self, signed_tx: InvokeFunction) -> TransactionExecutionInfo : 27 | return await self.account.state.execute_tx( 28 | tx=InternalTransaction.from_external( 29 | external_tx=signed_tx, 30 | general_config=self.account.state.general_config 31 | ) 32 | ) 33 | 34 | async def get_signed_transaction(self, calls, nonce: Optional[int] = None, max_fee: Optional[int] = 0) -> InvokeFunction: 35 | call_array, calldata = from_call_to_call_array(calls) 36 | 37 | account_copy = copy_contract_state(self.account) 38 | raw_invocation = account_copy.__execute__(call_array, calldata) 39 | 40 | if nonce is None: 41 | nonce = await raw_invocation.state.state.get_nonce_at(contract_address=account_copy.contract_address) 42 | 43 | transaction_hash = calculate_transaction_hash_common( 44 | tx_hash_prefix=TransactionHashPrefix.INVOKE, 45 | version=TRANSACTION_VERSION, 46 | contract_address=self.account.contract_address, 47 | entry_point_selector=0, 48 | calldata=raw_invocation.calldata, 49 | max_fee=max_fee, 50 | chain_id=StarknetChainId.TESTNET.value, 51 | additional_data=[nonce], 52 | ) 53 | 54 | signature = self.sign(transaction_hash) 55 | 56 | external_tx = InvokeFunction( 57 | contract_address=self.account.contract_address, 58 | calldata=raw_invocation.calldata, 59 | entry_point_selector=None, 60 | signature=signature, 61 | max_fee=max_fee, 62 | version=TRANSACTION_VERSION, 63 | nonce=nonce, 64 | ) 65 | return external_tx 66 | 67 | async def execute_on_plugin(self, selector_name, arguments=None, plugin=None): 68 | if arguments is None: 69 | arguments = [] 70 | 71 | if plugin is None: 72 | plugin = self.plugin_class_hash 73 | 74 | exec_arguments = [ 75 | plugin, 76 | get_selector_from_name(selector_name), 77 | len(arguments), 78 | *arguments 79 | ] 80 | return await self.send_transaction([(self.account.contract_address, 'executeOnPlugin', exec_arguments)]) 81 | 82 | async def read_on_plugin(self, selector_name, arguments=None, plugin=None): 83 | if arguments is None: 84 | arguments = [] 85 | 86 | if plugin is None: 87 | plugin = self.plugin_class_hash 88 | 89 | selector = get_selector_from_name(selector_name) 90 | return await self.account.executeOnPlugin(plugin, selector, arguments).call() 91 | 92 | async def add_plugin(self, plugin: int, plugin_arguments=None): 93 | if plugin_arguments is None: 94 | plugin_arguments = [] 95 | return await self.send_transaction([(self.account.contract_address, 'addPlugin', [plugin, len(plugin_arguments), *plugin_arguments])]) 96 | 97 | async def remove_plugin(self, plugin: int): 98 | return await self.send_transaction([(self.account.contract_address, 'removePlugin', [plugin])]) 99 | 100 | async def getVersion(self): 101 | return await self.send_transaction([(self.account.contract_address, 'getVersion', [])]) 102 | 103 | 104 | class StarkPluginSigner(PluginSigner): 105 | def __init__(self, stark_key: StarkKeyPair, account: StarknetContract, plugin_class_hash): 106 | super().__init__(account, plugin_class_hash) 107 | self.stark_key = stark_key 108 | self.public_key = stark_key.public_key 109 | 110 | def sign(self, message_hash: int) -> List[int]: 111 | return [self.plugin_class_hash] + list(self.stark_key.sign(message_hash)) -------------------------------------------------------------------------------- /tests/utils/session_keys_utils.py: -------------------------------------------------------------------------------- 1 | from starkware.cairo.common.hash_state import compute_hash_on_elements 2 | from typing import Optional, List, Tuple 3 | from utils.merkle_utils import get_leaves, generate_merkle_root, generate_merkle_proof 4 | from starkware.starknet.compiler.compile import get_selector_from_name 5 | from utils.utils import str_to_felt 6 | from utils.plugin_signer import PluginSigner, TRANSACTION_VERSION 7 | from dataclasses import dataclass 8 | from utils.utils import from_call_to_call_array, copy_contract_state, StarkKeyPair 9 | from starkware.starknet.testing.contract import StarknetContract 10 | from starkware.starknet.core.os.transaction_hash.transaction_hash import calculate_transaction_hash_common, TransactionHashPrefix 11 | from starkware.starknet.business_logic.transaction.objects import InternalTransaction, TransactionExecutionInfo 12 | from starkware.starknet.definitions.general_config import StarknetChainId 13 | from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Declare 14 | 15 | AllowedCall = Tuple[int,str] 16 | # H('StarkNetDomain(chainId:felt)') 17 | STARKNET_DOMAIN_TYPE_HASH = 0x13cda234a04d66db62c06b8e3ad5f91bd0c67286c2c7519a826cf49da6ba478 18 | # H('Session(key:felt,expires:felt,root:merkletree)') 19 | SESSION_TYPE_HASH = 0x1aa0e1c56b45cf06a54534fa1707c54e520b842feb21d03b7deddb6f1e340c 20 | # H(Policy(contractAddress:felt,selector:selector)) 21 | POLICY_TYPE_HASH = 0x2f0026e78543f036f33e26a8f5891b88c58dc1e20cbbfaf0bb53274da6fa568 22 | 23 | 24 | # Returns the tree root and proofs for each allowed call 25 | def generate_policy_tree(allowed_calls : List[AllowedCall]) -> Tuple[int, List[List[int]]]: 26 | merkle_leaves: List[Tuple[int, int, int]] = get_leaves( 27 | policy_type_hash=POLICY_TYPE_HASH, 28 | contracts=[a[0] for a in allowed_calls], 29 | selectors=[get_selector_from_name(a[1]) for a in allowed_calls], 30 | ) 31 | leaves = [leave[0] for leave in merkle_leaves] 32 | root = generate_merkle_root(leaves) 33 | proofs = [generate_merkle_proof(leaves, index) for index, leave in enumerate(leaves)] 34 | return root, proofs 35 | 36 | 37 | @dataclass 38 | class Session: 39 | session_public_key: int 40 | session_expiration: int 41 | root: int 42 | allowed_calls: List[AllowedCall] 43 | proofs: List[List[int]] 44 | session_hash: int 45 | account_address: int 46 | session_token: List[int] 47 | 48 | def single_proof_len(self) -> int: 49 | return len(self.proofs[0]) 50 | 51 | 52 | def build_session(signer, allowed_calls: List[AllowedCall], session_public_key: int, session_expiration:int, chain_id:int, account_address: int): 53 | root, proofs = generate_policy_tree(allowed_calls) 54 | domain_hash = compute_hash_on_elements([STARKNET_DOMAIN_TYPE_HASH, chain_id]) 55 | message_hash = compute_hash_on_elements([SESSION_TYPE_HASH, session_public_key, session_expiration, root]) 56 | 57 | session_hash = compute_hash_on_elements([ 58 | str_to_felt('StarkNet Message'), 59 | domain_hash, 60 | account_address, 61 | message_hash 62 | ]) 63 | signed_hash = signer.sign(session_hash) 64 | return Session( 65 | session_public_key=session_public_key, 66 | session_expiration=session_expiration, 67 | root=root, 68 | allowed_calls=allowed_calls, 69 | proofs=proofs, 70 | session_hash=session_hash, 71 | account_address=account_address, 72 | session_token=signed_hash 73 | ) 74 | 75 | 76 | class SessionPluginSigner(PluginSigner): 77 | def __init__(self, stark_key: StarkKeyPair, account: StarknetContract, plugin_class_hash): 78 | super().__init__(account, plugin_class_hash) 79 | self.stark_key = stark_key 80 | self.public_key = stark_key.public_key 81 | 82 | def sign(self, message_hash: int) -> List[int]: 83 | raise Exception("SessionPluginSigner can't sign arbitrary messages") 84 | 85 | async def get_signed_transaction(self, calls, session: Session, nonce: Optional[int] = None, max_fee: Optional[int] = 0) -> InvokeFunction: 86 | proofs = [] 87 | for call in calls: 88 | call_proof_index = session.allowed_calls.index((call[0], call[1])) 89 | proofs.append(session.proofs[call_proof_index]) 90 | return await self.get_signed_transaction_with_proofs(calls, session, proofs, nonce, max_fee) 91 | 92 | async def get_signed_transaction_with_proofs(self, calls, session: Session, proofs: List[List[int]], nonce: Optional[int] = None, max_fee: Optional[int] = 0) -> InvokeFunction: 93 | call_array, calldata = from_call_to_call_array(calls) 94 | 95 | account_copy = copy_contract_state(self.account) 96 | 97 | raw_invocation = account_copy.__execute__(call_array, calldata) 98 | 99 | if nonce is None: 100 | nonce = await raw_invocation.state.state.get_nonce_at(contract_address=account_copy.contract_address) 101 | 102 | transaction_hash = calculate_transaction_hash_common( 103 | tx_hash_prefix=TransactionHashPrefix.INVOKE, 104 | version=TRANSACTION_VERSION, 105 | contract_address=self.account.contract_address, 106 | entry_point_selector=0, 107 | calldata=raw_invocation.calldata, 108 | max_fee=max_fee, 109 | chain_id=StarknetChainId.TESTNET.value, 110 | additional_data=[nonce], 111 | ) 112 | 113 | session_signature = self.stark_key.sign(transaction_hash) 114 | proofs_flat = [item for proof in proofs for item in proof] 115 | signature = [ 116 | self.plugin_class_hash, 117 | *session_signature, # session signature 118 | session.session_public_key, # session_key 119 | session.session_expiration, # expiration 120 | session.root, # root 121 | session.single_proof_len(), # single_proof_len 122 | len(proofs_flat), # proofs_len 123 | *proofs_flat, # proofs 124 | len(session.session_token), # session_token_len 125 | *session.session_token # session_token 126 | ] 127 | 128 | return InvokeFunction( 129 | contract_address=self.account.contract_address, 130 | calldata=raw_invocation.calldata, 131 | entry_point_selector=None, 132 | signature=signature, 133 | max_fee=max_fee, 134 | version=TRANSACTION_VERSION, 135 | nonce=nonce, 136 | ) 137 | 138 | async def send_transaction(self, calls, session: Session, nonce: Optional[int] = None, max_fee: Optional[int] = 0) -> TransactionExecutionInfo: 139 | signed_tx = await self.get_signed_transaction(calls, session, nonce, max_fee) 140 | return await self.send_signed_tx(signed_tx) 141 | 142 | async def send_transaction_with_proofs(self, calls, session: Session, proofs: List[List[int]], nonce: Optional[int] = None, max_fee: Optional[int] = 0) -> TransactionExecutionInfo : 143 | signed_tx = await self.get_signed_transaction_with_proofs(calls, session, proofs, nonce, max_fee) 144 | return await self.send_signed_tx(signed_tx) 145 | 146 | async def send_signed_tx(self, signed_tx: InvokeFunction) -> TransactionExecutionInfo: 147 | return await self.account.state.execute_tx( 148 | tx=InternalTransaction.from_external( 149 | external_tx=signed_tx, 150 | general_config=self.account.state.general_config 151 | ) 152 | ) 153 | -------------------------------------------------------------------------------- /tests/utils/utils.py: -------------------------------------------------------------------------------- 1 | from starkware.crypto.signature.signature import private_to_stark_key 2 | from starkware.starknet.services.api.contract_class import ContractClass 3 | from starkware.starknet.testing.contract import StarknetContract 4 | from starkware.starknet.testing.state import StarknetState 5 | from starkware.starkware_utils.error_handling import StarkException 6 | from starkware.starknet.compiler.compile import compile_starknet_files 7 | from starkware.starknet.compiler.compile import get_selector_from_name 8 | from starkware.starknet.business_logic.execution.objects import Event 9 | from typing import Optional, List, Tuple 10 | from starkware.crypto.signature.signature import private_to_stark_key, sign 11 | from starkware.starknet.public.abi import AbiType 12 | from starkware.starknet.definitions.error_codes import StarknetErrorCode 13 | 14 | 15 | ERC165_INTERFACE_ID = 0x01ffc9a7 16 | ERC165_ACCOUNT_INTERFACE_ID = 0x3943f10f 17 | 18 | def str_to_felt(text: str) -> int: 19 | b_text = bytes(text, 'UTF-8') 20 | return int.from_bytes(b_text, "big") 21 | 22 | def compile(path: str) -> ContractClass: 23 | contract_cls = compile_starknet_files([path], debug_info=True) 24 | return contract_cls 25 | 26 | 27 | def cached_contract(state: StarknetState, _class: ContractClass, deployed: StarknetContract) -> StarknetContract: 28 | return build_contract( 29 | state=state, 30 | contract=deployed, 31 | custom_abi=_class.abi 32 | ) 33 | 34 | 35 | def copy_contract_state(contract: StarknetContract) -> StarknetContract: 36 | return build_contract(contract=contract, state=contract.state.copy()) 37 | 38 | 39 | def build_contract(contract: StarknetContract, state: StarknetState = None, custom_abi: AbiType = None) -> StarknetContract: 40 | return StarknetContract( 41 | state=contract.state if state is None else state, 42 | abi=contract.abi if custom_abi is None else custom_abi, 43 | contract_address=contract.contract_address, 44 | deploy_call_info=contract.deploy_call_info 45 | ) 46 | 47 | 48 | async def assert_revert(fun, reverted_with: Optional[str] = None): 49 | try: 50 | res = await fun 51 | assert False, "Transaction didn't revert as expected" 52 | except StarkException as err: 53 | _, error = err.args 54 | assert error['code'] == StarknetErrorCode.TRANSACTION_FAILED, f"assert expected: {StarknetErrorCode.TRANSACTION_FAILED}, got error: {error['code']}" 55 | if reverted_with is not None: 56 | errors_found = [s.removeprefix("Error message: ") for s in error['message'].splitlines() if s.startswith("Error message: ")] 57 | assert reverted_with in errors_found, f"assert expected: {reverted_with}, found errors: {errors_found}" 58 | 59 | 60 | def assert_event_emitted(tx_exec_info, from_address, name, data = []): 61 | if not data: 62 | raw_events = [Event(from_address=event.from_address, keys=event.keys, data=[]) for event in tx_exec_info.get_sorted_events()] 63 | else: 64 | raw_events = [Event(from_address=event.from_address, keys=event.keys, data=event.data) for event in tx_exec_info.get_sorted_events()] 65 | 66 | assert Event( 67 | from_address=from_address, 68 | keys=[get_selector_from_name(name)], 69 | data=data, 70 | ) in raw_events 71 | 72 | 73 | class StarkKeyPair: 74 | def __init__(self, private_key: int): 75 | self.private_key = private_key 76 | self.public_key = private_to_stark_key(private_key) 77 | 78 | def sign(self, message_hash: int) -> Tuple[int, int]: 79 | return sign(msg_hash=message_hash, priv_key=self.private_key) 80 | 81 | 82 | def from_call_to_call_array(calls): 83 | call_array = [] 84 | calldata = [] 85 | for call in calls: 86 | assert len(call) == 3, "Invalid call parameters" 87 | entry = (call[0], get_selector_from_name(call[1]), len(calldata), len(call[2])) 88 | call_array.append(entry) 89 | calldata.extend(call[2]) 90 | return call_array, calldata 91 | --------------------------------------------------------------------------------