├── rs_zktrie ├── src │ ├── utils.rs │ ├── lib.rs │ ├── db.rs │ ├── hash.rs │ ├── trie.rs │ └── types.rs ├── .gitignore ├── README.md └── Cargo.toml ├── .gitignore ├── docs ├── assets │ ├── arch.png │ ├── deletion.png │ └── insertion.png └── zktrie.md ├── go.mod ├── src ├── constants.rs ├── rs_lib.rs ├── go_lib.rs └── lib.rs ├── c.go ├── trie ├── zk_trie_proof_test.go ├── zk_trie_database.go ├── zk_trie_database_test.go ├── zk_trie_proof.go ├── zk_trie_test.go ├── zk_trie.go ├── zk_trie_node_test.go ├── zk_trie_node.go └── zk_trie_impl_test.go ├── types ├── byte32.go ├── byte32_test.go ├── hash_test.go ├── util_test.go ├── util.go ├── hash.go └── README.md ├── go.sum ├── Cargo.toml ├── README.md ├── lib.go └── LICENSE /rs_zktrie/src/utils.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /_obj 3 | Cargo.lock 4 | *.exe 5 | -------------------------------------------------------------------------------- /rs_zktrie/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /_obj 3 | Cargo.lock 4 | *.exe 5 | -------------------------------------------------------------------------------- /docs/assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scroll-tech/zktrie/HEAD/docs/assets/arch.png -------------------------------------------------------------------------------- /docs/assets/deletion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scroll-tech/zktrie/HEAD/docs/assets/deletion.png -------------------------------------------------------------------------------- /docs/assets/insertion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scroll-tech/zktrie/HEAD/docs/assets/insertion.png -------------------------------------------------------------------------------- /rs_zktrie/README.md: -------------------------------------------------------------------------------- 1 | # rustrie 2 | This is a reimplementation of zktrie from https://github.com/scroll-tech/zktrie 3 | -------------------------------------------------------------------------------- /rs_zktrie/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod db; 2 | pub mod hash; 3 | pub mod raw; 4 | pub mod trie; 5 | pub mod types; 6 | pub mod utils; 7 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/scroll-tech/zktrie 2 | 3 | go 1.21 4 | 5 | require github.com/stretchr/testify v1.7.0 6 | 7 | require ( 8 | github.com/davecgh/go-spew v1.1.0 // indirect 9 | github.com/pmezard/go-difflib v1.0.0 // indirect 10 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect 11 | ) 12 | -------------------------------------------------------------------------------- /src/constants.rs: -------------------------------------------------------------------------------- 1 | pub const HASHLEN: usize = 32; 2 | pub const FIELDSIZE: usize = 32; 3 | pub const ACCOUNTFIELDS: usize = 5; 4 | pub const ACCOUNTSIZE: usize = FIELDSIZE * ACCOUNTFIELDS; 5 | pub type Hash = [u8; HASHLEN]; 6 | pub type StoreData = [u8; FIELDSIZE]; 7 | pub type AccountData = [[u8; FIELDSIZE]; ACCOUNTFIELDS]; 8 | -------------------------------------------------------------------------------- /c.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | /* 4 | typedef char* (*hashF)(unsigned char*, unsigned char*, unsigned char*, unsigned char*); 5 | typedef void (*proveWriteF)(unsigned char*, int, void*); 6 | 7 | hashF hash_scheme = NULL; 8 | 9 | char* bridge_hash(unsigned char* a, unsigned char* b, unsigned char* domain, unsigned char* out){ 10 | return hash_scheme(a, b, domain, out); 11 | } 12 | 13 | void init_hash_scheme(hashF f){ 14 | hash_scheme = f; 15 | } 16 | 17 | void bridge_prove_write(proveWriteF f, unsigned char* key, unsigned char* val, int size, void* param){ 18 | f(val, size, param); 19 | } 20 | 21 | 22 | */ 23 | import "C" 24 | -------------------------------------------------------------------------------- /trie/zk_trie_proof_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "testing" 5 | 6 | zkt "github.com/scroll-tech/zktrie/types" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestDecodeSMTProof(t *testing.T) { 11 | magicBytes := ProofMagicBytes() 12 | node, err := DecodeSMTProof(magicBytes) 13 | assert.NoError(t, err) 14 | assert.Nil(t, node) 15 | 16 | k1 := zkt.NewHashFromBytes([]byte{1, 2, 3, 4, 5}) 17 | k2 := zkt.NewHashFromBytes([]byte{6, 7, 8, 9, 0}) 18 | origNode := NewParentNode(NodeTypeBranch_0, k1, k2) 19 | node, err = DecodeSMTProof(origNode.Value()) 20 | assert.NoError(t, err) 21 | assert.Equal(t, origNode.Value(), node.Value()) 22 | } 23 | -------------------------------------------------------------------------------- /rs_zktrie/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zktrie_rust" 3 | authors = ["xgao@zoyoe.com"] 4 | version.workspace = true 5 | edition.workspace = true 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [lib] 10 | name = "zktrie_rust" 11 | 12 | [dependencies] 13 | strum = "0.24" 14 | strum_macros = "0.24" 15 | lazy_static = "1.4" 16 | num-derive = "0.3" 17 | num-traits = "0.2" 18 | num = "0.4" 19 | log = "0.4.17" 20 | hex.workspace = true 21 | 22 | [dev-dependencies] 23 | ff = "0.12" 24 | poseidon = { git = "https://github.com/lanbones/poseidon" } 25 | halo2_proofs = { git = "https://github.com/DelphinusLab/halo2-gpu-specific.git", default-features = true } 26 | -------------------------------------------------------------------------------- /types/byte32.go: -------------------------------------------------------------------------------- 1 | package zktrie 2 | 3 | import ( 4 | "math/big" 5 | ) 6 | 7 | type Byte32 [32]byte 8 | 9 | func (b *Byte32) Hash() (*big.Int, error) { 10 | first16 := new(big.Int).SetBytes(b[0:16]) 11 | last16 := new(big.Int).SetBytes(b[16:32]) 12 | hash, err := hashScheme([]*big.Int{first16, last16}, big.NewInt(HASH_DOMAIN_BYTE32)) 13 | if err != nil { 14 | return nil, err 15 | } 16 | return hash, nil 17 | } 18 | 19 | func (b *Byte32) Bytes() []byte { return b[:] } 20 | 21 | // same action as common.Hash (truncate bytes longer than 32 bytes FROM beginning, 22 | // and padding 0 at the beginning for shorter bytes) 23 | func NewByte32FromBytes(b []byte) *Byte32 { 24 | 25 | byte32 := new(Byte32) 26 | 27 | if len(b) > 32 { 28 | b = b[len(b)-32:] 29 | } 30 | 31 | copy(byte32[32-len(b):], b) 32 | return byte32 33 | } 34 | 35 | // create bytes32 with zeropadding to shorter bytes, or truncate it 36 | func NewByte32FromBytesPaddingZero(b []byte) *Byte32 { 37 | byte32 := new(Byte32) 38 | copy(byte32[:], b) 39 | return byte32 40 | } 41 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 6 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 7 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 8 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 9 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 10 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 11 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 12 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["rs_zktrie"] 3 | 4 | [workspace.package] 5 | edition = "2021" 6 | version = "0.3.0" 7 | 8 | [workspace.dependencies] 9 | hex = "0.4" 10 | 11 | [package] 12 | name = "zktrie" 13 | version.workspace = true 14 | edition.workspace = true 15 | 16 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 17 | # links = "zktrie" 18 | 19 | [dependencies] 20 | zktrie_rust = { path = "rs_zktrie"} 21 | 22 | 23 | [build-dependencies] 24 | gobuild = { git = "https://github.com/scroll-tech/gobuild.git" } 25 | 26 | [dev-dependencies] 27 | hex.workspace = true 28 | halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2022_09_10" } 29 | poseidon-circuit = { git = "https://github.com/scroll-tech/poseidon-circuit.git", branch = "main" } 30 | 31 | [patch."https://github.com/privacy-scaling-explorations/halo2.git"] 32 | halo2_proofs = { git = "https://github.com/scroll-tech/halo2.git", branch = "v1.1" } 33 | 34 | 35 | [features] 36 | rs_zktrie = [] 37 | default = ["rs_zktrie"] 38 | -------------------------------------------------------------------------------- /trie/zk_trie_database.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "math/big" 5 | "sync" 6 | ) 7 | 8 | type ZktrieDatabase interface { 9 | UpdatePreimage(preimage []byte, hashField *big.Int) 10 | Put(k, v []byte) error 11 | Get(key []byte) ([]byte, error) 12 | } 13 | 14 | type Database struct { 15 | db map[string][]byte 16 | lock sync.RWMutex 17 | } 18 | 19 | func (db *Database) UpdatePreimage([]byte, *big.Int) {} 20 | 21 | func (db *Database) Put(k, v []byte) error { 22 | db.lock.Lock() 23 | defer db.lock.Unlock() 24 | 25 | db.db[string(k)] = v 26 | return nil 27 | } 28 | 29 | func (db *Database) Get(key []byte) ([]byte, error) { 30 | db.lock.RLock() 31 | defer db.lock.RUnlock() 32 | 33 | if entry, ok := db.db[string(key)]; ok { 34 | return entry, nil 35 | } 36 | return nil, ErrKeyNotFound 37 | 38 | } 39 | 40 | // Init flush db with batches of k/v without locking 41 | func (db *Database) Init(k, v []byte) { 42 | db.db[string(k)] = v 43 | } 44 | 45 | func NewZkTrieMemoryDb() *Database { 46 | return &Database{ 47 | db: make(map[string][]byte), 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /trie/zk_trie_database_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestDatabase(t *testing.T) { 12 | db := NewZkTrieMemoryDb() 13 | db.UpdatePreimage(nil, nil) 14 | for i := 0; i < 100; i++ { 15 | key := []byte(fmt.Sprintf("key_%d", i)) 16 | value := []byte(fmt.Sprintf("value_%d", i)) 17 | db.Init(key, value) 18 | } 19 | 20 | var wg sync.WaitGroup 21 | wg.Add(3) 22 | 23 | go func() { 24 | defer wg.Done() 25 | for i := 0; i < 100; i++ { 26 | key := []byte(fmt.Sprintf("key_%d", i)) 27 | value := []byte(fmt.Sprintf("value_%d", i)) 28 | err := db.Put(key, value) 29 | assert.NoError(t, err) 30 | } 31 | }() 32 | 33 | go func() { 34 | defer wg.Done() 35 | for i := 0; i < 100; i++ { 36 | key := []byte(fmt.Sprintf("key_%d", i)) 37 | value := []byte(fmt.Sprintf("value_%d", i)) 38 | gotValue, err := db.Get(key) 39 | assert.NoError(t, err) 40 | assert.Equal(t, value, gotValue) 41 | } 42 | }() 43 | 44 | go func() { 45 | defer wg.Done() 46 | for i := 100; i < 200; i++ { 47 | key := []byte(fmt.Sprintf("key_%d", i)) 48 | value, err := db.Get(key) 49 | assert.Equal(t, ErrKeyNotFound, err) 50 | assert.Nil(t, value) 51 | } 52 | }() 53 | 54 | wg.Wait() 55 | } 56 | -------------------------------------------------------------------------------- /rs_zktrie/src/db.rs: -------------------------------------------------------------------------------- 1 | use crate::raw::ImplError; 2 | use std::collections::HashMap; 3 | pub trait ZktrieDatabase { 4 | fn put(&mut self, k: Vec, v: Vec) -> Result<(), ImplError>; 5 | fn get(&self, k: &[u8]) -> Result<&[u8], ImplError>; 6 | } 7 | 8 | #[derive(Clone, Default)] 9 | pub struct SimpleDb { 10 | db: HashMap, Box<[u8]>>, 11 | } 12 | 13 | impl SimpleDb { 14 | pub fn new() -> Self { 15 | Self::default() 16 | } 17 | 18 | pub fn merge(&mut self, other: Self) { 19 | self.db.extend(other.db); 20 | } 21 | } 22 | 23 | impl ZktrieDatabase for SimpleDb { 24 | fn put(&mut self, k: Vec, v: Vec) -> Result<(), ImplError> { 25 | self.db.insert(k.into_boxed_slice(), v.into_boxed_slice()); 26 | Ok(()) 27 | } 28 | 29 | fn get(&self, k: &[u8]) -> Result<&[u8], ImplError> { 30 | self.db 31 | .get(k) 32 | .map(|v| v.as_ref()) 33 | .ok_or(ImplError::ErrKeyNotFound) 34 | } 35 | } 36 | 37 | #[cfg(test)] 38 | mod test { 39 | use super::{SimpleDb, ZktrieDatabase}; 40 | 41 | #[test] 42 | fn test_db() { 43 | let k1 = [1u8; 32].to_vec(); 44 | let k2 = [3u8; 32].to_vec(); 45 | let v1 = [2u8; 256].to_vec(); 46 | let v2 = [4u8; 256].to_vec(); 47 | let mut d = SimpleDb::new(); 48 | d.put(k1.clone(), v1.clone()).unwrap(); 49 | d.put(k2.clone(), v2.clone()).unwrap(); 50 | let v0 = d.get(&k1).unwrap(); 51 | assert_eq!(v0.as_ref(), v1.as_slice()); 52 | let v0 = d.get(&k2).unwrap(); 53 | assert_eq!(v0.as_ref(), v2.as_slice()); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /types/byte32_test.go: -------------------------------------------------------------------------------- 1 | package zktrie 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "math/big" 7 | "os" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func setupENV() { 14 | InitHashScheme(func(arr []*big.Int, domain *big.Int) (*big.Int, error) { 15 | lcEff := big.NewInt(65536) 16 | qString := "21888242871839275222246405745257275088548364400416034343698204186575808495617" 17 | Q, ok := new(big.Int).SetString(qString, 10) 18 | if !ok { 19 | panic(fmt.Sprintf("Bad base 10 string %s", qString)) 20 | } 21 | sum := domain 22 | for _, bi := range arr { 23 | nbi := new(big.Int).Mul(bi, bi) 24 | sum.Mul(sum, sum) 25 | sum.Mul(sum, lcEff) 26 | sum.Add(sum, nbi) 27 | } 28 | return sum.Mod(sum, Q), nil 29 | }) 30 | } 31 | 32 | func TestMain(m *testing.M) { 33 | setupENV() 34 | os.Exit(m.Run()) 35 | } 36 | 37 | func TestNewByte32(t *testing.T) { 38 | var tests = []struct { 39 | input []byte 40 | expected []byte 41 | expectedPaddingZero []byte 42 | expectedHash string 43 | expectedHashPadding string 44 | }{ 45 | {bytes.Repeat([]byte{1}, 4), 46 | []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}, 47 | []byte{1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 48 | "19342813114117753747472897", 49 | "4198633341355723145865718849633731687852896197776343461751712629107518959468", 50 | }, 51 | {bytes.Repeat([]byte{1}, 34), 52 | []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 53 | []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 54 | "19162873132136764367682277409313605623778997630491468285254908822491098844002", 55 | "19162873132136764367682277409313605623778997630491468285254908822491098844002", 56 | }, 57 | } 58 | 59 | for _, tt := range tests { 60 | byte32Result := NewByte32FromBytes(tt.input) 61 | byte32PaddingResult := NewByte32FromBytesPaddingZero(tt.input) 62 | assert.Equal(t, tt.expected, byte32Result.Bytes()) 63 | assert.Equal(t, tt.expectedPaddingZero, byte32PaddingResult.Bytes()) 64 | hashResult, err := byte32Result.Hash() 65 | assert.NoError(t, err) 66 | hashPaddingResult, err := byte32PaddingResult.Hash() 67 | assert.NoError(t, err) 68 | assert.Equal(t, tt.expectedHash, hashResult.String()) 69 | assert.Equal(t, tt.expectedHashPadding, hashPaddingResult.String()) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zktrie 2 | 3 | zktrie is a binary poseidon trie used in Scroll Network. 4 | 5 | Go and Rust implementations are provided inside this repo. 6 | 7 | ## Design Doc 8 | 9 | See the technical [docs here](docs/zktrie.md). 10 | 11 | ## Example codes 12 | 13 | [Rust example code](https://github.com/scroll-tech/stateless-block-verifier/blob/56b4aaf1d89a297a16a2934f579a116de024d213/src/executor.rs#L103) 14 | [Go example code](https://github.com/scroll-tech/go-ethereum/blob/develop/trie/zk_trie.go) 15 | 16 | ## Rust Usage 17 | 18 | We must init the crate with a poseidon hash scheme before any actions. [This](https://github.com/scroll-tech/zkevm-circuits/blob/e5c5522d544ce936290ef53e00c2d17a0e9b8d0b/zktrie/src/state/builder.rs#L17) is an example 19 | 20 | 21 | All the zktrie can share one underlying database, which can be initialized by putting the encoded trie node data directly 22 | 23 | ```rust 24 | 25 | let mut db = ZkMemoryDb::new(); 26 | 27 | /* for some trie node data encoded as bytes `buf` */ 28 | db.add_node_data(&buf).unwrap(); 29 | 30 | /* or if we have both trie node data and key encoded as bytes */ 31 | db.add_node_bytes(&buf, Some(&key)).unwrap(); 32 | 33 | ``` 34 | 35 | We must prove the root for a trie to create it, the corresponding root node must have been input in the database 36 | 37 | ```rust 38 | let root = hex::decode("079a038fbf78f25a2590e5a1d2fa34ce5e5f30e9a332713b43fa0e51b8770ab8") 39 | .unwrap(); 40 | let root: Hash = root.as_slice().try_into().unwrap(); 41 | 42 | let mut trie = db.new_trie(&root).unwrap(); 43 | ``` 44 | 45 | The trie can be updated by a single 32-bytes buffer if it is storage trie, or a `[[u8;32];4]` array for the account data `{nonce, balance, codehash, storageRoot}` if it is account trie 46 | 47 | ```rust 48 | let acc_buf = hex::decode("4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").unwrap(); 49 | let code_hash: [u8;32] = hex::decode("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").unwrap().as_slice().try_into().unwrap(); 50 | 51 | /* update an externally-owned account (so its storageRoot is all zero and code_hash equal to keccak256(nil)) */ 52 | let newacc: AccountData = [nonce, balance, code_hash, [0; FIELDSIZE]]; 53 | trie.update_account(&acc_buf, &newacc).unwrap(); 54 | 55 | ``` 56 | 57 | The root and mpt path for an address can be query from trie by `ZkTrie::root` and `ZkTrie::prove` 58 | 59 | 60 | ## License 61 | 62 | Licensed under either of 63 | 64 | - Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 65 | - MIT License (http://opensource.org/licenses/MIT) 66 | 67 | at your discretion. 68 | -------------------------------------------------------------------------------- /trie/zk_trie_proof.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | 7 | zkt "github.com/scroll-tech/zktrie/types" 8 | ) 9 | 10 | var magicSMTBytes []byte 11 | 12 | func init() { 13 | magicSMTBytes = []byte("THIS IS SOME MAGIC BYTES FOR SMT m1rRXgP2xpDI") 14 | } 15 | 16 | func ProofMagicBytes() []byte { return magicSMTBytes } 17 | 18 | // DecodeProof try to decode a node bytes, return can be nil for any non-node data (magic code) 19 | func DecodeSMTProof(data []byte) (*Node, error) { 20 | 21 | if bytes.Equal(magicSMTBytes, data) { 22 | //skip magic bytes node 23 | return nil, nil 24 | } 25 | 26 | return NewNodeFromBytes(data) 27 | } 28 | 29 | // Prove constructs a merkle proof for SMT, it respect the protocol used by the ethereum-trie 30 | // but save the node data with a compact form 31 | func (mt *ZkTrieImpl) Prove(kHash *zkt.Hash, fromLevel uint, writeNode func(*Node) error) error { 32 | // force root hash calculation if needed 33 | if _, err := mt.Root(); err != nil { 34 | return err 35 | } 36 | 37 | mt.lock.RLock() 38 | defer mt.lock.RUnlock() 39 | 40 | path := getPath(mt.maxLevels, kHash[:]) 41 | var nodes []*Node 42 | var lastN *Node 43 | tn := mt.rootKey 44 | for i := 0; i < mt.maxLevels; i++ { 45 | n, err := mt.getNode(tn) 46 | if err != nil { 47 | fmt.Println("get node fail", err, tn.Hex(), 48 | lastN.ChildL.Hex(), 49 | lastN.ChildR.Hex(), 50 | path, 51 | i, 52 | ) 53 | return err 54 | } 55 | nodeHash := tn 56 | lastN = n 57 | 58 | finished := true 59 | switch n.Type { 60 | case NodeTypeEmpty_New: 61 | case NodeTypeLeaf_New: 62 | // notice even we found a leaf whose entry didn't match the expected k, 63 | // we still include it as the proof of absence 64 | case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: 65 | finished = false 66 | if path[i] { 67 | tn = n.ChildR 68 | } else { 69 | tn = n.ChildL 70 | } 71 | case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: 72 | panic("encounter deprecated node types") 73 | default: 74 | return ErrInvalidNodeFound 75 | } 76 | 77 | nCopy := n.Copy() 78 | nCopy.nodeHash = nodeHash 79 | nodes = append(nodes, nCopy) 80 | if finished { 81 | break 82 | } 83 | } 84 | 85 | for _, n := range nodes { 86 | if fromLevel > 0 { 87 | fromLevel-- 88 | continue 89 | } 90 | 91 | // TODO: notice here we may have broken some implicit on the proofDb: 92 | // the key is not kecca(value) and it even can not be derived from 93 | // the value by any means without a actually decoding 94 | if err := writeNode(n); err != nil { 95 | return err 96 | } 97 | } 98 | 99 | return nil 100 | } 101 | -------------------------------------------------------------------------------- /types/hash_test.go: -------------------------------------------------------------------------------- 1 | package zktrie 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "fmt" 7 | "math/big" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestDummyHash(t *testing.T) { 15 | result, err := dummyHash([]*big.Int{}, nil) 16 | assert.Equal(t, big.NewInt(0), result) 17 | assert.Equal(t, hashNotInitErr, err) 18 | } 19 | 20 | func TestCheckBigIntInField(t *testing.T) { 21 | bi := big.NewInt(0) 22 | assert.True(t, CheckBigIntInField(bi)) 23 | 24 | bi = new(big.Int).Sub(Q, big.NewInt(1)) 25 | assert.True(t, CheckBigIntInField(bi)) 26 | 27 | bi = new(big.Int).Set(Q) 28 | assert.False(t, CheckBigIntInField(bi)) 29 | } 30 | 31 | func TestNewHashAndBigIntFromBytes(t *testing.T) { 32 | b := bytes.Repeat([]byte{1, 2}, 16) 33 | h := NewHashFromBytes(b) 34 | assert.Equal(t, "0102010201020102010201020102010201020102010201020102010201020102", h.Hex()) 35 | assert.Equal(t, "45585349...", h.String()) 36 | 37 | h, err := NewHashFromCheckedBytes(b) 38 | assert.NoError(t, err) 39 | assert.Equal(t, "0102010201020102010201020102010201020102010201020102010201020102", h.Hex()) 40 | 41 | bi, err := NewBigIntFromHashBytes(b) 42 | assert.NoError(t, err) 43 | assert.Equal(t, "455853498485199945361735166433836579326217380693297711485161465995904286978", bi.String()) 44 | 45 | h1 := NewHashFromBytes(b) 46 | text, err := h1.MarshalText() 47 | assert.NoError(t, err) 48 | assert.Equal(t, "455853498485199945361735166433836579326217380693297711485161465995904286978", h1.BigInt().String()) 49 | h2 := &Hash{} 50 | err = h2.UnmarshalText(text) 51 | assert.NoError(t, err) 52 | assert.Equal(t, h1, h2) 53 | 54 | short := []byte{1, 2, 3, 4, 5} 55 | _, err = NewHashFromCheckedBytes(short) 56 | assert.Error(t, err) 57 | assert.Equal(t, fmt.Sprintf("expected %d bytes, but got %d bytes", HashByteLen, len(short)), err.Error()) 58 | 59 | short = []byte{1, 2, 3, 4, 5} 60 | _, err = NewBigIntFromHashBytes(short) 61 | assert.Error(t, err) 62 | assert.Equal(t, fmt.Sprintf("expected %d bytes, but got %d bytes", HashByteLen, len(short)), err.Error()) 63 | 64 | outOfField := bytes.Repeat([]byte{255}, 32) 65 | _, err = NewBigIntFromHashBytes(outOfField) 66 | assert.Error(t, err) 67 | assert.Equal(t, "NewBigIntFromHashBytes: Value not inside the Finite Field", err.Error()) 68 | } 69 | 70 | func TestNewHashFromBigIntAndString(t *testing.T) { 71 | bi := big.NewInt(12345) 72 | h := NewHashFromBigInt(bi) 73 | assert.Equal(t, "0000000000000000000000000000000000000000000000000000000000003039", h.Hex()) 74 | assert.Equal(t, "12345", h.String()) 75 | 76 | s := "454086624460063511464984254936031011189294057512315937409637584344757371137" 77 | h, err := NewHashFromString(s) 78 | assert.NoError(t, err) 79 | assert.Equal(t, "0101010101010101010101010101010101010101010101010101010101010101", h.Hex()) 80 | assert.Equal(t, "45408662...", h.String()) 81 | } 82 | 83 | func TestNewHashFromBytes(t *testing.T) { 84 | h := HashZero 85 | read, err := rand.Read(h[:]) 86 | require.NoError(t, err) 87 | require.Equal(t, HashByteLen, read) 88 | require.Equal(t, h, *NewHashFromBytes(h.Bytes())) 89 | } 90 | -------------------------------------------------------------------------------- /types/util_test.go: -------------------------------------------------------------------------------- 1 | package zktrie 2 | 3 | import ( 4 | "math/big" 5 | "strconv" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSetBitBigEndian(t *testing.T) { 12 | bitmap := make([]byte, 8) 13 | 14 | SetBitBigEndian(bitmap, 3) 15 | SetBitBigEndian(bitmap, 15) 16 | SetBitBigEndian(bitmap, 27) 17 | SetBitBigEndian(bitmap, 63) 18 | 19 | expected := []byte{0x80, 0x0, 0x0, 0x0, 0x8, 0x0, 0x80, 0x8} 20 | assert.Equal(t, expected, bitmap) 21 | } 22 | 23 | func TestBitManipulations(t *testing.T) { 24 | bitmap := []byte{0b10101010, 0b01010101} 25 | 26 | bitResults := make([]bool, 16) 27 | for i := uint(0); i < 16; i++ { 28 | bitResults[i] = TestBit(bitmap, i) 29 | } 30 | 31 | expectedBitResults := []bool{ 32 | false, true, false, true, false, true, false, true, 33 | true, false, true, false, true, false, true, false, 34 | } 35 | assert.Equal(t, expectedBitResults, bitResults) 36 | 37 | bitResultsBigEndian := make([]bool, 16) 38 | for i := uint(0); i < 16; i++ { 39 | bitResultsBigEndian[i] = TestBitBigEndian(bitmap, i) 40 | } 41 | 42 | expectedBitResultsBigEndian := []bool{ 43 | true, false, true, false, true, false, true, false, 44 | false, true, false, true, false, true, false, true, 45 | } 46 | assert.Equal(t, expectedBitResultsBigEndian, bitResultsBigEndian) 47 | } 48 | 49 | func TestBigEndianBitsToBigInt(t *testing.T) { 50 | bits := []bool{true, false, true, false, true, false, true, false} 51 | result := BigEndianBitsToBigInt(bits) 52 | expected := big.NewInt(170) 53 | assert.Equal(t, expected, result) 54 | } 55 | 56 | func TestToSecureKey(t *testing.T) { 57 | secureKey, err := ToSecureKey([]byte("testKey")) 58 | assert.NoError(t, err) 59 | assert.Equal(t, "3998087801436302712617435196225481036627874106324392591598072448097460358227", secureKey.String()) 60 | } 61 | 62 | func TestToSecureKeyBytes(t *testing.T) { 63 | secureKeyBytes, err := ToSecureKeyBytes([]byte("testKey")) 64 | assert.NoError(t, err) 65 | assert.Equal(t, []byte{0x8, 0xd6, 0xd6, 0x66, 0xa4, 0x8, 0xc5, 0x72, 0xa0, 0xc3, 0x71, 0x50, 0x89, 0xa0, 0x2b, 0xe7, 0x59, 0x97, 0x39, 0x5d, 0x2c, 0x37, 0x38, 0x5d, 0x67, 0x22, 0x84, 0xe5, 0xc8, 0xbf, 0xc, 0x53}, secureKeyBytes.Bytes()) 66 | } 67 | 68 | func TestReverseByteOrder(t *testing.T) { 69 | assert.Equal(t, []byte{5, 4, 3, 2, 1}, ReverseByteOrder([]byte{1, 2, 3, 4, 5})) 70 | } 71 | 72 | func TestHashElems(t *testing.T) { 73 | fst := big.NewInt(5) 74 | snd := big.NewInt(3) 75 | elems := make([]*big.Int, 32) 76 | for i := range elems { 77 | elems[i] = big.NewInt(int64(i + 1)) 78 | } 79 | 80 | result, err := HashElems(fst, snd, elems...) 81 | assert.NoError(t, err) 82 | assert.Equal(t, "1613b67f0a90f864bafa14df215f89e0c5a1c128e54561f0d730d112678e981d", result.Hex()) 83 | } 84 | 85 | func TestPreHandlingElems(t *testing.T) { 86 | flagArray := uint32(0b10101010101010101010101010101010) 87 | elems := make([]Byte32, 32) 88 | for i := range elems { 89 | elems[i] = *NewByte32FromBytes([]byte("test" + strconv.Itoa(i+1))) 90 | } 91 | 92 | result, err := HandlingElemsAndByte32(flagArray, elems) 93 | assert.NoError(t, err) 94 | assert.Equal(t, "259503a5495e5e7e83d7e8e3f22b214092f921b7cadba00526aea7485c1997e7", result.Hex()) 95 | 96 | elems = elems[:1] 97 | result, err = HandlingElemsAndByte32(flagArray, elems) 98 | assert.NoError(t, err) 99 | assert.Equal(t, "0000000000000000000000000000000000000000000000000000007465737431", result.Hex()) 100 | } 101 | -------------------------------------------------------------------------------- /trie/zk_trie_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "math/big" 5 | "os" 6 | "testing" 7 | 8 | zkt "github.com/scroll-tech/zktrie/types" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func setupENV() { 13 | zkt.InitHashScheme(func(arr []*big.Int, domain *big.Int) (*big.Int, error) { 14 | lcEff := big.NewInt(65536) 15 | sum := domain 16 | for _, bi := range arr { 17 | nbi := new(big.Int).Mul(bi, bi) 18 | sum = sum.Mul(sum, sum) 19 | sum = sum.Mul(sum, lcEff) 20 | sum = sum.Add(sum, nbi) 21 | } 22 | return sum.Mod(sum, zkt.Q), nil 23 | }) 24 | } 25 | 26 | func TestMain(m *testing.M) { 27 | setupENV() 28 | os.Exit(m.Run()) 29 | } 30 | 31 | func TestNewZkTrie(t *testing.T) { 32 | root := zkt.Byte32{} 33 | db := NewZkTrieMemoryDb() 34 | zkTrie, err := NewZkTrie(root, db) 35 | assert.NoError(t, err) 36 | assert.Equal(t, zkt.HashZero.Bytes(), zkTrie.Hash()) 37 | assert.Equal(t, zkt.HashZero.Bytes(), zkTrie.Tree().rootKey.Bytes()) 38 | 39 | root = zkt.Byte32{1} 40 | zkTrie, err = NewZkTrie(root, db) 41 | assert.Equal(t, ErrKeyNotFound, err) 42 | assert.Nil(t, zkTrie) 43 | } 44 | 45 | func TestZkTrie_GetUpdateDelete(t *testing.T) { 46 | root := zkt.Byte32{} 47 | db := NewZkTrieMemoryDb() 48 | zkTrie, err := NewZkTrie(root, db) 49 | assert.NoError(t, err) 50 | 51 | val, err := zkTrie.TryGet([]byte("key")) 52 | assert.NoError(t, err) 53 | assert.Nil(t, val) 54 | assert.Equal(t, zkt.HashZero.Bytes(), zkTrie.Hash()) 55 | 56 | err = zkTrie.TryUpdate([]byte("key"), 1, []zkt.Byte32{{1}}) 57 | assert.NoError(t, err) 58 | assert.Equal(t, []byte{0x23, 0x36, 0x5e, 0xbd, 0x71, 0xa7, 0xad, 0x35, 0x65, 0xdd, 0x24, 0x88, 0x47, 0xca, 0xe8, 0xe8, 0x8, 0x21, 0x15, 0x62, 0xc6, 0x83, 0xdb, 0x8, 0x4f, 0x5a, 0xfb, 0xd1, 0xb0, 0x3d, 0x4c, 0xb5}, zkTrie.Hash()) 59 | 60 | val, err = zkTrie.TryGet([]byte("key")) 61 | assert.NoError(t, err) 62 | assert.Equal(t, (&zkt.Byte32{1}).Bytes(), val) 63 | 64 | err = zkTrie.TryDelete([]byte("key")) 65 | assert.NoError(t, err) 66 | assert.Equal(t, zkt.HashZero.Bytes(), zkTrie.Hash()) 67 | 68 | val, err = zkTrie.TryGet([]byte("key")) 69 | assert.NoError(t, err) 70 | assert.Nil(t, val) 71 | } 72 | 73 | func TestZkTrie_Copy(t *testing.T) { 74 | root := zkt.Byte32{} 75 | db := NewZkTrieMemoryDb() 76 | zkTrie, err := NewZkTrie(root, db) 77 | assert.NoError(t, err) 78 | 79 | zkTrie.TryUpdate([]byte("key"), 1, []zkt.Byte32{{1}}) 80 | 81 | copyTrie := zkTrie.Copy() 82 | val, err := copyTrie.TryGet([]byte("key")) 83 | assert.NoError(t, err) 84 | assert.Equal(t, (&zkt.Byte32{1}).Bytes(), val) 85 | } 86 | 87 | func TestZkTrie_ProveAndProveWithDeletion(t *testing.T) { 88 | root := zkt.Byte32{} 89 | db := NewZkTrieMemoryDb() 90 | zkTrie, err := NewZkTrie(root, db) 91 | assert.NoError(t, err) 92 | 93 | keys := []string{"key1", "key2", "key3", "key4", "key5"} 94 | for i, keyStr := range keys { 95 | key := make([]byte, 32) 96 | copy(key, []byte(keyStr)) 97 | 98 | err := zkTrie.TryUpdate(key, uint32(i+1), []zkt.Byte32{{byte(uint32(i + 1))}}) 99 | assert.NoError(t, err) 100 | 101 | writeNode := func(n *Node) error { 102 | return nil 103 | } 104 | 105 | onHit := func(n *Node, sib *Node) {} 106 | 107 | k, err := zkt.ToSecureKey(key) 108 | assert.NoError(t, err) 109 | 110 | for j := 0; j <= i; j++ { 111 | err = zkTrie.ProveWithDeletion(zkt.NewHashFromBigInt(k).Bytes(), uint(j), writeNode, onHit) 112 | assert.NoError(t, err) 113 | 114 | err = zkTrie.Prove(zkt.NewHashFromBigInt(k).Bytes(), uint(j), writeNode) 115 | assert.NoError(t, err) 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /types/util.go: -------------------------------------------------------------------------------- 1 | package zktrie 2 | 3 | import ( 4 | "math/big" 5 | ) 6 | 7 | // HashElemsWithDomain performs a recursive poseidon hash over the array of ElemBytes, each hash 8 | // reduce 2 fieds into one, with a specified domain field which would be used in 9 | // every recursiving call 10 | func HashElemsWithDomain(domain, fst, snd *big.Int, elems ...*big.Int) (*Hash, error) { 11 | 12 | l := len(elems) 13 | baseH, err := hashScheme([]*big.Int{fst, snd}, domain) 14 | if err != nil { 15 | return nil, err 16 | } 17 | if l == 0 { 18 | return NewHashFromBigInt(baseH), nil 19 | } else if l == 1 { 20 | return HashElemsWithDomain(domain, baseH, elems[0]) 21 | } 22 | 23 | tmp := make([]*big.Int, (l+1)/2) 24 | for i := range tmp { 25 | if (i+1)*2 > l { 26 | tmp[i] = elems[i*2] 27 | } else { 28 | h, err := hashScheme(elems[i*2:(i+1)*2], domain) 29 | if err != nil { 30 | return nil, err 31 | } 32 | tmp[i] = h 33 | } 34 | } 35 | 36 | return HashElemsWithDomain(domain, baseH, tmp[0], tmp[1:]...) 37 | } 38 | 39 | // HashElems call HashElemsWithDomain with a domain of HASH_DOMAIN_ELEMS_BASE(256)* 40 | func HashElems(fst, snd *big.Int, elems ...*big.Int) (*Hash, error) { 41 | 42 | return HashElemsWithDomain(big.NewInt(int64(len(elems)*HASH_DOMAIN_ELEMS_BASE)+HASH_DOMAIN_BYTE32), 43 | fst, snd, elems...) 44 | } 45 | 46 | // HandlingElemsAndByte32 hash an arry mixed with field and byte32 elements, turn each byte32 into 47 | // field elements first then calculate the hash with HashElems 48 | func HandlingElemsAndByte32(flagArray uint32, elems []Byte32) (*Hash, error) { 49 | 50 | ret := make([]*big.Int, len(elems)) 51 | var err error 52 | 53 | for i, elem := range elems { 54 | if flagArray&(1< * value field, key preimage ; 25 | 26 | node key = field ; 27 | 28 | compress flag = 3 * byte ; 29 | 30 | value len = byte ; 31 | 32 | value field = field | compressed field, compressed field ; 33 | 34 | compressed field = 16 * hex char ; 35 | 36 | key preimage = '0x0' | preimage bytes ; 37 | 38 | preimage bytes = len, * byte ; 39 | 40 | len = byte ; 41 | ``` 42 | 43 | A `field` is an element in prime field of BN256 represented by **big endian** integer and contained in fixed length (32) bytes; 44 | 45 | A `compressed field` is a field represented by **big endian** integer which could be contained in 16 bytes; 46 | 47 | For the total `value len` items of `value field` (maximum 255), the first 24 `value field`s can be recorded as `field` or 2x `compressed field` (i.e. a byte32). The corresponding bit in `compress flag` is set to 1 if it was recorded as byte32, or 0 for a field. 48 | 49 | ## Key scheme 50 | 51 | The key of data node is obtained from one or more poseidon hash calculation: `poseidon := (field, field) => field`. 52 | 53 | For parent node: 54 | 55 | ``` 56 | key = poseidon(, ) 57 | ``` 58 | 59 | For leaf node: 60 | 61 | ``` 62 | key = poseidon(
, )
 63 | 
 64 | pre key = poseidon(field(1), )
 65 | 
 66 | value hash = poseidon(, ) | poseidon(, )
 67 | 
 68 | leaf element =  | poseidon(, ) | field(0)
 69 | 
 70 | ```
 71 | 
 72 | That is, to calculate the key of a leaf node:
 73 | 
 74 | 1. In the sequence of `value field`s, take which is recorded as 'compressed' and calculate the 2x `compressed field` for its poseidon hash, replace the corresponding `value field` ad-hoc in the sequence;
 75 | 
 76 | 2. Consider the sequence from 1 as the leafs of a binary merkle tree (append a 0 field for odd leafs) and calculate its root by poseidon hash;
 77 | 
 78 | For empty node:
 79 | 
 80 | ```
 81 | key = field(0)
 82 | ```
 83 | 
 84 | ## Account data
 85 | 
 86 | Each account data is saved in one leaf node of account zktrie as 4 `value field`s:
 87 | 
 88 | 1. Nonce as `field`
 89 | 2. Balance as `field`
 90 | 3. CodeHash as `compressed field` (byte32)
 91 | 4. Storage root as `field`
 92 | 
 93 | The key for an account data is calculated from the 20-bit account address as following:
 94 | 
 95 | ```
 96 | 
 97 | 32-byte-zero-end-padding-addr := address, 16 * bytes (0)
 98 | 
 99 | key = poseidon(, )
100 | 
101 | ```
102 | 
103 | ## Data examples
104 | 
105 | ### A leaf node in account trie:
106 | 
107 | > 0x017f9d3bbc51d12566ecc6049ca6bf76e32828c22b197405f63a833b566fe7da0a040400000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000029b74e075daad9f17eb39cd893c2dd32f52ecd99084d63964842defd00ebcbe208a2f471d50e56ac5000ab9e82f871e36b5a636b19bd02f70aa666a3bd03142f00
108 | 
109 | Can be decompose to:
110 | 
111 | + `0x01`: node type prefix for leaf node
112 | + `7f9d3bbc51d12566ecc6049ca6bf76e32828c22b197405f63a833b566fe7da0a`: node key as field
113 | + `04`: value len (4 value fields)
114 | + `040000`: compress flag, a 24 bit array, indicating the third field is compressed
115 | + `0000000000000000000000000000000000000000000000000000000000000001`: value field 0 (nonce)
116 | + `0000000000000000000000000000000000000000000000000000000000000000`: value field 1 (balance)
117 | + `29b74e075daad9f17eb39cd893c2dd32f52ecd99084d63964842defd00ebcbe2`: value field 2 (codeHash, as byte32)
118 | + `08a2f471d50e56ac5000ab9e82f871e36b5a636b19bd02f70aa666a3bd03142f`: value field 3 (storage root)
119 | + `00`: key preimage is not available
120 | 
121 | The key calculation for this node is:
122 | 
123 | ```
124 | 
125 | arr = [, , , ]
126 | 
127 | hash_pre = poseidon(, )
128 | 
129 | arr[2] = hash_pre
130 | 
131 | layer1 = [poseidon(arr[0], arr[1]), poseidon(arr[2], arr[3])]
132 | 
133 | key = poseidon(layer1[0], layer1[1])
134 | 
135 | ```
136 | 
137 | Notice all fields and compressed fields are represented as **big endian** integer.
138 | 
139 | ### A parent node in account trie:
140 | 
141 | > 0x00000000000000000000000000000000000000000000000000000000000000000004470b58d80eeb26da85b2c2db5c254900656fb459c07729f556ff02534ab32a
142 | 
143 | Notice the left child of this node is an empty node (so its key is field(0))
144 | 


--------------------------------------------------------------------------------
/rs_zktrie/src/hash.rs:
--------------------------------------------------------------------------------
  1 | use crate::{raw::ImplError, types::Hashable};
  2 | use std::fmt::Debug;
  3 | 
  4 | pub trait Hash: AsRef<[u8]> + AsMut<[u8]> + Default + Clone + Debug + PartialEq {
  5 |     const LEN: usize;
  6 | 
  7 |     fn is_valid(&self) -> bool {
  8 |         true
  9 |     }
 10 |     fn zero() -> Self {
 11 |         Default::default()
 12 |     }
 13 |     fn simple_hash_scheme(a: [u8; 32], b: [u8; 32], domain: u64) -> Self;
 14 | }
 15 | 
 16 | #[derive(Clone, Debug, Default, PartialEq)]
 17 | pub struct AsHash(T);
 18 | 
 19 | impl AsHash {
 20 |     pub fn take(self) -> T {
 21 |         self.0
 22 |     }
 23 | }
 24 | 
 25 | impl AsRef<[u8]> for AsHash {
 26 |     fn as_ref(&self) -> &[u8] {
 27 |         self.0.as_ref()
 28 |     }
 29 | }
 30 | 
 31 | impl AsMut<[u8]> for AsHash {
 32 |     fn as_mut(&mut self) -> &mut [u8] {
 33 |         self.0.as_mut()
 34 |     }
 35 | }
 36 | 
 37 | impl Hashable for AsHash {
 38 |     fn check_in_field(hash: &Self) -> bool {
 39 |         hash.0.is_valid()
 40 |     }
 41 | 
 42 |     fn test_bit(key: &Self, pos: usize) -> bool {
 43 |         return key.as_ref()[T::LEN - pos / 8 - 1] & (1 << (pos % 8)) != 0;
 44 |     }
 45 | 
 46 |     fn to_bytes(&self) -> Vec {
 47 |         self.as_ref()[0..T::LEN].to_vec()
 48 |     }
 49 | 
 50 |     fn hash_zero() -> Self {
 51 |         Self(T::zero())
 52 |     }
 53 | 
 54 |     fn from_bytes(bytes: &[u8]) -> Result {
 55 |         if bytes.len() > T::LEN {
 56 |             Err(ImplError::ErrNodeBytesBadSize)
 57 |         } else {
 58 |             let padding = T::LEN - bytes.len();
 59 |             let mut h = Self::hash_zero();
 60 |             h.as_mut()[padding..].copy_from_slice(bytes);
 61 |             if Self::check_in_field(&h) {
 62 |                 Ok(h)
 63 |             } else {
 64 |                 Err(ImplError::ErrNodeBytesBadSize)
 65 |             }
 66 |         }
 67 |     }
 68 | 
 69 |     fn hash_elems_with_domain(
 70 |         domain: u64,
 71 |         lbytes: &Self,
 72 |         rbytes: &Self,
 73 |     ) -> Result {
 74 |         let h = Self(T::simple_hash_scheme(
 75 |             lbytes.as_ref().try_into().expect("same length"),
 76 |             rbytes.as_ref().try_into().expect("same length"),
 77 |             domain,
 78 |         ));
 79 |         if Self::check_in_field(&h) {
 80 |             Ok(h)
 81 |         } else {
 82 |             Err(ImplError::ErrNodeBytesBadSize)
 83 |         }
 84 |     }
 85 | }
 86 | 
 87 | #[cfg(test)]
 88 | pub use tests::HashImpl;
 89 | 
 90 | #[cfg(test)]
 91 | mod tests {
 92 |     use crate::types::{Hashable, Node, TrieHashScheme};
 93 | 
 94 |     use ff::PrimeField;
 95 |     use halo2_proofs::pairing::bn256::Fr;
 96 |     use poseidon::Poseidon;
 97 | 
 98 |     lazy_static::lazy_static! {
 99 |         pub static ref POSEIDON_HASHER: poseidon::Poseidon = Poseidon::::new(8, 63);
100 |     }
101 | 
102 |     const HASH_BYTE_LEN: usize = 32;
103 | 
104 |     #[derive(Clone, Debug, Default, PartialEq)]
105 |     pub struct Hash(pub(crate) [u8; HASH_BYTE_LEN]);
106 | 
107 |     impl AsRef<[u8]> for Hash {
108 |         fn as_ref(&self) -> &[u8] {
109 |             &self.0
110 |         }
111 |     }
112 | 
113 |     impl AsMut<[u8]> for Hash {
114 |         fn as_mut(&mut self) -> &mut [u8] {
115 |             &mut self.0
116 |         }
117 |     }
118 | 
119 |     impl super::Hash for Hash {
120 |         const LEN: usize = HASH_BYTE_LEN;
121 | 
122 |         //todo replace with poseidon hash
123 |         fn simple_hash_scheme(a: [u8; 32], b: [u8; 32], domain: u64) -> Self {
124 |             let mut hasher = POSEIDON_HASHER.clone();
125 |             hasher.update(&[
126 |                 Fr::from_repr(a).unwrap(),
127 |                 Fr::from_repr(b).unwrap(),
128 |                 Fr::from(domain),
129 |             ]);
130 |             Hash(hasher.squeeze().to_repr())
131 |         }
132 | 
133 |         fn is_valid(&self) -> bool {
134 |             Fr::from_repr(self.0).is_some().into()
135 |         }
136 |         fn zero() -> Self {
137 |             Self([0; HASH_BYTE_LEN])
138 |         }
139 |     }
140 | 
141 |     pub type HashImpl = super::AsHash;
142 | 
143 |     #[test]
144 |     fn test_hash_byte() {
145 |         let mut byte = vec![];
146 |         let mut h = HashImpl::hash_zero();
147 |         for i in 0..HASH_BYTE_LEN {
148 |             byte.push(i as u8);
149 |             h.as_mut()[i] = i as u8;
150 |         }
151 |         assert_eq!(h.to_bytes(), byte);
152 |         assert_eq!(HashImpl::from_bytes(&byte).unwrap(), h);
153 |     }
154 | 
155 |     #[test]
156 |     fn test_hash_domain() {
157 |         let domain: u64 = 16;
158 |         let mut bytes = vec![];
159 |         for i in 0..16 {
160 |             bytes.push([i as u8; 32]);
161 |         }
162 |         for i in 0..8 {
163 |             let ret = HashImpl::hash_elems_with_domain(
164 |                 domain,
165 |                 &HashImpl::from_bytes(&bytes[2 * i]).unwrap(),
166 |                 &HashImpl::from_bytes(&bytes[2 * i + 1]).unwrap(),
167 |             );
168 |             assert!(ret.is_ok());
169 |         }
170 |         let ret = Node::::handling_elems_and_bytes32(65535, &bytes);
171 |         assert!(ret.is_ok());
172 |     }
173 | 
174 |     #[test]
175 |     fn test_hash_scheme() {
176 |         //fill poseidon hash result when move to zk
177 |         //todo!();
178 |     }
179 | }
180 | 


--------------------------------------------------------------------------------
/rs_zktrie/src/trie.rs:
--------------------------------------------------------------------------------
  1 | use crate::db::ZktrieDatabase;
  2 | use crate::raw::{ImplError, ZkTrieImpl};
  3 | use crate::types::{Hashable, Node, NodeType, TrieHashScheme};
  4 | 
  5 | pub trait KeyCache {
  6 |     fn get_key(&self, k: &[u8]) -> Option<&H>;
  7 | }
  8 | 
  9 | // ZkTrie wraps a trie with key hashing. In a secure trie, all
 10 | // access operations hash the key using keccak256. This prevents
 11 | // calling code from creating long chains of nodes that
 12 | // increase the access time.
 13 | //
 14 | // Contrary to a regular trie, a ZkTrie can only be created with
 15 | // New and must have an attached database. The database also stores
 16 | // the preimage of each key.
 17 | //
 18 | 
 19 | const MAX_LEVELS: usize = (NODE_KEY_VALID_BYTES * 8) as usize;
 20 | 
 21 | pub struct ZkTrie> {
 22 |     tree: ZkTrieImpl,
 23 | }
 24 | 
 25 | // NODE_KEY_VALID_BYTES is the number of least significant bytes in the node key
 26 | // that are considered valid to addressing the leaf node, and thus limits the
 27 | // maximum trie depth to NODE_KEY_VALID_BYTES * 8.
 28 | // We need to truncate the node key because the key is the output of Poseidon
 29 | // hash and the key space doesn't fully occupy the range of power of two. It can
 30 | // lead to an ambiguous bit representation of the key in the finite field
 31 | // causing a soundness issue in the zk circuit.
 32 | const NODE_KEY_VALID_BYTES: u32 = 31;
 33 | 
 34 | impl> ZkTrie {
 35 |     pub const MAX_LEVELS: usize = MAX_LEVELS;
 36 | 
 37 |     // NewSecure creates a trie
 38 |     // SecureBinaryTrie bypasses all the buffer mechanism in *Database, it directly uses the
 39 |     // underlying diskdb
 40 |     pub fn new_zktrie(root: H, db: DB) -> Result {
 41 |         let tr = ZkTrieImpl::new_zktrie_impl_with_root(db, root);
 42 |         let t = ZkTrie { tree: tr? };
 43 |         Ok(t)
 44 |     }
 45 | 
 46 |     // TryGet returns the value for key stored in the trie.
 47 |     // The value bytes must not be modified by the caller.
 48 |     // If a node was not found in the database, a MissingNodeError is returned.
 49 |     pub fn try_get(&self, key: &[u8]) -> Vec {
 50 |         let node = if let Some(k) = self.tree.get_db().get_key(key) {
 51 |             self.tree.try_get(k)
 52 |         } else {
 53 |             let k = Node::::hash_bytes(key).unwrap();
 54 |             self.tree.try_get(&k)
 55 |         };
 56 |         node.ok().and_then(|n| n.data()).unwrap_or_default()
 57 |     }
 58 | 
 59 |     // Tree exposed underlying ZkTrieImpl
 60 |     pub fn tree(self) -> ZkTrieImpl {
 61 |         self.tree
 62 |     }
 63 | 
 64 |     // Commit flushes the trie to database
 65 |     pub fn commit(&mut self) -> Result<(), ImplError> {
 66 |         self.tree.commit()
 67 |     }
 68 | 
 69 |     pub fn is_trie_dirty(&self) -> bool {
 70 |         self.tree.is_trie_dirty()
 71 |     }
 72 | 
 73 |     // TryUpdate associates key with value in the trie. Subsequent calls to
 74 |     // Get will return value. If value has length zero, any existing value
 75 |     // is deleted from the trie and calls to Get will return nil.
 76 |     //
 77 |     // The value bytes must not be modified by the caller while they are
 78 |     // stored in the trie.
 79 |     //
 80 |     // If a node was not found in the database, a MissingNodeError is returned.
 81 |     //
 82 |     // NOTE: value is restricted to length of bytes32.
 83 |     pub fn try_update(
 84 |         &mut self,
 85 |         key: &[u8],
 86 |         v_flag: u32,
 87 |         v_preimage: Vec<[u8; 32]>,
 88 |     ) -> Result<(), ImplError> {
 89 |         let k = if let Some(k) = self.tree.get_db().get_key(key) {
 90 |             k.clone()
 91 |         } else {
 92 |             Node::::hash_bytes(key).unwrap()
 93 |         };
 94 |         self.tree.try_update(&k, v_flag, v_preimage)
 95 |     }
 96 | 
 97 |     // TryDelete removes any existing value for key from the trie.
 98 |     // If a node was not found in the database, a MissingNodeError is returned.
 99 |     pub fn try_delete(&mut self, key: &[u8]) -> Result<(), ImplError> {
100 |         let k = if let Some(k) = self.tree.get_db().get_key(key) {
101 |             k.clone()
102 |         } else {
103 |             Node::::hash_bytes(key).unwrap()
104 |         };
105 |         self.tree.try_delete(&k)
106 |     }
107 | 
108 |     // Hash returns the root hash of SecureBinaryTrie. It does not write to the
109 |     // database and can be used even if the trie doesn't have one.
110 |     pub fn hash(&self) -> Vec {
111 |         self.tree.root().to_bytes()
112 |     }
113 | 
114 |     pub fn prepare_root(&mut self) -> Result<(), ImplError> {
115 |         self.tree.prepare_root()?;
116 |         Ok(())
117 |     }
118 | 
119 |     // Prove constructs a merkle proof for key. The result contains all encoded nodes
120 |     // on the path to the value at key. The value itself is also included in the last
121 |     // node and can be retrieved by verifying the proof.
122 |     //
123 |     // If the trie does not contain a value for key, the returned proof contains all
124 |     // nodes of the longest existing prefix of the key (at least the root node), ending
125 |     // with the node that proves the absence of the key. and the `bool` in returned
126 |     // tuple is false
127 |     //
128 |     // If the trie contain a non-empty leaf for key, the `bool` in returned tuple is true
129 |     pub fn prove(&self, key_hash_byte: &[u8]) -> Result<(Vec>, bool), ImplError> {
130 |         let key_hash = H::from_bytes(key_hash_byte)?;
131 |         let proof = self.tree.prove(&key_hash)?;
132 |         let mut hit = false;
133 | 
134 |         for n in &proof {
135 |             if n.node_type == NodeType::NodeTypeLeafNew && n.node_key == key_hash {
136 |                 hit = true
137 |             }
138 |         }
139 | 
140 |         Ok((proof, hit))
141 |     }
142 | }
143 | 


--------------------------------------------------------------------------------
/trie/zk_trie.go:
--------------------------------------------------------------------------------
  1 | // Copyright 2015 The go-ethereum Authors
  2 | // This file is part of the go-ethereum library.
  3 | //
  4 | // The go-ethereum library is free software: you can redistribute it and/or modify
  5 | // it under the terms of the GNU Lesser General Public License as published by
  6 | // the Free Software Foundation, either version 3 of the License, or
  7 | // (at your option) any later version.
  8 | //
  9 | // The go-ethereum library is distributed in the hope that it will be useful,
 10 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 12 | // GNU Lesser General Public License for more details.
 13 | //
 14 | // You should have received a copy of the GNU Lesser General Public License
 15 | // along with the go-ethereum library. If not, see .
 16 | 
 17 | package trie
 18 | 
 19 | import (
 20 | 	"bytes"
 21 | 	"math/big"
 22 | 
 23 | 	zkt "github.com/scroll-tech/zktrie/types"
 24 | )
 25 | 
 26 | // ZkTrie wraps a trie with key hashing. In a secure trie, all
 27 | // access operations hash the key using keccak256. This prevents
 28 | // calling code from creating long chains of nodes that
 29 | // increase the access time.
 30 | //
 31 | // Contrary to a regular trie, a ZkTrie can only be created with
 32 | // New and must have an attached database. The database also stores
 33 | // the preimage of each key.
 34 | //
 35 | // ZkTrie is not safe for concurrent use.
 36 | type ZkTrie struct {
 37 | 	tree *ZkTrieImpl
 38 | }
 39 | 
 40 | // NodeKeyValidBytes is the number of least significant bytes in the node key
 41 | // that are considered valid to addressing the leaf node, and thus limits the
 42 | // maximum trie depth to NodeKeyValidBytes * 8.
 43 | // We need to truncate the node key because the key is the output of Poseidon
 44 | // hash and the key space doesn't fully occupy the range of power of two. It can
 45 | // lead to an ambiguous bit representation of the key in the finite field
 46 | // causing a soundness issue in the zk circuit.
 47 | const NodeKeyValidBytes = 31
 48 | 
 49 | // NewSecure creates a trie
 50 | // SecureBinaryTrie bypasses all the buffer mechanism in *Database, it directly uses the
 51 | // underlying diskdb
 52 | func NewZkTrie(root zkt.Byte32, db ZktrieDatabase) (*ZkTrie, error) {
 53 | 	maxLevels := NodeKeyValidBytes * 8
 54 | 	tree, err := NewZkTrieImplWithRoot((db), zkt.NewHashFromBytes(root.Bytes()), maxLevels)
 55 | 	if err != nil {
 56 | 		return nil, err
 57 | 	}
 58 | 	return &ZkTrie{
 59 | 		tree: tree,
 60 | 	}, nil
 61 | }
 62 | 
 63 | // TryGet returns the value for key stored in the trie.
 64 | // The value bytes must not be modified by the caller.
 65 | // If a node was not found in the database, a MissingNodeError is returned.
 66 | func (t *ZkTrie) TryGet(key []byte) ([]byte, error) {
 67 | 	k, err := zkt.ToSecureKey(key)
 68 | 	if err != nil {
 69 | 		return nil, err
 70 | 	}
 71 | 
 72 | 	return t.tree.TryGet(zkt.NewHashFromBigInt(k))
 73 | }
 74 | 
 75 | // Tree exposed underlying ZkTrieImpl
 76 | func (t *ZkTrie) Tree() *ZkTrieImpl {
 77 | 	return t.tree
 78 | }
 79 | 
 80 | // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
 81 | // possible to use keybyte-encoding as the path might contain odd nibbles.
 82 | func (t *ZkTrie) TryGetNode(path []byte) ([]byte, int, error) {
 83 | 	panic("unimplemented")
 84 | }
 85 | 
 86 | func (t *ZkTrie) updatePreimage(preimage []byte, hashField *big.Int) {
 87 | 	t.tree.db.UpdatePreimage(preimage, hashField)
 88 | }
 89 | 
 90 | // TryUpdate associates key with value in the trie. Subsequent calls to
 91 | // Get will return value. If value has length zero, any existing value
 92 | // is deleted from the trie and calls to Get will return nil.
 93 | //
 94 | // The value bytes must not be modified by the caller while they are
 95 | // stored in the trie.
 96 | //
 97 | // If a node was not found in the database, a MissingNodeError is returned.
 98 | //
 99 | // NOTE: value is restricted to length of bytes32.
100 | func (t *ZkTrie) TryUpdate(key []byte, vFlag uint32, vPreimage []zkt.Byte32) error {
101 | 	k, err := zkt.ToSecureKey(key)
102 | 	if err != nil {
103 | 		return err
104 | 	}
105 | 	t.updatePreimage(key, k)
106 | 	return t.tree.TryUpdate(zkt.NewHashFromBigInt(k), vFlag, vPreimage)
107 | }
108 | 
109 | // TryDelete removes any existing value for key from the trie.
110 | // If a node was not found in the database, a MissingNodeError is returned.
111 | func (t *ZkTrie) TryDelete(key []byte) error {
112 | 	k, err := zkt.ToSecureKey(key)
113 | 	if err != nil {
114 | 		return err
115 | 	}
116 | 
117 | 	kHash := zkt.NewHashFromBigInt(k)
118 | 	//mitigate the create-delete issue: do not delete unexisted key
119 | 	if r, _ := t.tree.TryGet(kHash); r == nil {
120 | 		return nil
121 | 	}
122 | 
123 | 	return t.tree.TryDelete(kHash)
124 | }
125 | 
126 | // Hash returns the root hash of SecureBinaryTrie. It does not write to the
127 | // database and can be used even if the trie doesn't have one.
128 | func (t *ZkTrie) Hash() []byte {
129 | 	root, err := t.tree.Root()
130 | 	if err != nil {
131 | 		panic("root failed in trie.Hash")
132 | 	}
133 | 	return root.Bytes()
134 | }
135 | 
136 | // Commit flushes the trie to database
137 | func (t *ZkTrie) Commit() error {
138 | 	return t.tree.Commit()
139 | }
140 | 
141 | // Copy returns a copy of SecureBinaryTrie.
142 | func (t *ZkTrie) Copy() *ZkTrie {
143 | 	return &ZkTrie{
144 | 		tree: t.tree.Copy(),
145 | 	}
146 | }
147 | 
148 | // Prove is a simlified calling of ProveWithDeletion
149 | func (t *ZkTrie) Prove(key []byte, fromLevel uint, writeNode func(*Node) error) error {
150 | 	return t.ProveWithDeletion(key, fromLevel, writeNode, nil)
151 | }
152 | 
153 | // ProveWithDeletion constructs a merkle proof for key. The result contains all encoded nodes
154 | // on the path to the value at key. The value itself is also included in the last
155 | // node and can be retrieved by verifying the proof.
156 | //
157 | // If the trie does not contain a value for key, the returned proof contains all
158 | // nodes of the longest existing prefix of the key (at least the root node), ending
159 | // with the node that proves the absence of the key.
160 | //
161 | // If the trie contain value for key, the onHit is called BEFORE writeNode being called,
162 | // both the hitted leaf node and its sibling node is provided as arguments so caller
163 | // would receive enough information for launch a deletion and calculate the new root
164 | // base on the proof data
165 | // Also notice the sibling can be nil if the trie has only one leaf
166 | func (t *ZkTrie) ProveWithDeletion(key []byte, fromLevel uint, writeNode func(*Node) error, onHit func(*Node, *Node)) error {
167 | 	k, err := zkt.NewHashFromCheckedBytes(key)
168 | 	if err != nil {
169 | 		return err
170 | 	}
171 | 	var prev *Node
172 | 	return t.tree.Prove(k, fromLevel, func(n *Node) (err error) {
173 | 		defer func() {
174 | 			if err == nil {
175 | 				err = writeNode(n)
176 | 			}
177 | 			prev = n
178 | 		}()
179 | 
180 | 		if prev != nil {
181 | 			switch prev.Type {
182 | 			case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3:
183 | 			default:
184 | 				// sanity check: we should stop after obtain leaf/empty
185 | 				panic("unexpected behavior in prove")
186 | 			}
187 | 		}
188 | 
189 | 		if onHit == nil {
190 | 			return
191 | 		}
192 | 
193 | 		// check and call onhit
194 | 		if n.Type == NodeTypeLeaf_New && bytes.Equal(n.NodeKey.Bytes(), k.Bytes()) {
195 | 			if prev == nil {
196 | 				// for sole element trie
197 | 				onHit(n, nil)
198 | 			} else {
199 | 				var sibling, nHash *zkt.Hash
200 | 				nHash, err = n.NodeHash()
201 | 				if err != nil {
202 | 					return
203 | 				}
204 | 
205 | 				if bytes.Equal(nHash.Bytes(), prev.ChildL.Bytes()) {
206 | 					sibling = prev.ChildR
207 | 				} else {
208 | 					sibling = prev.ChildL
209 | 				}
210 | 
211 | 				if siblingNode, err := t.tree.getNode(sibling); err == nil {
212 | 					onHit(n, siblingNode)
213 | 				} else {
214 | 					onHit(n, nil)
215 | 				}
216 | 			}
217 | 
218 | 		}
219 | 		return
220 | 	})
221 | }
222 | 


--------------------------------------------------------------------------------
/docs/zktrie.md:
--------------------------------------------------------------------------------
  1 | # zkTrie Spec
  2 | 
  3 | ## 1. Tree Structure
  4 | 
  5 | 
6 | zkTrie Structure 7 |
Figure 1. zkTrie Structure
8 |
9 | 10 | In essence, zkTrie is a sparse binary Merkle Patricia Trie, depicted in the above figure. 11 | Before diving into the Sparse Binary Merkle Patricia Trie, let's briefly touch on Merkle Trees and Patricia Tries. 12 | * **Merkle Tree**: A Merkle Tree is a tree where each leaf node represents a hash of a data block, and each non-leaf node represents the hash of its child nodes. 13 | * **Patricia Trie**: A Patricia Trie is a type of radix tree or compressed trie used to store key-value pairs efficiently. It encodes the nodes with same prefix of the key to share the common path, where the path is determined by the value of the node key. 14 | 15 | As illustrated in the Figure 1, there are three types of nodes in the zkTrie. 16 | - Parent Node (type: 0): Given the zkTrie is a binary tree, a parent node has two children. 17 | - Leaf Node (type: 1): A leaf node holds the data of a key-value pair. 18 | - Empty Node (type: 2): An empty node is a special type of node, indicating the sub-trie that shares the same prefix is empty. 19 | 20 | In zkTrie, we use Poseidon hash to compute the node hash because it's more friendly and efficient to prove it in the zk circuit. 21 | 22 | ## 2. Tree Construction 23 | 24 | Given a key-value pair, we first compute a *secure key* for the corresponding leaf node by hashing the original key (i.e., account address and storage key) using the Poseidon hash function. This can make the key uniformly distributed over the key space. The node key hashing method is described in the [Node Hashing](#3-node-hashing) section below. 25 | 26 | We then encode the path of a new leaf node by traversing the secure key from Least Significant Bit (LSB) to the Most Significant Bit (MSB). At each step, if the bit is 0, we will traverse to the left child; otherwise, traverse to the right child. 27 | 28 | We limit the maximum depth of zkTrie to 248, meaning that the tree will only traverse the lower 248 bits of the key. This is because the secure key space is a finite field used by Poseidon hash that doesn't occupy the full range of power of 2. This leads to an ambiguous bit representation of the key in a finite field and thus causes a soundness issue in the zk circuit. But if we truncate the key to lower 248 bits, the key space can fully occupy the range of $2^{248}$ and won't have the ambiguity in the bit representation. 29 | 30 | We also apply an optimization to reduce the tree depth by contracting a subtree that has only one leaf node to a single leaf node. For example, in the Figure 1, the tree has three nodes in total, with keys `0100`, `0010`, and `1010`. Because there is only one node that has key with suffix `00`, the leaf node for key `0100` only traverses the suffix `00` and doesn't fully expand its key which would have resulted in depth of 4. 31 | 32 | ## 3. Node Hashing 33 | 34 | In this section, we will describe how leaf secure key and node merkle hash are computed. We use Poseidon hash in both hashing computation, denoted as `h` in the doc below. 35 | 36 | 39 | 40 | ### 3.1 Empty Node 41 | 42 | The node hash of an empty node is 0. 43 | 44 | ### 3.2 Parent Node 45 | 46 | The parent node hash is computed as follows 47 | 48 | ```go 49 | parentNodeHash = h(leftChildHash, rightChildHash) 50 | ``` 51 | 52 | ### 3.3 Leaf Node 53 | 54 | The node hash of a leaf node is computed as follows 55 | 56 | ```go 57 | leafNodeHash = h(h(1, nodeKey), valueHash) 58 | ``` 59 | 60 | The leaf node can hold two types of values: Ethereum accounts and storage key-value pairs. Next, we will describe how the node key and value hash are computed for each leaf node type. 61 | 62 | #### Ethereum Account Leaf Node 63 | For an Ethereum Account Leaf Node, it consists of an Ethereum address and a state account struct. The secure key is derived from the Ethereum address. 64 | ``` 65 | address[0:20] (20 bytes in big-endian) 66 | valHi = address[0:16] 67 | valLo = address[16:20] * 2^96 (padding 12 bytes of 0 at the end) 68 | nodeKey = h(valHi, valLo) 69 | ``` 70 | 71 | A state account struct in the Scroll consists of the following fields (`Fr` indicates the finite field used in Poseidon hash and is a 254-bit value) 72 | 73 | - `Nonce`: u64 74 | - `Balance`: u256, but treated as Fr 75 | - `StorageRoot`: Fr 76 | - `KeccakCodeHash`: u256 77 | - `PoseidonCodeHash`: Fr 78 | - `CodeSize`: u64 79 | 80 | Before computing the value hash, the state account is first marshaled into a list of `u256` values. The marshaling scheme is 81 | 82 | ``` 83 | (The following scheme assumes the big-endian encoding) 84 | [0:32] (bytes in big-endian) 85 | [0:16] Reserved with all 0 86 | [16:24] CodeSize, uint64 in big-endian 87 | [24:32] Nonce, uint64 in big-endian 88 | [32:64] Balance 89 | [64:96] StorageRoot 90 | [96:128] KeccakCodeHash 91 | [128:160] PoseidonCodehash 92 | (total 160 bytes) 93 | ``` 94 | 95 | The marshal function also returns a `flag` value along with a vector of `u256` values. The `flag` is a bitmap that indicates whether a `u256` value CANNOT be treated as a field element (Fr). The `flag` value for state account is 8, shown below. 96 | 97 | ``` 98 | +--------------------+---------+------+----------+----------+ 99 | | 0 | 1 | 2 | 3 | 4 | (index) 100 | +--------------------+---------+------+----------+----------+ 101 | | nonce||codesize||0 | balance | root | keccak | poseidon | (u256) 102 | +--------------------+---------+------+----------+----------+ 103 | | 0 | 0 | 0 | 1 | 0 | (flag bits) 104 | +--------------------+---------+------+----------+----------+ 105 | (LSB) (MSB) 106 | ``` 107 | 108 | The value hash is computed in two steps: 109 | 1. Convert the value that cannot be represented as a field element of the Poseidon hash to the field element. 110 | 2. Combine field elements in a binary tree structure till the tree root is treated as the value hash. 111 | 112 | In the first step, when the bit in the `flag` is 1 indicating the `u256` value that cannot be treated as a field element, we split the value into a high-128bit value and a low-128bit value, and then pass them to a Poseidon hash to derive a field element value, `h(valueHi, valueLo)`. 113 | 114 | Based on the definition, the value hash of the state account is computed as follows. 115 | 116 | ``` 117 | valueHash = 118 | h( 119 | h( 120 | h(nonce||codesize||0, balance), 121 | h( 122 | storageRoot, 123 | h(keccakCodeHash[0:16], keccakCodeHash[16:32]), // convert Keccak codehash to a field element 124 | ), 125 | ), 126 | poseidonCodeHash, 127 | ) 128 | ``` 129 | 130 | #### Storage Leaf Node 131 | 132 | For a Storage Leaf Node, it is a key-value pair, which both are a `u256` value. The secure key of this leaf node is derived from the storage key. 133 | 134 | ``` 135 | storageKey[0:32] (32 bytes in big-endian) 136 | valHi = storageKey[0:16] 137 | valLo = storageKey[16:32] 138 | nodeKey = h(valHi, valLo) 139 | ``` 140 | 141 | The storage value is a `u256` value. The `flag` for the storage value is 1, showed below. 142 | 143 | ``` 144 | +--------------+ 145 | | 0 | (index) 146 | +--------------+ 147 | | storageValue | (u256) 148 | +--------------+ 149 | | 1 | (flag bits) 150 | +--------------+ 151 | ``` 152 | 153 | The value hash is computed as follows 154 | 155 | ```go 156 | valueHash = h(storageValue[0:16], storageValue[16:32]) 157 | ``` 158 | 159 | ## 4. Tree Operations 160 | 161 | ### 4.1 Insertion 162 | 163 |
164 | zkTrie Structure 165 |
Figure 2. Insert a new leaf node to zkTrie
166 |
167 | 168 | When we insert a new leaf node to the existing zkTrie, there could be two cases illustrated in the Figure 2. 169 | 170 | 1. When traversing the path of the node key, it reaches an empty node (Figure 2(b)). In this case, we just need to replace this empty node by this leaf node and backtrace the path to update the merkle hash of parent nodes till the root. 171 | 2. When traversing the path of the node key, it reaches another leaf node `b` (Figure 2(c)). In this case, we need to push down the existing leaf node `b` until the next bit in the node keys of two leaf nodes differs. At each push-down step, we need to insert an empty sibling node when necessary. When we reach the level where the bits differ, we then place two leaf nodes `b` and `c` as the left child and the right child depending on their bits. At last, we backtrace the path and update the merkle hash of all parent nodes. 172 | 173 | ### 4.2 Deletion 174 | 175 |
176 | zkTrie Structure 177 |
Figure 3. Delete a leaf node from the zkTrie
178 |
179 | 180 | 181 | The deletion of a leaf node is similar to the insertion. There are two cases illustrated in the Figure 3. 182 | 183 | 1. The sibling node of to-be-deleted leaf node is a parent node (Figure 3(b)). In this case, we can just replace the node `a` by an empty node and update the node hash of its ancestors till the root node. 184 | 2. The node of to-be-deleted leaf node is a leaf node (Figure 3(c)). Similarly, we first replace the leaf node by an empty node and start to contract its sibling node upwards until its sibling node is not an empty node. For example, in Figure 3(c), we first replace the leaf node `b` by an empty node. During the contraction, since the sibling of node `c` now becomes an empty node, we move node `c` one level upward to replace its parent node. The new sibling of node `c`, node `e`, is still an empty node. So again we move node `c` upward. Now that the sibling of node `c` is node `a`, the deletion process is finished. 185 | 186 | Note that the sibling of a leaf node in a valid zkTrie cannot be an empty node. Otherwise, we should always prune the subtree and move the leaf node upwards. 187 | -------------------------------------------------------------------------------- /trie/zk_trie_node_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | 8 | zkt "github.com/scroll-tech/zktrie/types" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestNewNode(t *testing.T) { 13 | t.Run("Test NewEmptyNode", func(t *testing.T) { 14 | node := NewEmptyNode() 15 | assert.Equal(t, NodeTypeEmpty_New, node.Type) 16 | 17 | hash, err := node.NodeHash() 18 | assert.NoError(t, err) 19 | assert.Equal(t, &zkt.HashZero, hash) 20 | 21 | hash, err = node.ValueHash() 22 | assert.NoError(t, err) 23 | assert.Equal(t, &zkt.HashZero, hash) 24 | }) 25 | 26 | t.Run("Test NewLeafNode", func(t *testing.T) { 27 | k := zkt.NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) 28 | vp := []zkt.Byte32{*zkt.NewByte32FromBytes(bytes.Repeat([]byte("b"), 32))} 29 | node := NewLeafNode(k, 1, vp) 30 | assert.Equal(t, NodeTypeLeaf_New, node.Type) 31 | assert.Equal(t, uint32(1), node.CompressedFlags) 32 | assert.Equal(t, vp, node.ValuePreimage) 33 | 34 | hash, err := node.NodeHash() 35 | assert.NoError(t, err) 36 | assert.Equal(t, "11b483a69ab36a33af13fb95a51a5788fe9936de933595ba30be0a435cba3df1", hash.Hex()) 37 | 38 | hash, err = node.ValueHash() 39 | assert.NoError(t, err) 40 | hashFromVp, err := vp[0].Hash() 41 | assert.NoError(t, err) 42 | assert.Equal(t, hashFromVp.Text(16), hash.Hex()) 43 | }) 44 | 45 | t.Run("Test NewParentNode", func(t *testing.T) { 46 | k := zkt.NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) 47 | node := NewParentNode(NodeTypeBranch_3, k, k) 48 | assert.Equal(t, NodeTypeBranch_3, node.Type) 49 | assert.Equal(t, k, node.ChildL) 50 | assert.Equal(t, k, node.ChildR) 51 | 52 | hash, err := node.NodeHash() 53 | assert.NoError(t, err) 54 | assert.Equal(t, "11391717288411fe4c995d6e5793713939e0f59550cd4da96b419905f41d9b80", hash.Hex()) 55 | 56 | hash, err = node.ValueHash() 57 | assert.NoError(t, err) 58 | assert.Equal(t, &zkt.HashZero, hash) 59 | }) 60 | 61 | t.Run("Test NewParentNodeWithEmptyChild", func(t *testing.T) { 62 | k := zkt.NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) 63 | r, err := NewEmptyNode().NodeHash() 64 | assert.NoError(t, err) 65 | node := NewParentNode(NodeTypeBranch_2, k, r) 66 | 67 | assert.Equal(t, NodeTypeBranch_2, node.Type) 68 | assert.Equal(t, k, node.ChildL) 69 | assert.Equal(t, r, node.ChildR) 70 | 71 | hash, err := node.NodeHash() 72 | assert.NoError(t, err) 73 | assert.Equal(t, "1cde45e680e99be8b276837124884739ca8dadf216e18b65228948fd74edaa7c", hash.Hex()) 74 | 75 | hash, err = node.ValueHash() 76 | assert.NoError(t, err) 77 | assert.Equal(t, &zkt.HashZero, hash) 78 | }) 79 | 80 | t.Run("Test Invalid Node", func(t *testing.T) { 81 | node := &Node{Type: 99} 82 | 83 | invalidNodeHash, err := node.NodeHash() 84 | assert.NoError(t, err) 85 | assert.Equal(t, &zkt.HashZero, invalidNodeHash) 86 | }) 87 | } 88 | 89 | func TestNewNodeFromBytes(t *testing.T) { 90 | t.Run("ParentNode", func(t *testing.T) { 91 | k1 := zkt.NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) 92 | k2 := zkt.NewHashFromBytes(bytes.Repeat([]byte("b"), 32)) 93 | node := NewParentNode(NodeTypeBranch_0, k1, k2) 94 | b := node.Value() 95 | 96 | node, err := NewNodeFromBytes(b) 97 | assert.NoError(t, err) 98 | 99 | assert.Equal(t, NodeTypeBranch_0, node.Type) 100 | assert.Equal(t, k1, node.ChildL) 101 | assert.Equal(t, k2, node.ChildR) 102 | 103 | hash, err := node.NodeHash() 104 | assert.NoError(t, err) 105 | assert.Equal(t, "187b8e3ca6b878c71b04a312333ed0e9a82e1354cd814f2c576046ed95dd894c", hash.Hex()) 106 | 107 | hash, err = node.ValueHash() 108 | assert.NoError(t, err) 109 | assert.Equal(t, &zkt.HashZero, hash) 110 | }) 111 | 112 | t.Run("LeafNode", func(t *testing.T) { 113 | k := zkt.NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) 114 | vp := make([]zkt.Byte32, 1) 115 | node := NewLeafNode(k, 1, vp) 116 | 117 | node.KeyPreimage = zkt.NewByte32FromBytes(bytes.Repeat([]byte("b"), 32)) 118 | 119 | nodeBytes := node.Value() 120 | newNode, err := NewNodeFromBytes(nodeBytes) 121 | assert.NoError(t, err) 122 | 123 | assert.Equal(t, node.Type, newNode.Type) 124 | assert.Equal(t, node.NodeKey, newNode.NodeKey) 125 | assert.Equal(t, node.ValuePreimage, newNode.ValuePreimage) 126 | assert.Equal(t, node.KeyPreimage, newNode.KeyPreimage) 127 | 128 | hash, err := node.NodeHash() 129 | assert.NoError(t, err) 130 | assert.Equal(t, "0409b2168569a5af12877689ac08263274fe6cb522898f6964647fed88b861b2", hash.Hex()) 131 | 132 | hash, err = node.ValueHash() 133 | assert.NoError(t, err) 134 | hashFromVp, err := vp[0].Hash() 135 | 136 | assert.Equal(t, zkt.NewHashFromBigInt(hashFromVp), hash) 137 | }) 138 | 139 | t.Run("EmptyNode", func(t *testing.T) { 140 | node := NewEmptyNode() 141 | b := node.Value() 142 | 143 | node, err := NewNodeFromBytes(b) 144 | assert.NoError(t, err) 145 | 146 | assert.Equal(t, NodeTypeEmpty_New, node.Type) 147 | 148 | hash, err := node.NodeHash() 149 | assert.NoError(t, err) 150 | assert.Equal(t, &zkt.HashZero, hash) 151 | 152 | hash, err = node.ValueHash() 153 | assert.NoError(t, err) 154 | assert.Equal(t, &zkt.HashZero, hash) 155 | }) 156 | 157 | t.Run("BadSize", func(t *testing.T) { 158 | testCases := [][]byte{ 159 | {}, 160 | {0, 1, 2}, 161 | func() []byte { 162 | b := make([]byte, zkt.HashByteLen+3) 163 | b[0] = byte(NodeTypeLeaf) 164 | return b 165 | }(), 166 | func() []byte { 167 | k := zkt.NewHashFromBytes([]byte{1, 2, 3, 4, 5}) 168 | vp := make([]zkt.Byte32, 1) 169 | node := NewLeafNode(k, 1, vp) 170 | b := node.Value() 171 | return b[:len(b)-32] 172 | }(), 173 | func() []byte { 174 | k := zkt.NewHashFromBytes([]byte{1, 2, 3, 4, 5}) 175 | vp := make([]zkt.Byte32, 1) 176 | node := NewLeafNode(k, 1, vp) 177 | node.KeyPreimage = zkt.NewByte32FromBytes([]byte{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37}) 178 | 179 | b := node.Value() 180 | return b[:len(b)-1] 181 | }(), 182 | } 183 | 184 | for _, b := range testCases { 185 | node, err := NewNodeFromBytes(b) 186 | assert.ErrorIs(t, err, ErrNodeBytesBadSize) 187 | assert.Nil(t, node) 188 | } 189 | }) 190 | 191 | t.Run("InvalidType", func(t *testing.T) { 192 | b := []byte{255} 193 | 194 | node, err := NewNodeFromBytes(b) 195 | assert.ErrorIs(t, err, ErrInvalidNodeFound) 196 | assert.Nil(t, node) 197 | }) 198 | } 199 | 200 | func TestNodeValueAndData(t *testing.T) { 201 | k := zkt.NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) 202 | vp := []zkt.Byte32{*zkt.NewByte32FromBytes(bytes.Repeat([]byte("b"), 32))} 203 | 204 | node := NewLeafNode(k, 1, vp) 205 | canonicalValue := node.CanonicalValue() 206 | assert.Equal(t, []byte{0x4, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x1, 0x1, 0x0, 0x0, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x0}, canonicalValue) 207 | assert.Equal(t, []byte{0x4, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x1, 0x1, 0x0, 0x0, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x0}, node.Value()) 208 | node.KeyPreimage = zkt.NewByte32FromBytes(bytes.Repeat([]byte("c"), 32)) 209 | assert.Equal(t, []byte{0x4, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x1, 0x1, 0x0, 0x0, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x20, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63, 0x63}, node.Value()) 210 | assert.Equal(t, []byte{0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62, 0x62}, node.Data()) 211 | 212 | parentNode := NewParentNode(NodeTypeBranch_3, k, k) 213 | canonicalValue = parentNode.CanonicalValue() 214 | assert.Equal(t, []byte{0x9, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61}, canonicalValue) 215 | assert.Nil(t, parentNode.Data()) 216 | 217 | emptyNode := &Node{Type: NodeTypeEmpty_New} 218 | assert.Equal(t, []byte{byte(emptyNode.Type)}, emptyNode.CanonicalValue()) 219 | assert.Nil(t, emptyNode.Data()) 220 | 221 | invalidNode := &Node{Type: 99} 222 | assert.Equal(t, []byte{}, invalidNode.CanonicalValue()) 223 | assert.Nil(t, invalidNode.Data()) 224 | } 225 | 226 | func TestNodeString(t *testing.T) { 227 | k := zkt.NewHashFromBytes(bytes.Repeat([]byte("a"), 32)) 228 | vp := []zkt.Byte32{*zkt.NewByte32FromBytes(bytes.Repeat([]byte("b"), 32))} 229 | 230 | leafNode := NewLeafNode(k, 1, vp) 231 | assert.Equal(t, fmt.Sprintf("Leaf I:%v Items: %d, First:%v", leafNode.NodeKey, len(leafNode.ValuePreimage), leafNode.ValuePreimage[0]), leafNode.String()) 232 | 233 | parentNode := NewParentNode(NodeTypeBranch_3, k, k) 234 | assert.Equal(t, fmt.Sprintf("Parent L:%s R:%s", parentNode.ChildL, parentNode.ChildR), parentNode.String()) 235 | 236 | emptyNode := NewEmptyNode() 237 | assert.Equal(t, "Empty", emptyNode.String()) 238 | 239 | invalidNode := &Node{Type: 99} 240 | assert.Equal(t, "Invalid Node", invalidNode.String()) 241 | } 242 | -------------------------------------------------------------------------------- /src/rs_lib.rs: -------------------------------------------------------------------------------- 1 | use super::constants::*; 2 | use std::{collections::HashMap, rc::Rc}; 3 | use zktrie_rust::{ 4 | db::ZktrieDatabase, 5 | types::{Hashable, TrieHashScheme}, 6 | *, 7 | }; 8 | #[derive(Clone, Debug, Default, PartialEq)] 9 | pub struct HashField([u8; HASHLEN]); 10 | 11 | impl AsRef<[u8]> for HashField { 12 | fn as_ref(&self) -> &[u8] { 13 | &self.0 14 | } 15 | } 16 | 17 | impl AsMut<[u8]> for HashField { 18 | fn as_mut(&mut self) -> &mut [u8] { 19 | &mut self.0 20 | } 21 | } 22 | 23 | impl hash::Hash for HashField { 24 | const LEN: usize = HASHLEN; 25 | 26 | // notice: we have skipped the "field range checking" since 27 | // we have wrapped zktrie in a form that never accept value 28 | // which would be an invalid field poteinally 29 | 30 | fn simple_hash_scheme(mut a: [u8; 32], mut b: [u8; 32], domain: u64) -> Self { 31 | a.reverse(); 32 | b.reverse(); 33 | 34 | let mut domain_byte32 = [0u8; 32]; 35 | domain_byte32[..8].copy_from_slice(&domain.to_le_bytes()); 36 | 37 | let mut ret = super::HASHSCHEME 38 | .get() 39 | .expect("init_hash_scheme_simple should have been called")( 40 | &a, &b, &domain_byte32 41 | ) 42 | .unwrap_or_default(); 43 | ret.reverse(); 44 | 45 | Self(ret) 46 | } 47 | } 48 | 49 | type HashImpl = hash::AsHash; 50 | 51 | pub struct ZkTrieNode { 52 | trie_node: types::Node, 53 | } 54 | 55 | impl ZkTrieNode { 56 | pub fn parse(data: &[u8]) -> Result { 57 | types::Node::new_node_from_bytes(data) 58 | // notice the go routine also calculated nodehash while parsing 59 | // see the code inside `NewTrieNode` 60 | .and_then(|n| n.calc_node_hash()) 61 | .map(|n| Self { trie_node: n }) 62 | .map_err(|e| e.to_string()) 63 | } 64 | 65 | pub fn parse_with_key(data: &[u8], key: &[u8]) -> Result { 66 | types::Node::new_node_from_bytes(data) 67 | .and_then(|mut n| { 68 | let h = HashImpl::from_bytes(key)?; 69 | n.set_node_hash(h); 70 | Ok(n) 71 | }) 72 | .map(|n| Self { trie_node: n }) 73 | .map_err(|e| e.to_string()) 74 | } 75 | 76 | pub fn node_hash(&self) -> Hash { 77 | self.trie_node 78 | .clone() 79 | .node_hash() 80 | .expect("has caluclated") 81 | .as_ref() 82 | .try_into() 83 | .expect("same length") 84 | } 85 | 86 | pub fn value_hash(&self) -> Option { 87 | self.trie_node 88 | .clone() 89 | .value_hash() 90 | .map(|h| h.as_ref().try_into().expect("same length")) 91 | } 92 | 93 | pub fn is_tip(&self) -> bool { 94 | self.trie_node.is_terminal() 95 | } 96 | 97 | pub fn as_account(&self) -> Option { 98 | if self.is_tip() { 99 | self.trie_node 100 | .data() 101 | .map(|data| { 102 | data.chunks(FIELDSIZE) 103 | .map(TryInto::<[u8; FIELDSIZE]>::try_into) 104 | .map(|v| v.expect("same length")) 105 | .collect::>() 106 | }) 107 | .map(|datas| datas.try_into().expect("should be same items")) 108 | } else { 109 | None 110 | } 111 | } 112 | 113 | pub fn as_storage(&self) -> Option { 114 | if self.is_tip() { 115 | self.trie_node 116 | .data() 117 | .map(|data| data.try_into().expect("should be same length")) 118 | } else { 119 | None 120 | } 121 | } 122 | } 123 | 124 | #[derive(Clone)] 125 | pub struct ZkMemoryDb { 126 | db: db::SimpleDb, 127 | key_db: HashMap, HashImpl>, 128 | } 129 | 130 | #[derive(Clone)] 131 | pub struct SharedMemoryDb(Rc); 132 | 133 | impl db::ZktrieDatabase for SharedMemoryDb { 134 | fn put(&mut self, _: Vec, _: Vec) -> Result<(), raw::ImplError> { 135 | Err(raw::ImplError::ErrNotWritable) 136 | } 137 | fn get(&self, k: &[u8]) -> Result<&[u8], raw::ImplError> { 138 | self.0.db.get(k) 139 | } 140 | } 141 | 142 | impl trie::KeyCache for SharedMemoryDb { 143 | fn get_key(&self, k: &[u8]) -> Option<&HashImpl> { 144 | self.0.key_db.get(k) 145 | } 146 | } 147 | 148 | #[derive(Clone)] 149 | pub struct UpdateDb(db::SimpleDb, Rc); 150 | 151 | impl UpdateDb { 152 | pub fn updated_db(self) -> db::SimpleDb { 153 | self.0 154 | } 155 | } 156 | 157 | impl db::ZktrieDatabase for UpdateDb { 158 | fn put(&mut self, k: Vec, v: Vec) -> Result<(), raw::ImplError> { 159 | self.0.put(k, v) 160 | } 161 | fn get(&self, k: &[u8]) -> Result<&[u8], raw::ImplError> { 162 | let ret = self.0.get(k); 163 | if ret.is_ok() { 164 | ret 165 | } else { 166 | self.1.db.get(k) 167 | } 168 | } 169 | } 170 | 171 | impl trie::KeyCache for UpdateDb { 172 | fn get_key(&self, k: &[u8]) -> Option<&HashImpl> { 173 | self.1.key_db.get(k) 174 | } 175 | } 176 | 177 | use trie::ZkTrie as ZktrieRs; 178 | 179 | pub struct ZkTrie>(ZktrieRs); 180 | 181 | pub type ErrString = String; 182 | 183 | const MAGICSMTBYTES: &[u8] = "THIS IS SOME MAGIC BYTES FOR SMT m1rRXgP2xpDI".as_bytes(); 184 | 185 | impl Default for ZkMemoryDb { 186 | fn default() -> Self { 187 | Self::new() 188 | } 189 | } 190 | 191 | impl ZkMemoryDb { 192 | pub fn new() -> Self { 193 | Self { 194 | db: db::SimpleDb::new(), 195 | key_db: HashMap::new(), 196 | } 197 | } 198 | 199 | pub fn with_key_cache<'a>(&mut self, data: impl Iterator) { 200 | for (k, v) in data { 201 | // TODO: here we silently omit any invalid hash value 202 | if let Ok(h) = HashImpl::from_bytes(v) { 203 | self.key_db.insert(Vec::from(k), h); 204 | } 205 | } 206 | } 207 | 208 | pub fn add_node_bytes(&mut self, data: &[u8], key: Option<&[u8]>) -> Result<(), ErrString> { 209 | if data == MAGICSMTBYTES { 210 | return Ok(()); 211 | } 212 | let n = if let Some(key) = key { 213 | ZkTrieNode::parse_with_key(data, key) 214 | } else { 215 | ZkTrieNode::parse(data) 216 | }?; 217 | self.db 218 | .put(n.node_hash().to_vec(), n.trie_node.canonical_value()) 219 | .map_err(|e| e.to_string()) 220 | } 221 | 222 | pub fn add_node_data(&mut self, data: &[u8]) -> Result<(), ErrString> { 223 | self.add_node_bytes(data, None) 224 | } 225 | 226 | pub fn update(&mut self, updated_db: db::SimpleDb) { 227 | self.db.merge(updated_db); 228 | } 229 | 230 | /// the zktrie can be created only if the corresponding root node has been added 231 | pub fn new_trie(self: &Rc, root: &Hash) -> Option> { 232 | HashImpl::from_bytes(root.as_slice()) 233 | .ok() 234 | .and_then(|h| ZktrieRs::new_zktrie(h, UpdateDb(Default::default(), self.clone())).ok()) 235 | .map(ZkTrie) 236 | } 237 | 238 | /// the zktrie can be created only if the corresponding root node has been added 239 | pub fn new_ref_trie(self: &Rc, root: &Hash) -> Option> { 240 | HashImpl::from_bytes(root.as_slice()) 241 | .ok() 242 | .and_then(|h| ZktrieRs::new_zktrie(h, SharedMemoryDb(self.clone())).ok()) 243 | .map(ZkTrie) 244 | } 245 | } 246 | 247 | impl ZkTrie { 248 | pub fn updated_db(self) -> db::SimpleDb { 249 | self.0.tree().into_db().updated_db() 250 | } 251 | 252 | fn update(&mut self, key: &[u8], value: &[[u8; FIELDSIZE]]) -> Result<(), ErrString> { 253 | let v_flag = match value.len() { 254 | 1 => 1, 255 | 4 => 4, 256 | 5 => 8, 257 | _ => return Err("unexpected buffer type".to_string()), 258 | }; 259 | 260 | self.0 261 | .try_update(key, v_flag, value.to_vec()) 262 | .map_err(|e| e.to_string()) 263 | } 264 | 265 | pub fn update_store(&mut self, key: &[u8], value: &StoreData) -> Result<(), ErrString> { 266 | self.update(key, &[*value]) 267 | } 268 | 269 | pub fn update_account( 270 | &mut self, 271 | key: &[u8], 272 | acc_fields: &AccountData, 273 | ) -> Result<(), ErrString> { 274 | self.update(key, acc_fields) 275 | } 276 | 277 | pub fn delete(&mut self, key: &[u8]) { 278 | self.0.try_delete(key).ok(); 279 | } 280 | } 281 | 282 | impl> ZkTrie { 283 | pub fn root(&self) -> Hash { 284 | self.0.hash().as_slice().try_into().expect("same length") 285 | } 286 | 287 | pub fn commit(&mut self) -> Result<(), ErrString> { 288 | self.0.commit().map_err(|e| e.to_string()) 289 | } 290 | 291 | pub fn is_trie_dirty(&self) -> bool { 292 | self.0.is_trie_dirty() 293 | } 294 | 295 | pub fn prepare_root(&mut self) { 296 | self.0.prepare_root().expect("prepare root failed"); 297 | } 298 | 299 | // all errors are reduced to "not found" 300 | fn get(&self, key: &[u8]) -> Option<[u8; T]> { 301 | let ret = self.0.try_get(key); 302 | if ret.len() != T { 303 | None 304 | } else { 305 | Some(ret.as_slice().try_into().expect("same length")) 306 | } 307 | } 308 | 309 | // get value from storage trie 310 | pub fn get_store(&self, key: &[u8]) -> Option { 311 | self.get::<32>(key) 312 | } 313 | 314 | // get account data from account trie 315 | pub fn get_account(&self, key: &[u8]) -> Option { 316 | self.get::(key).map(|arr| unsafe { 317 | std::mem::transmute::<[u8; FIELDSIZE * ACCOUNTFIELDS], AccountData>(arr) 318 | }) 319 | } 320 | 321 | // build prove array for mpt path 322 | pub fn prove(&self, key: &[u8]) -> Result>, ErrString> { 323 | use types::Node; 324 | 325 | let s_key = Node::::hash_bytes(key).map_err(|e| e.to_string())?; 326 | 327 | let (proof, _) = self.0.prove(s_key.as_ref()).map_err(|e| e.to_string())?; 328 | 329 | Ok(proof 330 | .into_iter() 331 | .map(|n| n.value()) 332 | .chain(std::iter::once(MAGICSMTBYTES.to_vec())) 333 | .collect()) 334 | } 335 | } 336 | -------------------------------------------------------------------------------- /src/go_lib.rs: -------------------------------------------------------------------------------- 1 | use super::constants::*; 2 | use std::ffi::{self, c_char, c_int, c_void}; 3 | use std::marker::{PhantomData, PhantomPinned}; 4 | use std::{fmt, rc::Rc}; 5 | 6 | #[repr(C)] 7 | struct MemoryDb { 8 | _data: [u8; 0], 9 | _marker: PhantomData<(*mut u8, PhantomPinned)>, 10 | } 11 | #[repr(C)] 12 | struct Trie { 13 | _data: [u8; 0], 14 | _marker: PhantomData<(*mut u8, PhantomPinned)>, 15 | } 16 | #[repr(C)] 17 | struct TrieNode { 18 | _data: [u8; 0], 19 | _marker: PhantomData<(*mut u8, PhantomPinned)>, 20 | } 21 | 22 | pub type HashScheme = extern "C" fn(*const u8, *const u8, *const u8, *mut u8) -> *const i8; 23 | type ProveCallback = extern "C" fn(*const u8, c_int, *mut c_void); 24 | 25 | #[link(name = "zktrie")] 26 | extern "C" { 27 | fn InitHashScheme(f: HashScheme); 28 | fn NewMemoryDb() -> *mut MemoryDb; 29 | fn InitDbByNode(db: *mut MemoryDb, data: *const u8, sz: c_int) -> *const c_char; 30 | fn NewZkTrie(root: *const u8, db: *const MemoryDb) -> *mut Trie; 31 | fn FreeMemoryDb(db: *mut MemoryDb); 32 | fn FreeZkTrie(trie: *mut Trie); 33 | fn FreeBuffer(p: *const c_void); 34 | fn TrieGetSize(trie: *const Trie, key: *const u8, key_sz: c_int, value_sz: c_int) -> *const u8; 35 | fn TrieRoot(trie: *const Trie) -> *const u8; 36 | fn TrieUpdate( 37 | trie: *mut Trie, 38 | key: *const u8, 39 | key_sz: c_int, 40 | val: *const u8, 41 | val_sz: c_int, 42 | ) -> *const c_char; 43 | fn TrieDelete(trie: *mut Trie, key: *const u8, key_sz: c_int); 44 | fn TrieProve( 45 | trie: *const Trie, 46 | key: *const u8, 47 | key_sz: c_int, 48 | cb: ProveCallback, 49 | param: *mut c_void, 50 | ) -> *const c_char; 51 | fn NewTrieNode(data: *const u8, data_sz: c_int) -> *const TrieNode; 52 | fn FreeTrieNode(node: *const TrieNode); 53 | fn TrieNodeHash(node: *const TrieNode) -> *const u8; 54 | fn TrieLeafNodeValueHash(node: *const TrieNode) -> *const u8; 55 | fn TrieNodeIsTip(node: *const TrieNode) -> c_int; 56 | fn TrieNodeData(node: *const TrieNode, value_sz: c_int) -> *const u8; 57 | } 58 | 59 | pub(crate) fn init_hash_scheme(f: HashScheme) { 60 | unsafe { InitHashScheme(f) } 61 | } 62 | 63 | pub struct ErrString(*const c_char); 64 | 65 | impl Drop for ErrString { 66 | fn drop(&mut self) { 67 | unsafe { FreeBuffer(self.0.cast()) }; 68 | } 69 | } 70 | 71 | impl fmt::Debug for ErrString { 72 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 73 | self.to_string().fmt(f) 74 | } 75 | } 76 | 77 | impl From<*const c_char> for ErrString { 78 | fn from(src: *const c_char) -> Self { 79 | Self(src) 80 | } 81 | } 82 | 83 | impl ToString for ErrString { 84 | fn to_string(&self) -> String { 85 | let ret = unsafe { ffi::CStr::from_ptr(self.0).to_str() }; 86 | ret.map(String::from).unwrap_or_else(|_| { 87 | String::from("error string include invalid char and can not be displayed") 88 | }) 89 | } 90 | } 91 | 92 | fn must_get_const_bytes(p: *const u8) -> [u8; T] { 93 | let bytes = unsafe { std::slice::from_raw_parts(p, T) }; 94 | let bytes = bytes 95 | .try_into() 96 | .expect("the buf has been set to specified bytes"); 97 | unsafe { FreeBuffer(p.cast()) } 98 | bytes 99 | } 100 | 101 | fn must_get_hash(p: *const u8) -> Hash { 102 | must_get_const_bytes::(p) 103 | } 104 | 105 | pub struct ZkMemoryDb { 106 | db: *mut MemoryDb, 107 | } 108 | 109 | impl Drop for ZkMemoryDb { 110 | fn drop(&mut self) { 111 | unsafe { FreeMemoryDb(self.db) }; 112 | } 113 | } 114 | 115 | pub struct ZkTrieNode { 116 | trie_node: *const TrieNode, 117 | } 118 | 119 | impl Drop for ZkTrieNode { 120 | fn drop(&mut self) { 121 | unsafe { FreeTrieNode(self.trie_node) }; 122 | } 123 | } 124 | 125 | impl ZkTrieNode { 126 | pub fn parse(data: &[u8]) -> Result { 127 | let trie_node = unsafe { NewTrieNode(data.as_ptr(), c_int::try_from(data.len()).unwrap()) }; 128 | if trie_node.is_null() { 129 | Err(format!("Can not parse {data:#x?}")) 130 | } else { 131 | Ok(Self { trie_node }) 132 | } 133 | } 134 | 135 | pub fn node_hash(&self) -> Hash { 136 | must_get_hash(unsafe { TrieNodeHash(self.trie_node) }) 137 | } 138 | 139 | pub fn is_tip(&self) -> bool { 140 | let is_tip = unsafe { TrieNodeIsTip(self.trie_node) }; 141 | is_tip != 0 142 | } 143 | 144 | pub fn as_account(&self) -> Option { 145 | if self.is_tip() { 146 | let ret = unsafe { TrieNodeData(self.trie_node, ACCOUNTSIZE as i32) }; 147 | if ret.is_null() { 148 | None 149 | } else { 150 | let ret_byte = must_get_const_bytes(ret); 151 | unsafe { 152 | Some(std::mem::transmute::< 153 | [u8; FIELDSIZE * ACCOUNTFIELDS], 154 | AccountData, 155 | >(ret_byte)) 156 | } 157 | } 158 | } else { 159 | None 160 | } 161 | } 162 | 163 | pub fn as_storage(&self) -> Option { 164 | if self.is_tip() { 165 | let ret = unsafe { TrieNodeData(self.trie_node, 32) }; 166 | if ret.is_null() { 167 | None 168 | } else { 169 | Some(must_get_const_bytes::<32>(ret)) 170 | } 171 | } else { 172 | None 173 | } 174 | } 175 | 176 | pub fn value_hash(&self) -> Option { 177 | let key_p = unsafe { TrieLeafNodeValueHash(self.trie_node) }; 178 | if key_p.is_null() { 179 | None 180 | } else { 181 | Some(must_get_hash(key_p)) 182 | } 183 | } 184 | } 185 | 186 | pub struct ZkTrie { 187 | trie: *mut Trie, 188 | binding_db: Rc, 189 | } 190 | 191 | impl Drop for ZkTrie { 192 | fn drop(&mut self) { 193 | unsafe { FreeZkTrie(self.trie) }; 194 | } 195 | } 196 | 197 | impl Clone for ZkTrie { 198 | fn clone(&self) -> Self { 199 | self.binding_db 200 | .new_trie(&self.root()) 201 | .expect("valid under clone") 202 | } 203 | } 204 | 205 | impl ZkMemoryDb { 206 | pub fn new() -> Rc { 207 | Rc::new(Self { 208 | db: unsafe { NewMemoryDb() }, 209 | }) 210 | } 211 | 212 | pub fn add_node_bytes(self: &mut Rc, data: &[u8]) -> Result<(), ErrString> { 213 | let ret_ptr = unsafe { InitDbByNode(self.db, data.as_ptr(), data.len() as c_int) }; 214 | if ret_ptr.is_null() { 215 | Ok(()) 216 | } else { 217 | Err(ret_ptr.into()) 218 | } 219 | } 220 | 221 | // the zktrie can be created only if the corresponding root node has been added 222 | pub fn new_trie(self: &Rc, root: &Hash) -> Option { 223 | let ret = unsafe { NewZkTrie(root.as_ptr(), self.db) }; 224 | 225 | if ret.is_null() { 226 | None 227 | } else { 228 | Some(ZkTrie { 229 | trie: ret, 230 | binding_db: self.clone(), 231 | }) 232 | } 233 | } 234 | } 235 | 236 | impl ZkTrie { 237 | extern "C" fn prove_callback(data: *const u8, data_sz: c_int, out_p: *mut c_void) { 238 | let output = unsafe { 239 | out_p 240 | .cast::>>() 241 | .as_mut() 242 | .expect("callback parameter can not be zero") 243 | }; 244 | let buf = unsafe { std::slice::from_raw_parts(data, data_sz as usize) }; 245 | output.push(Vec::from(buf)) 246 | } 247 | 248 | pub fn root(&self) -> Hash { 249 | must_get_hash(unsafe { TrieRoot(self.trie) }) 250 | } 251 | 252 | pub fn get_db(&self) -> Rc { 253 | self.binding_db.clone() 254 | } 255 | 256 | // all errors are reduced to "not found" 257 | fn get(&self, key: &[u8]) -> Option<[u8; T]> { 258 | let ret = unsafe { TrieGetSize(self.trie, key.as_ptr(), key.len() as c_int, T as c_int) }; 259 | 260 | if ret.is_null() { 261 | None 262 | } else { 263 | Some(must_get_const_bytes::(ret)) 264 | } 265 | } 266 | 267 | // get value from storage trie 268 | pub fn get_store(&self, key: &[u8]) -> Option { 269 | self.get::<32>(key) 270 | } 271 | 272 | // get account data from account trie 273 | pub fn get_account(&self, key: &[u8]) -> Option { 274 | self.get::(key).map(|arr| unsafe { 275 | std::mem::transmute::<[u8; FIELDSIZE * ACCOUNTFIELDS], AccountData>(arr) 276 | }) 277 | } 278 | 279 | // build prove array for mpt path 280 | pub fn prove(&self, key: &[u8]) -> Result>, ErrString> { 281 | let mut output: Vec> = Vec::new(); 282 | let ptr: *mut Vec> = &mut output; 283 | 284 | let ret_ptr = unsafe { 285 | TrieProve( 286 | self.trie, 287 | key.as_ptr(), 288 | key.len() as c_int, 289 | Self::prove_callback, 290 | ptr.cast(), 291 | ) 292 | }; 293 | if ret_ptr.is_null() { 294 | Ok(output) 295 | } else { 296 | Err(ret_ptr.into()) 297 | } 298 | } 299 | 300 | fn update(&mut self, key: &[u8], value: &[u8; T]) -> Result<(), ErrString> { 301 | let ret_ptr = unsafe { 302 | TrieUpdate( 303 | self.trie, 304 | key.as_ptr(), 305 | key.len() as c_int, 306 | value.as_ptr(), 307 | T as c_int, 308 | ) 309 | }; 310 | if ret_ptr.is_null() { 311 | Ok(()) 312 | } else { 313 | Err(ret_ptr.into()) 314 | } 315 | } 316 | 317 | pub fn update_store(&mut self, key: &[u8], value: &StoreData) -> Result<(), ErrString> { 318 | self.update(key, value) 319 | } 320 | 321 | pub fn update_account( 322 | &mut self, 323 | key: &[u8], 324 | acc_fields: &AccountData, 325 | ) -> Result<(), ErrString> { 326 | let acc_buf: &[u8; FIELDSIZE * ACCOUNTFIELDS] = unsafe { 327 | let ptr = acc_fields.as_ptr(); 328 | ptr.cast::<[u8; FIELDSIZE * ACCOUNTFIELDS]>() 329 | .as_ref() 330 | .expect("casted ptr can not be null") 331 | }; 332 | 333 | self.update(key, acc_buf) 334 | } 335 | 336 | pub fn delete(&mut self, key: &[u8]) { 337 | unsafe { 338 | TrieDelete(self.trie, key.as_ptr(), key.len() as c_int); 339 | } 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /lib.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | /* 4 | #include 5 | #include 6 | 7 | typedef char* (*hashF)(unsigned char*, unsigned char*, unsigned char*); 8 | typedef void (*proveWriteF)(unsigned char*, int, void*); 9 | 10 | extern hashF hash_scheme; 11 | 12 | char* bridge_hash(unsigned char* a, unsigned char* b, unsigned char* domain, unsigned char* out); 13 | void init_hash_scheme(hashF f); 14 | void bridge_prove_write(proveWriteF f, unsigned char* key, unsigned char* val, int size, void* param); 15 | 16 | */ 17 | import "C" 18 | import ( 19 | "errors" 20 | "fmt" 21 | "math/big" 22 | "runtime/cgo" 23 | "unsafe" 24 | 25 | "github.com/scroll-tech/zktrie/trie" 26 | zkt "github.com/scroll-tech/zktrie/types" 27 | ) 28 | 29 | var zeros = [32]byte{} 30 | 31 | func hash_external(inp []*big.Int, domain *big.Int) (*big.Int, error) { 32 | if len(inp) != 2 { 33 | return big.NewInt(0), errors.New("invalid input size") 34 | } 35 | a := zkt.ReverseByteOrder(inp[0].Bytes()) 36 | b := zkt.ReverseByteOrder(inp[1].Bytes()) 37 | dm := zkt.ReverseByteOrder(domain.Bytes()) 38 | 39 | a = append(a, zeros[0:(32-len(a))]...) 40 | b = append(b, zeros[0:(32-len(b))]...) 41 | dm = append(dm, zeros[0:(32-len(dm))]...) 42 | 43 | c := make([]byte, 32) 44 | 45 | err := C.bridge_hash((*C.uchar)(&a[0]), (*C.uchar)(&b[0]), (*C.uchar)(&dm[0]), (*C.uchar)(&c[0])) 46 | 47 | if err != nil { 48 | return big.NewInt(0), errors.New(C.GoString(err)) 49 | } 50 | 51 | return big.NewInt(0).SetBytes(zkt.ReverseByteOrder(c)), nil 52 | } 53 | 54 | //export TestHashScheme 55 | func TestHashScheme() { 56 | h1, err := hash_external([]*big.Int{big.NewInt(1), big.NewInt(2)}, big.NewInt(0)) 57 | if err != nil { 58 | panic(err) 59 | } 60 | expected := big.NewInt(0) 61 | expected.UnmarshalText([]byte("7853200120776062878684798364095072458815029376092732009249414926327459813530")) 62 | if h1.Cmp(expected) != 0 { 63 | panic(fmt.Errorf("unexpected poseidon hash value: %s", h1)) 64 | } 65 | 66 | h2, err := hash_external([]*big.Int{big.NewInt(1), big.NewInt(2)}, big.NewInt(256)) 67 | if err != nil { 68 | panic(err) 69 | } 70 | expected.UnmarshalText([]byte("2362370911616048355006851495576377379220050231129891536935411970097789775493")) 71 | if h2.Cmp(expected) != 0 { 72 | panic(fmt.Errorf("unexpected poseidon hash value: %s", h1)) 73 | } 74 | } 75 | 76 | // notice the function must use C calling convention 77 | // 78 | //export InitHashScheme 79 | func InitHashScheme(f unsafe.Pointer) { 80 | hash_f := C.hashF(f) 81 | C.init_hash_scheme(hash_f) 82 | zkt.InitHashScheme(hash_external) 83 | } 84 | 85 | // parse raw bytes and create the trie node 86 | // 87 | //export NewTrieNode 88 | func NewTrieNode(data *C.char, sz C.int) C.uintptr_t { 89 | bt := C.GoBytes(unsafe.Pointer(data), sz) 90 | n, err := trie.NewNodeFromBytes(bt) 91 | if err != nil { 92 | return 0 93 | } 94 | 95 | // calculate key for caching 96 | if _, err := n.NodeHash(); err != nil { 97 | return 0 98 | } 99 | 100 | return C.uintptr_t(cgo.NewHandle(n)) 101 | } 102 | 103 | // obtain the key hash, must be free by caller 104 | // 105 | //export TrieNodeHash 106 | func TrieNodeHash(pN C.uintptr_t) unsafe.Pointer { 107 | h := cgo.Handle(pN) 108 | n := h.Value().(*trie.Node) 109 | 110 | hash, _ := n.NodeHash() 111 | return C.CBytes(hash.Bytes()) 112 | } 113 | 114 | // obtain the data of node if it is leaf, must be free by caller 115 | // or nil for other type 116 | // if val_sz is not 0 and the value size is not equal to val_sz, 117 | // it is also return nil 118 | // 119 | //export TrieNodeData 120 | func TrieNodeData(pN C.uintptr_t, val_sz C.int) unsafe.Pointer { 121 | h := cgo.Handle(pN) 122 | n := h.Value().(*trie.Node) 123 | 124 | if d := n.Data(); d != nil { 125 | // safety check 126 | if expected_sz := int(val_sz); expected_sz != 0 && len(d) != int(val_sz) { 127 | return nil 128 | } 129 | 130 | return C.CBytes(d) 131 | } else { 132 | return nil 133 | } 134 | } 135 | 136 | // test if the node is tip type (i.e. leaf or empty) 137 | // 138 | //export TrieNodeIsTip 139 | func TrieNodeIsTip(pN C.uintptr_t) C.int { 140 | h := cgo.Handle(pN) 141 | n := h.Value().(*trie.Node) 142 | 143 | if n.IsTerminal() { 144 | return 1 145 | } else { 146 | return 0 147 | } 148 | } 149 | 150 | // obtain the value hash for leaf node (must be free by caller), or nil for other 151 | // 152 | //export TrieLeafNodeValueHash 153 | func TrieLeafNodeValueHash(pN C.uintptr_t) unsafe.Pointer { 154 | h := cgo.Handle(pN) 155 | n := h.Value().(*trie.Node) 156 | 157 | if n.Type != trie.NodeTypeLeaf_New { 158 | return nil 159 | } 160 | 161 | valueHash, _ := n.ValueHash() 162 | return C.CBytes(valueHash.Bytes()) 163 | } 164 | 165 | // free created trie node 166 | // 167 | //export FreeTrieNode 168 | func FreeTrieNode(p C.uintptr_t) { freeObject(p) } 169 | 170 | // create memory db 171 | // 172 | //export NewMemoryDb 173 | func NewMemoryDb() C.uintptr_t { 174 | // it break the cgo's enforcement (C code can not store Go pointer after return) 175 | // but it should be ok for we have kept reference in the global object 176 | ret := trie.NewZkTrieMemoryDb() 177 | 178 | return C.uintptr_t(cgo.NewHandle(ret)) 179 | } 180 | 181 | func freeObject(p C.uintptr_t) { 182 | h := cgo.Handle(p) 183 | h.Delete() 184 | } 185 | 186 | // free created memory db 187 | // 188 | //export FreeMemoryDb 189 | func FreeMemoryDb(p C.uintptr_t) { freeObject(p) } 190 | 191 | // free created trie 192 | // 193 | //export FreeZkTrie 194 | func FreeZkTrie(p C.uintptr_t) { freeObject(p) } 195 | 196 | // free buffers being returned, like error strings or trie value 197 | // 198 | //export FreeBuffer 199 | func FreeBuffer(p unsafe.Pointer) { 200 | C.free(p) 201 | } 202 | 203 | // flush db with encoded trie-node bytes 204 | // used for initialize the database, in a thread-unsafe fashion 205 | // 206 | //export InitDbByNode 207 | func InitDbByNode(pDb C.uintptr_t, data *C.uchar, sz C.int) *C.char { 208 | h := cgo.Handle(pDb) 209 | db := h.Value().(*trie.Database) 210 | 211 | bt := C.GoBytes(unsafe.Pointer(data), sz) 212 | n, err := trie.DecodeSMTProof(bt) 213 | if err != nil { 214 | return C.CString(err.Error()) 215 | } else if n == nil { 216 | //skip magic string 217 | return nil 218 | } 219 | 220 | hash, err := n.NodeHash() 221 | if err != nil { 222 | return C.CString(err.Error()) 223 | } 224 | 225 | db.Init(hash[:], n.CanonicalValue()) 226 | return nil 227 | } 228 | 229 | // the input root must be 32bytes (or more, but only first 32bytes would be recognized) 230 | // 231 | //export NewZkTrie 232 | func NewZkTrie(root_c *C.uchar, pDb C.uintptr_t) C.uintptr_t { 233 | h := cgo.Handle(pDb) 234 | db := h.Value().(*trie.Database) 235 | root := C.GoBytes(unsafe.Pointer(root_c), 32) 236 | 237 | zktrie, err := trie.NewZkTrie(*zkt.NewByte32FromBytes(root), db) 238 | if err != nil { 239 | return 0 240 | } 241 | 242 | return C.uintptr_t(cgo.NewHandle(zktrie)) 243 | } 244 | 245 | // currently it is caller's responsibility to distinguish what 246 | // the returned buffer is byte32 or encoded account data (4x32bytes fields for original account 247 | // or 6x32bytes fields for 'dual-codehash' extended account) 248 | // 249 | //export TrieGet 250 | func TrieGet(p C.uintptr_t, key_c *C.uchar, key_sz C.int) unsafe.Pointer { 251 | h := cgo.Handle(p) 252 | tr := h.Value().(*trie.ZkTrie) 253 | key := C.GoBytes(unsafe.Pointer(key_c), key_sz) 254 | 255 | v, err := tr.TryGet(key) 256 | if v == nil || err != nil { 257 | return nil 258 | } 259 | //sanity check 260 | if val_sz := len(v); val_sz != 32 && val_sz != 32*4 && val_sz != 32*5 { 261 | // unexpected val size which is to be recognized by caller, so just filter it 262 | return nil 263 | } 264 | 265 | return C.CBytes(v) 266 | } 267 | 268 | // variant of TrieGet that specifies the expected value size for safety; if the actual value 269 | // size does not match the expected value size, it returns nil instead of leading to undefined 270 | // behavior. 271 | // 272 | //export TrieGetSize 273 | func TrieGetSize(p C.uintptr_t, key_c *C.uchar, key_sz C.int, val_sz C.int) unsafe.Pointer { 274 | h := cgo.Handle(p) 275 | tr := h.Value().(*trie.ZkTrie) 276 | key := C.GoBytes(unsafe.Pointer(key_c), key_sz) 277 | 278 | v, err := tr.TryGet(key) 279 | if v == nil || err != nil { 280 | return nil 281 | } 282 | 283 | // safety check 284 | if len(v) != int(val_sz) { 285 | return nil 286 | } 287 | 288 | return C.CBytes(v) 289 | } 290 | 291 | // update only accept encoded buffer, and flag is derived automatically from buffer size (account data or store val) 292 | // 293 | //export TrieUpdate 294 | func TrieUpdate(p C.uintptr_t, key_c *C.uchar, key_sz C.int, val_c *C.uchar, val_sz C.int) *C.char { 295 | 296 | if val_sz != 32 && val_sz != 128 && val_sz != 160 { 297 | return C.CString("unexpected buffer type") 298 | } 299 | 300 | var vFlag uint32 301 | if val_sz == 160 { 302 | vFlag = 8 303 | } else if val_sz == 128 { 304 | vFlag = 4 305 | } else { 306 | vFlag = 1 307 | } 308 | 309 | h := cgo.Handle(p) 310 | tr := h.Value().(*trie.ZkTrie) 311 | key := C.GoBytes(unsafe.Pointer(key_c), key_sz) 312 | var vals []zkt.Byte32 313 | start_ptr := uintptr(unsafe.Pointer(val_c)) 314 | for i := 0; i < int(val_sz); i += 32 { 315 | vals = append(vals, *zkt.NewByte32FromBytes(C.GoBytes(unsafe.Pointer(start_ptr), 32))) 316 | start_ptr += 32 317 | } 318 | 319 | err := tr.TryUpdate(key, vFlag, vals) 320 | if err != nil { 321 | return C.CString(err.Error()) 322 | } 323 | return nil 324 | } 325 | 326 | // delete leaf, silently omit any error 327 | // 328 | //export TrieDelete 329 | func TrieDelete(p C.uintptr_t, key_c *C.uchar, key_sz C.int) { 330 | h := cgo.Handle(p) 331 | tr := h.Value().(*trie.ZkTrie) 332 | key := C.GoBytes(unsafe.Pointer(key_c), key_sz) 333 | tr.TryDelete(key) 334 | } 335 | 336 | // output prove, only the val part is output for callback 337 | // 338 | //export TrieProve 339 | func TrieProve(p C.uintptr_t, key_c *C.uchar, key_sz C.int, callback unsafe.Pointer, cb_param unsafe.Pointer) *C.char { 340 | h := cgo.Handle(p) 341 | tr := h.Value().(*trie.ZkTrie) 342 | key := C.GoBytes(unsafe.Pointer(key_c), key_sz) 343 | s_key, err := zkt.ToSecureKeyBytes(key) 344 | if err != nil { 345 | return C.CString(err.Error()) 346 | } 347 | 348 | err = tr.Prove(s_key.Bytes(), 0, func(n *trie.Node) error { 349 | 350 | dt := n.Value() 351 | 352 | C.bridge_prove_write( 353 | C.proveWriteF(callback), 354 | nil, //do not need to prove node key 355 | (*C.uchar)(&dt[0]), 356 | C.int(len(dt)), 357 | cb_param, 358 | ) 359 | 360 | return nil 361 | }) 362 | if err != nil { 363 | return C.CString(err.Error()) 364 | } 365 | 366 | tailingLine := trie.ProofMagicBytes() 367 | C.bridge_prove_write( 368 | C.proveWriteF(callback), 369 | nil, //do not need to prove node key 370 | (*C.uchar)(&tailingLine[0]), 371 | C.int(len(tailingLine)), 372 | cb_param, 373 | ) 374 | 375 | return nil 376 | } 377 | 378 | // obtain the hash 379 | // 380 | //export TrieRoot 381 | func TrieRoot(p C.uintptr_t) unsafe.Pointer { 382 | h := cgo.Handle(p) 383 | tr := h.Value().(*trie.ZkTrie) 384 | return C.CBytes(tr.Hash()) 385 | } 386 | 387 | func main() {} 388 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /trie/zk_trie_node.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "math/big" 7 | "reflect" 8 | "slices" 9 | "unsafe" 10 | 11 | zkt "github.com/scroll-tech/zktrie/types" 12 | ) 13 | 14 | // NodeType defines the type of node in the MT. 15 | type NodeType byte 16 | 17 | const ( 18 | // NodeTypeParent indicates the type of parent Node that has children. 19 | NodeTypeParent NodeType = 0 20 | // NodeTypeLeaf indicates the type of a leaf Node that contains a key & 21 | // value. 22 | NodeTypeLeaf NodeType = 1 23 | // NodeTypeEmpty indicates the type of an empty Node. 24 | NodeTypeEmpty NodeType = 2 25 | 26 | // DBEntryTypeRoot indicates the type of a DB entry that indicates the 27 | // current Root of a MerkleTree 28 | DBEntryTypeRoot NodeType = 3 29 | 30 | NodeTypeLeaf_New NodeType = 4 31 | NodeTypeEmpty_New NodeType = 5 32 | // branch node for both child are terminal nodes 33 | NodeTypeBranch_0 NodeType = 6 34 | // branch node for left child is terminal node and right child is branch 35 | NodeTypeBranch_1 NodeType = 7 36 | // branch node for left child is branch node and right child is terminal 37 | NodeTypeBranch_2 NodeType = 8 38 | // branch node for both child are branch nodes 39 | NodeTypeBranch_3 NodeType = 9 40 | ) 41 | 42 | // DeduceUploadType deduce a new branch type from current branch when one of its child become non-terminal 43 | func (n NodeType) DeduceUpgradeType(goRight bool) NodeType { 44 | if goRight { 45 | switch n { 46 | case NodeTypeBranch_0: 47 | return NodeTypeBranch_1 48 | case NodeTypeBranch_1: 49 | return n 50 | case NodeTypeBranch_2, NodeTypeBranch_3: 51 | return NodeTypeBranch_3 52 | } 53 | } else { 54 | switch n { 55 | case NodeTypeBranch_0: 56 | return NodeTypeBranch_2 57 | case NodeTypeBranch_1, NodeTypeBranch_3: 58 | return NodeTypeBranch_3 59 | case NodeTypeBranch_2: 60 | return n 61 | } 62 | } 63 | 64 | panic(fmt.Errorf("invalid NodeType: %d", n)) 65 | } 66 | 67 | // DeduceDowngradeType deduce a new branch type from current branch when one of its child become terminal 68 | func (n NodeType) DeduceDowngradeType(atRight bool) NodeType { 69 | if atRight { 70 | switch n { 71 | case NodeTypeBranch_1: 72 | return NodeTypeBranch_0 73 | case NodeTypeBranch_3: 74 | return NodeTypeBranch_2 75 | case NodeTypeBranch_0, NodeTypeBranch_2: 76 | panic(fmt.Errorf("can not downgrade a node with terminal child (%d)", n)) 77 | } 78 | } else { 79 | switch n { 80 | case NodeTypeBranch_3: 81 | return NodeTypeBranch_1 82 | case NodeTypeBranch_2: 83 | return NodeTypeBranch_0 84 | case NodeTypeBranch_0, NodeTypeBranch_1: 85 | panic(fmt.Errorf("can not downgrade a node with terminal child (%d)", n)) 86 | } 87 | } 88 | panic(fmt.Errorf("invalid NodeType: %d", n)) 89 | } 90 | 91 | // Node is the struct that represents a node in the MT. The node should not be 92 | // modified after creation because the cached key won't be updated. 93 | type Node struct { 94 | // Type is the type of node in the tree. 95 | Type NodeType 96 | // ChildL is the node hash of the left child of a parent node. 97 | ChildL *zkt.Hash 98 | // ChildR is the node hash of the right child of a parent node. 99 | ChildR *zkt.Hash 100 | // NodeKey is the node's key stored in a leaf node. 101 | NodeKey *zkt.Hash 102 | // ValuePreimage can store at most 256 byte32 as fields (represnted by BIG-ENDIAN integer) 103 | // and the first 24 can be compressed (each bytes32 consider as 2 fields), in hashing the compressed 104 | // elemments would be calculated first 105 | ValuePreimage []zkt.Byte32 106 | // CompressedFlags use each bit for indicating the compressed flag for the first 24 fields 107 | CompressedFlags uint32 108 | // nodeHash is the cache of the hash of the node to avoid recalculating 109 | nodeHash *zkt.Hash 110 | // valueHash is the cache of the hash of valuePreimage to avoid recalculating, only valid for leaf node 111 | valueHash *zkt.Hash 112 | // KeyPreimage is the original key value that derives the NodeKey, kept here only for proof 113 | KeyPreimage *zkt.Byte32 114 | } 115 | 116 | // NewLeafNode creates a new leaf node. 117 | func NewLeafNode(k *zkt.Hash, valueFlags uint32, valuePreimage []zkt.Byte32) *Node { 118 | return &Node{Type: NodeTypeLeaf_New, NodeKey: k, CompressedFlags: valueFlags, ValuePreimage: valuePreimage} 119 | } 120 | 121 | // NewParentNode creates a new parent node. 122 | func NewParentNode(ntype NodeType, childL *zkt.Hash, childR *zkt.Hash) *Node { 123 | return &Node{Type: ntype, ChildL: childL, ChildR: childR} 124 | } 125 | 126 | // NewEmptyNode creates a new empty node. 127 | func NewEmptyNode() *Node { 128 | return &Node{Type: NodeTypeEmpty_New} 129 | } 130 | 131 | // NewNodeFromBytes creates a new node by parsing the input []byte. 132 | func NewNodeFromBytes(b []byte) (*Node, error) { 133 | var n Node 134 | if err := n.SetBytes(b); err != nil { 135 | return nil, err 136 | } 137 | return &n, nil 138 | } 139 | 140 | // LeafHash computes the key of a leaf node given the hIndex and hValue of the 141 | // entry of the leaf. 142 | func LeafHash(k, v *zkt.Hash) (*zkt.Hash, error) { 143 | return zkt.HashElemsWithDomain(big.NewInt(int64(NodeTypeLeaf_New)), k.BigInt(), v.BigInt()) 144 | } 145 | 146 | func (n *Node) SetBytes(b []byte) error { 147 | if len(b) < 1 { 148 | return ErrNodeBytesBadSize 149 | } 150 | nType := NodeType(b[0]) 151 | b = b[1:] 152 | switch nType { 153 | case NodeTypeParent, NodeTypeBranch_0, 154 | NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: 155 | if len(b) != 2*zkt.HashByteLen { 156 | return ErrNodeBytesBadSize 157 | } 158 | 159 | childL := n.ChildL 160 | childR := n.ChildR 161 | 162 | if childL == nil { 163 | childL = zkt.NewHashFromBytes(b[:zkt.HashByteLen]) 164 | } else { 165 | childL.SetBytes(b[:zkt.HashByteLen]) 166 | } 167 | 168 | if childR == nil { 169 | childR = zkt.NewHashFromBytes(b[zkt.HashByteLen : zkt.HashByteLen*2]) 170 | } else { 171 | childR.SetBytes(b[zkt.HashByteLen : zkt.HashByteLen*2]) 172 | } 173 | 174 | *n = Node{ 175 | Type: nType, 176 | ChildL: childL, 177 | ChildR: childR, 178 | } 179 | case NodeTypeLeaf, NodeTypeLeaf_New: 180 | if len(b) < zkt.HashByteLen+4 { 181 | return ErrNodeBytesBadSize 182 | } 183 | nodeKey := zkt.NewHashFromBytes(b[0:zkt.HashByteLen]) 184 | mark := binary.LittleEndian.Uint32(b[zkt.HashByteLen : zkt.HashByteLen+4]) 185 | preimageLen := int(mark & 255) 186 | compressedFlags := mark >> 8 187 | valuePreimage := slices.Grow(n.ValuePreimage[0:], preimageLen) 188 | curPos := zkt.HashByteLen + 4 189 | if len(b) < curPos+preimageLen*32+1 { 190 | return ErrNodeBytesBadSize 191 | } 192 | for i := 0; i < preimageLen; i++ { 193 | var byte32 zkt.Byte32 194 | copy(byte32[:], b[i*32+curPos:(i+1)*32+curPos]) 195 | valuePreimage = append(valuePreimage, byte32) 196 | } 197 | curPos += preimageLen * 32 198 | preImageSize := int(b[curPos]) 199 | curPos += 1 200 | 201 | var keyPreimage *zkt.Byte32 202 | if preImageSize != 0 { 203 | if len(b) < curPos+preImageSize { 204 | return ErrNodeBytesBadSize 205 | } 206 | 207 | keyPreimage = n.KeyPreimage 208 | if keyPreimage == nil { 209 | keyPreimage = new(zkt.Byte32) 210 | } 211 | copy(keyPreimage[:], b[curPos:curPos+preImageSize]) 212 | } 213 | 214 | *n = Node{ 215 | Type: nType, 216 | NodeKey: nodeKey, 217 | CompressedFlags: compressedFlags, 218 | ValuePreimage: valuePreimage, 219 | KeyPreimage: keyPreimage, 220 | } 221 | case NodeTypeEmpty, NodeTypeEmpty_New: 222 | *n = Node{Type: nType} 223 | default: 224 | return ErrInvalidNodeFound 225 | } 226 | return nil 227 | } 228 | 229 | // IsTerminal returns if the node is 'terminated', i.e. empty or leaf node 230 | func (n *Node) IsTerminal() bool { 231 | switch n.Type { 232 | case NodeTypeEmpty_New, NodeTypeLeaf_New: 233 | return true 234 | case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: 235 | return false 236 | case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: 237 | panic("encounter deprecated node types") 238 | default: 239 | panic(fmt.Errorf("encounter unknown node types %d", n.Type)) 240 | } 241 | 242 | } 243 | 244 | // NodeHash computes the hash digest of the node by hashing the content in a 245 | // specific way for each type of node. This key is used as the hash of the 246 | // Merkle tree for each node. 247 | func (n *Node) NodeHash() (*zkt.Hash, error) { 248 | if n.nodeHash == nil { // Cache the key to avoid repeated hash computations. 249 | // NOTE: We are not using the type to calculate the hash! 250 | switch n.Type { 251 | case NodeTypeBranch_0, 252 | NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: // H(ChildL || ChildR) 253 | var err error 254 | n.nodeHash, err = zkt.HashElemsWithDomain(big.NewInt(int64(n.Type)), 255 | n.ChildL.BigInt(), n.ChildR.BigInt()) 256 | if err != nil { 257 | return nil, err 258 | } 259 | case NodeTypeLeaf_New: 260 | var err error 261 | n.valueHash, err = zkt.HandlingElemsAndByte32(n.CompressedFlags, n.ValuePreimage) 262 | if err != nil { 263 | return nil, err 264 | } 265 | 266 | n.nodeHash, err = LeafHash(n.NodeKey, n.valueHash) 267 | if err != nil { 268 | return nil, err 269 | } 270 | 271 | case NodeTypeEmpty_New: // Zero 272 | n.nodeHash = &zkt.HashZero 273 | case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: 274 | panic("encounter deprecated node types") 275 | default: 276 | n.nodeHash = &zkt.HashZero 277 | } 278 | } 279 | return n.nodeHash, nil 280 | } 281 | 282 | // ValueHash computes the hash digest of the value stored in the leaf node. For 283 | // other node types, it returns the zero hash. 284 | func (n *Node) ValueHash() (*zkt.Hash, error) { 285 | if n.Type != NodeTypeLeaf_New { 286 | return &zkt.HashZero, nil 287 | } 288 | if _, err := n.NodeHash(); err != nil { 289 | return nil, err 290 | } 291 | return n.valueHash, nil 292 | } 293 | 294 | // Data returns the wrapped data inside LeafNode and cast them into bytes 295 | // for other node type it just return nil 296 | func (n *Node) Data() []byte { 297 | switch n.Type { 298 | case NodeTypeLeaf_New: 299 | var data []byte 300 | hdata := (*reflect.SliceHeader)(unsafe.Pointer(&data)) 301 | //TODO: uintptr(reflect.ValueOf(n.ValuePreimage).UnsafePointer()) should be more elegant but only available until go 1.18 302 | hdata.Data = uintptr(unsafe.Pointer(&n.ValuePreimage[0])) 303 | hdata.Len = 32 * len(n.ValuePreimage) 304 | hdata.Cap = hdata.Len 305 | return data 306 | default: 307 | return nil 308 | } 309 | } 310 | 311 | // CanonicalValue returns the byte form of a node required to be persisted, and strip unnecessary fields 312 | // from the encoding (current only KeyPreimage for Leaf node) to keep a minimum size for content being 313 | // stored in backend storage 314 | func (n *Node) CanonicalValue() []byte { 315 | switch n.Type { 316 | case NodeTypeBranch_0, NodeTypeBranch_1, NodeTypeBranch_2, NodeTypeBranch_3: // {Type || ChildL || ChildR} 317 | bytes := []byte{byte(n.Type)} 318 | bytes = append(bytes, n.ChildL.Bytes()...) 319 | bytes = append(bytes, n.ChildR.Bytes()...) 320 | return bytes 321 | case NodeTypeLeaf_New: // {Type || Data...} 322 | bytes := []byte{byte(n.Type)} 323 | bytes = append(bytes, n.NodeKey.Bytes()...) 324 | tmp := make([]byte, 4) 325 | compressedFlag := (n.CompressedFlags << 8) + uint32(len(n.ValuePreimage)) 326 | binary.LittleEndian.PutUint32(tmp, compressedFlag) 327 | bytes = append(bytes, tmp...) 328 | for _, elm := range n.ValuePreimage { 329 | bytes = append(bytes, elm[:]...) 330 | } 331 | bytes = append(bytes, 0) 332 | return bytes 333 | case NodeTypeEmpty_New: // { Type } 334 | return []byte{byte(n.Type)} 335 | case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: 336 | panic("encounter deprecated node types") 337 | default: 338 | return []byte{} 339 | } 340 | } 341 | 342 | // Value returns the encoded bytes of a node, include all information of it 343 | func (n *Node) Value() []byte { 344 | outBytes := n.CanonicalValue() 345 | switch n.Type { 346 | case NodeTypeLeaf_New: // {Type || Data...} 347 | if n.KeyPreimage != nil { 348 | outBytes[len(outBytes)-1] = byte(len(n.KeyPreimage)) 349 | outBytes = append(outBytes, n.KeyPreimage[:]...) 350 | } 351 | } 352 | 353 | return outBytes 354 | } 355 | 356 | // String outputs a string representation of a node (different for each type). 357 | func (n *Node) String() string { 358 | switch n.Type { 359 | // {Type || ChildL || ChildR} 360 | case NodeTypeBranch_0: 361 | return fmt.Sprintf("Parent L(t):%s R(t):%s", n.ChildL, n.ChildR) 362 | case NodeTypeBranch_1: 363 | return fmt.Sprintf("Parent L(t):%s R:%s", n.ChildL, n.ChildR) 364 | case NodeTypeBranch_2: 365 | return fmt.Sprintf("Parent L:%s R(t):%s", n.ChildL, n.ChildR) 366 | case NodeTypeBranch_3: 367 | return fmt.Sprintf("Parent L:%s R:%s", n.ChildL, n.ChildR) 368 | case NodeTypeLeaf_New: // {Type || Data...} 369 | return fmt.Sprintf("Leaf I:%v Items: %d, First:%v", n.NodeKey, len(n.ValuePreimage), n.ValuePreimage[0]) 370 | case NodeTypeEmpty_New: // {} 371 | return "Empty" 372 | case NodeTypeEmpty, NodeTypeLeaf, NodeTypeParent: 373 | return "deprecated Node" 374 | default: 375 | return "Invalid Node" 376 | } 377 | } 378 | 379 | // Copy creates a new Node instance from the given node 380 | func (n *Node) Copy() *Node { 381 | newNode, err := NewNodeFromBytes(n.Value()) 382 | if err != nil { 383 | panic("failed to copy trie node") 384 | } 385 | return newNode 386 | } 387 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod constants; 2 | 3 | pub use constants::*; 4 | pub type SimpleHashSchemeFn = fn(&[u8; 32], &[u8; 32], &[u8; 32]) -> Option<[u8; 32]>; 5 | use std::sync::OnceLock; 6 | static HASHSCHEME: OnceLock = OnceLock::new(); 7 | 8 | #[cfg(not(feature = "rs_zktrie"))] 9 | pub mod go_lib; 10 | #[cfg(not(feature = "rs_zktrie"))] 11 | pub use go_lib::*; 12 | #[cfg(not(feature = "rs_zktrie"))] 13 | pub fn init_hash_scheme_simple(f: SimpleHashSchemeFn) { 14 | HASHSCHEME.set(f).unwrap_or_default(); 15 | go_lib::init_hash_scheme(c_hash_scheme_adapter); 16 | } 17 | 18 | #[cfg(feature = "rs_zktrie")] 19 | pub use rs_lib::*; 20 | #[cfg(feature = "rs_zktrie")] 21 | pub mod rs_lib; 22 | #[cfg(feature = "rs_zktrie")] 23 | pub fn init_hash_scheme_simple(f: SimpleHashSchemeFn) { 24 | HASHSCHEME.set(f).unwrap_or_default() 25 | } 26 | 27 | #[allow(dead_code)] 28 | extern "C" fn c_hash_scheme_adapter( 29 | a: *const u8, 30 | b: *const u8, 31 | domain: *const u8, 32 | out: *mut u8, 33 | ) -> *const i8 { 34 | use std::slice; 35 | let a: [u8; 32] = 36 | TryFrom::try_from(unsafe { slice::from_raw_parts(a, 32) }).expect("length specified"); 37 | let b: [u8; 32] = 38 | TryFrom::try_from(unsafe { slice::from_raw_parts(b, 32) }).expect("length specified"); 39 | let domain: [u8; 32] = 40 | TryFrom::try_from(unsafe { slice::from_raw_parts(domain, 32) }).expect("length specified"); 41 | let out = unsafe { slice::from_raw_parts_mut(out, 32) }; 42 | 43 | let h = HASHSCHEME 44 | .get() 45 | .expect("if it is called hash scheme must be initied")(&a, &b, &domain); 46 | 47 | static HASH_OUT_ERROR: &str = "hash scheme can not output"; 48 | if let Some(h) = h { 49 | out.copy_from_slice(&h); 50 | std::ptr::null() 51 | } else { 52 | HASH_OUT_ERROR.as_ptr().cast() 53 | } 54 | } 55 | 56 | #[cfg(test)] 57 | mod tests { 58 | 59 | use super::*; 60 | use halo2_proofs::halo2curves::bn256::Fr; 61 | use halo2_proofs::halo2curves::group::ff::PrimeField; 62 | use poseidon_circuit::Hashable; 63 | 64 | fn poseidon_hash_scheme(a: &[u8; 32], b: &[u8; 32], domain: &[u8; 32]) -> Option<[u8; 32]> { 65 | let fa = Fr::from_bytes(a); 66 | let fa = if fa.is_some().into() { 67 | fa.unwrap() 68 | } else { 69 | return None; 70 | }; 71 | let fb = Fr::from_bytes(b); 72 | let fb = if fb.is_some().into() { 73 | fb.unwrap() 74 | } else { 75 | return None; 76 | }; 77 | let fdomain = Fr::from_bytes(domain); 78 | let fdomain = if fdomain.is_some().into() { 79 | fdomain.unwrap() 80 | } else { 81 | return None; 82 | }; 83 | Some(Fr::hash_with_domain([fa, fb], fdomain).to_repr()) 84 | } 85 | 86 | #[cfg(not(feature = "rs_zktrie"))] 87 | #[link(name = "zktrie")] 88 | extern "C" { 89 | fn TestHashScheme(); 90 | } 91 | 92 | #[cfg(not(feature = "rs_zktrie"))] 93 | #[test] 94 | fn hash_works() { 95 | // check consistency between go poseidon and rust poseidon 96 | init_hash_scheme_simple(poseidon_hash_scheme); 97 | unsafe { 98 | TestHashScheme(); 99 | } 100 | } 101 | 102 | #[allow(dead_code)] 103 | static EXAMPLE : [&str;41] = [ 104 | "0x09218bcaf094949451aaea2273a4092c7116839ad69df7597df06c7bf741a9477f01020df75837d8a760bfb941f3465f63812b205ac7e1fff5d310a2a3295e60c8", 105 | "0x0913e957fbc8585b40175129d3547a76b9fc3a1c3b16a6ca4de468879bb08fcbb6104a71f54260a0430906c4a0c3cc5eb459dd132b637c944ea92b769a98dba762", 106 | "0x092c2eae4f5273c398709da3e317c86a3a817008c98269bf2766405259c488306628c0c92eb1f16fc59b8b99e0a8abee3f88afb477c4d36be3571d537b076e0f83", 107 | "0x0800100f66e758c81427817699eeed67308bc9a7ee8054f2cbd463b7bf252610af233b07e4b000250359a56ef55485036e6d4dbca7c71bf82812790ac3f4a5238e", 108 | "0x08088158f4dfd26b06688c646a453c1b52710139a064b0394b47a0693c2bee46a4159d39c4d2776406bca63dfba405861d669f6220a087ad4b204e1cca52c7be5f", 109 | "0x062b2d9de4b02c2bab78264918866524e44e6efdc24bf0be2d4a8aa6f9b232a7781cb2c64090d483dbe3795eea941f808f7eda30de68190976a36f856f2a824bdd", 110 | "0x040a30b5d71d70991519167c5314323d2d69b02b7c501070ec7f34f4f24d89b5860508000000000000000000000000000000000000000000000000119b000000000000000100000000000000000000000000000000000000000000000000000000000000001a99ce3a54bcc9f4d7f61c67286f0ffc6a5ddab4a94c1f6fc6741a5ef196145b16fc66d15010e6213d2a009f57ed8e847717ea0b83eeb37cd322e9ad1b018a3e0d85b09a93d5ed99a87d27dcf6d50e4459d16bb694e70f89eefcb745ea1c85e7200c64e6f8d51bb1ae0e4ad62b9a1b996e1b2675d3000000000000000000000000", 111 | "0x09218bcaf094949451aaea2273a4092c7116839ad69df7597df06c7bf741a9477f01020df75837d8a760bfb941f3465f63812b205ac7e1fff5d310a2a3295e60c8", 112 | "0x0913e957fbc8585b40175129d3547a76b9fc3a1c3b16a6ca4de468879bb08fcbb6104a71f54260a0430906c4a0c3cc5eb459dd132b637c944ea92b769a98dba762", 113 | "0x092c2eae4f5273c398709da3e317c86a3a817008c98269bf2766405259c488306628c0c92eb1f16fc59b8b99e0a8abee3f88afb477c4d36be3571d537b076e0f83", 114 | "0x0800100f66e758c81427817699eeed67308bc9a7ee8054f2cbd463b7bf252610af233b07e4b000250359a56ef55485036e6d4dbca7c71bf82812790ac3f4a5238e", 115 | "0x08088158f4dfd26b06688c646a453c1b52710139a064b0394b47a0693c2bee46a4159d39c4d2776406bca63dfba405861d669f6220a087ad4b204e1cca52c7be5f", 116 | "0x062b2d9de4b02c2bab78264918866524e44e6efdc24bf0be2d4a8aa6f9b232a7781cb2c64090d483dbe3795eea941f808f7eda30de68190976a36f856f2a824bdd", 117 | "0x041822829dca763241624d1f8dd4cf59018fc5f69931d579f8e8a4c3addd6633e605080000000000000000000000000000000000000000000000000000000000000000001101ffffffffffffffffffffffffffffffffffffffffffd5a5fa65e20465da88bf0000000000000000000000000000000000000000000000000000000000000000c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a4702098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864201c5a77d9fa7ef466951b2f01f724bca3a5820b63000000000000000000000000", 118 | "0x092df5ac113a2c9174aea818559d63df596efdd925bcc14028e901ba605dc030e101ebd1fa8391b5fa5b805444d74896d14cfac9519260e94ab9ef25ee4461f737", 119 | "0x070000000000000000000000000000000000000000000000000000000000000000104736bbf00e9ab6f74b9e366c28b4f21c4a273cbd1f7e3dff3d68d4dbfe6d76", 120 | "0x060a9837791a40c9befa2ebdbbe99fdb8d8d7a7bb9fe3a4c14f1581a76809bf2b21323d7866288f9d670672215af41d0d610303b7e5f6ba97b5e54080960974580", 121 | "0x041aed9d52b6e3489c0ea97983a6dc4fbad57507090547dc83b8830c2ddb88577701010000000000000000000000001c5a77d9fa7ef466951b2f01f724bca3a5820b630012200000000000000000000000000000000000000000000000000000000000000005", 122 | "0x09218bcaf094949451aaea2273a4092c7116839ad69df7597df06c7bf741a9477f01020df75837d8a760bfb941f3465f63812b205ac7e1fff5d310a2a3295e60c8", 123 | "0x0913e957fbc8585b40175129d3547a76b9fc3a1c3b16a6ca4de468879bb08fcbb6104a71f54260a0430906c4a0c3cc5eb459dd132b637c944ea92b769a98dba762", 124 | "0x092c2eae4f5273c398709da3e317c86a3a817008c98269bf2766405259c488306628c0c92eb1f16fc59b8b99e0a8abee3f88afb477c4d36be3571d537b076e0f83", 125 | "0x0800100f66e758c81427817699eeed67308bc9a7ee8054f2cbd463b7bf252610af233b07e4b000250359a56ef55485036e6d4dbca7c71bf82812790ac3f4a5238e", 126 | "0x04113060bdeae1240b8b2f272e35848ac6b0c401bdc3a9ec20186da2a6a9d4607e05080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000152d02c7e14af60000000000000000000000000000000000000000000000000000000000000000000000c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a4702098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b6486420c0c4c8baea3f6acb49b6e1fb9e2adeceeacb0ca2000000000000000000000000", 127 | "0x09218bcaf094949451aaea2273a4092c7116839ad69df7597df06c7bf741a9477f01020df75837d8a760bfb941f3465f63812b205ac7e1fff5d310a2a3295e60c8", 128 | "0x0908e49a63f6ecd17ace446bd1e684b6cdd29f31faae528a9f058aefa76551068228eeef32a81cf40e295ad9c1de7e53a5180e6f1727521a209e0e2913250941fe", 129 | "0x082e4e1a6f0a26fe354020a569325d45d9b63e11f769a79620da6c84c053d88733252f02bd2a45416d5076e363e6172f941774eb184ce75ba0803264362958e2ef", 130 | "0x08020cc627de460d025af928a8a847b8d7475ff44bcaadce1667cfab122c8f3ea6301dc3e787d41a3db0710353073f18eaebab31ac37d69e25983caf72f6c08178", 131 | "0x04139a6815e4d1fb05c969e6a8036aa5cc06b88751d713326d681bd90448ea64c905080000000000000000000000000000000000000000000000000874000000000000000000000000000000000000000000000000000000000000000000000000000000002c3c54d9c8b2d411ccd6458eaea5c644392b097de2ee416f5c923f3b01c7b8b80fabb5b0f58ec2922e2969f4dadb6d1395b49ecd40feff93e01212ae848355d410e77cae1c507f967948c6cd114e74ed65f662e365c7d6993e97f78ce898252800", 132 | "0x09218bcaf094949451aaea2273a4092c7116839ad69df7597df06c7bf741a9477f01020df75837d8a760bfb941f3465f63812b205ac7e1fff5d310a2a3295e60c8", 133 | "0x0908e49a63f6ecd17ace446bd1e684b6cdd29f31faae528a9f058aefa76551068228eeef32a81cf40e295ad9c1de7e53a5180e6f1727521a209e0e2913250941fe", 134 | "0x082e4e1a6f0a26fe354020a569325d45d9b63e11f769a79620da6c84c053d88733252f02bd2a45416d5076e363e6172f941774eb184ce75ba0803264362958e2ef", 135 | "0x08020cc627de460d025af928a8a847b8d7475ff44bcaadce1667cfab122c8f3ea6301dc3e787d41a3db0710353073f18eaebab31ac37d69e25983caf72f6c08178", 136 | "0x0700000000000000000000000000000000000000000000000000000000000000000d652d6e2cc697970d24bfec9c84b720481a080eeb3a039277d5dfa90c634a02", 137 | "0x060b262fa2cc2bcdf4083a6b4b45956ebcf85003d697780351a24398b7df39985a096c33b369382285822d8f0acf8097ca6f095334750a42f869e513c8ec3779a7", 138 | "0x04287b801ba8950befe82147f88e71eff6b85eb921845d754c9c2a165a4ec86791050800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000a5b65ae2577410000000000000000000000000000000000000000000000000000000000000000c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a4702098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864205300000000000000000000000000000000000005000000000000000000000000", 139 | "0x092e757f7cfb7c618a89bef428d6f043efb7913959793a525d3e6dc2265aa2e0362c9e569b67ba72d58e6f56454481607aee49523e3da63072e2cb4e0b37453e8a", 140 | "0x091868286870969b61281e49af8860d1bc74b558a4014da7433cb7e99a88aa56bc2d58daf89ed4b660018c081b11785924bd129ce58535350bd66c23eddf591e2b", 141 | "0x0900d86fc3cea9f88796671391157d8433f92be74473b01876ef9b6a75632c225d159af6801572801dfd6e17b00de85fcf0dae392c520440b763ecfc3936970af5", 142 | "0x0911b101680f5f11b4cccdcde4115c3f8e8af523fa76dd52de98c468cc0502dd642fd7d2a38e36d5a616485e21c93edb5798618e0e0e2003b979d05a94b29b2b29", 143 | "0x070000000000000000000000000000000000000000000000000000000000000000240aaaaee47745183d4820fe7384efe4a3fb93461aecea38b0a7d7bee64784a5", 144 | "0x05", 145 | ]; 146 | 147 | #[test] 148 | fn node_parse() { 149 | init_hash_scheme_simple(poseidon_hash_scheme); 150 | 151 | let nd = ZkTrieNode::parse(&hex::decode("04139a6815e4d1fb05c969e6a8036aa5cc06b88751d713326d681bd90448ea64c905080000000000000000000000000000000000000000000000000874000000000000000000000000000000000000000000000000000000000000000000000000000000002c3c54d9c8b2d411ccd6458eaea5c644392b097de2ee416f5c923f3b01c7b8b80fabb5b0f58ec2922e2969f4dadb6d1395b49ecd40feff93e01212ae848355d410e77cae1c507f967948c6cd114e74ed65f662e365c7d6993e97f78ce898252800").unwrap()); 152 | let nd = nd.unwrap(); 153 | assert_eq!( 154 | hex::encode(nd.node_hash()), 155 | "301dc3e787d41a3db0710353073f18eaebab31ac37d69e25983caf72f6c08178" 156 | ); 157 | let nd = ZkTrieNode::parse(&hex::decode("041822829dca763241624d1f8dd4cf59018fc5f69931d579f8e8a4c3addd6633e605080000000000000000000000000000000000000000000000000000000000000000003901ffffffffffffffffffffffffffffffffffffffffffc078f7396622d90018d50000000000000000000000000000000000000000000000000000000000000000c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a4702098f5fb9e239eab3ceac3f27b81e481dc3124d55ffed523a839ee8446b64864201c5a77d9fa7ef466951b2f01f724bca3a5820b63000000000000000000000000").unwrap()); 158 | let nd = nd.unwrap(); 159 | assert_eq!( 160 | hex::encode(nd.node_hash()), 161 | "18a38101a2886bca1262d02a7355d693b7937833a0eb729a5612cdb9a9817fc2" 162 | ); 163 | } 164 | 165 | #[test] 166 | fn trie_works() { 167 | use std::rc::Rc; 168 | 169 | init_hash_scheme_simple(poseidon_hash_scheme); 170 | let mut db = ZkMemoryDb::new(); 171 | 172 | for bts in EXAMPLE { 173 | let buf = hex::decode(bts.get(2..).unwrap()).unwrap(); 174 | db.add_node_data(&buf).unwrap(); 175 | } 176 | let mut db = Rc::new(db); 177 | 178 | let root = hex::decode("194cfd0c3cce58ac79c5bab34b149927e0cd9280c6d61870bfb621d45533ddbc") 179 | .unwrap(); 180 | let root: Hash = root.as_slice().try_into().unwrap(); 181 | 182 | let mut trie = db.new_trie(&root).unwrap(); 183 | 184 | if trie.is_trie_dirty() { 185 | trie.prepare_root(); 186 | } 187 | 188 | assert_eq!(trie.root(), root); 189 | 190 | let acc_buf = hex::decode("1C5A77d9FA7eF466951B2F01F724BCa3A5820b63").unwrap(); 191 | 192 | let acc_data = trie.get_account(&acc_buf).unwrap(); 193 | 194 | let mut nonce_code: StoreData = 195 | hex::decode("0000000000000000000000000000000000000000000000000000000000000011") 196 | .unwrap() 197 | .as_slice() 198 | .try_into() 199 | .unwrap(); 200 | let balance: StoreData = 201 | hex::decode("01ffffffffffffffffffffffffffffffffffffffffffd5a5fa65e20465da88bf") 202 | .unwrap() 203 | .as_slice() 204 | .try_into() 205 | .unwrap(); 206 | let code_hash: StoreData = 207 | hex::decode("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") 208 | .unwrap() 209 | .as_slice() 210 | .try_into() 211 | .unwrap(); 212 | assert_eq!(acc_data[0], nonce_code); 213 | assert_eq!(acc_data[1], balance); 214 | assert_eq!(acc_data[3], code_hash); 215 | 216 | nonce_code[31] += 1; 217 | 218 | let newacc: AccountData = [nonce_code, balance, [0; FIELDSIZE], code_hash, acc_data[4]]; 219 | trie.update_account(&acc_buf, &newacc).unwrap(); 220 | 221 | let acc_data = trie.get_account(&acc_buf).unwrap(); 222 | assert_eq!(acc_data[0], nonce_code); 223 | assert_eq!(acc_data[1], balance); 224 | assert_eq!(acc_data[3], code_hash); 225 | 226 | let mut root = 227 | hex::decode("9a88bda22f50dc0fda6c355fd93c025df7f7ce6e3d0b979942ebd981c1c6c71c") 228 | .unwrap(); 229 | root.reverse(); 230 | let root: Hash = root.as_slice().try_into().unwrap(); 231 | if trie.is_trie_dirty() { 232 | trie.prepare_root(); 233 | } 234 | assert_eq!(trie.root(), root); 235 | 236 | let newacc: AccountData = [ 237 | newacc[0], 238 | hex::decode("01ffffffffffffffffffffffffffffffffffffffffffd5a5fa65b10989405cd7") 239 | .unwrap() 240 | .as_slice() 241 | .try_into() 242 | .unwrap(), 243 | newacc[2], 244 | newacc[3], 245 | newacc[4], 246 | ]; 247 | trie.update_account(&acc_buf, &newacc).unwrap(); 248 | let mut root = 249 | hex::decode("7f787ee24805a9e5f69dc3a91ce68ef86d9358ce9c35729bd68660ccf6f9d909") 250 | .unwrap(); 251 | root.reverse(); 252 | let root: Hash = root.as_slice().try_into().unwrap(); 253 | if trie.is_trie_dirty() { 254 | trie.prepare_root(); 255 | } 256 | assert_eq!(trie.root(), root); 257 | 258 | assert!(db.new_ref_trie(&root).is_none()); 259 | 260 | // the zktrie can be created only if the corresponding root node has been added 261 | trie.commit().unwrap(); 262 | 263 | let trie_db = trie.updated_db(); 264 | Rc::get_mut(&mut db).expect("no reference").update(trie_db); 265 | let trie = db.new_ref_trie(&root).unwrap(); 266 | 267 | let proof = trie.prove(&acc_buf).unwrap(); 268 | 269 | assert_eq!(proof.len(), 8); 270 | assert_eq!(proof[7], hex::decode("5448495320495320534f4d45204d4147494320425954455320464f5220534d54206d3172525867503278704449").unwrap()); 271 | assert_eq!(proof[3], hex::decode("0810b051b9facdd51b7fd1a1cf8e9a62facef17c80c7be0db1f15f3cda95982e34233b07e4b000250359a56ef55485036e6d4dbca7c71bf82812790ac3f4a5238e").unwrap()); 272 | 273 | let node = ZkTrieNode::parse(&proof[6]).unwrap(); 274 | assert_eq!( 275 | node.node_hash().as_slice(), 276 | hex::decode("272f093df377b234e179b70dc1a04a1543072be3c7d3a47f6e59004c84639907") 277 | .unwrap() 278 | ); 279 | assert_eq!( 280 | node.value_hash().unwrap().as_slice(), 281 | hex::decode("06c7c55f4d38fa2c6f6e0e655038ae7e1b3bb9dfa8954bdec0f9708e6e6b7d72") 282 | .unwrap() 283 | ); 284 | assert!(node.is_tip()); 285 | assert_eq!( 286 | Vec::from(node.as_account().unwrap()[1]), 287 | hex::decode("01ffffffffffffffffffffffffffffffffffffffffffd5a5fa65b10989405cd7") 288 | .unwrap(), 289 | ); 290 | 291 | let mut trie = db.new_trie(&root).unwrap(); 292 | 293 | trie.delete(&acc_buf); 294 | assert!(trie.get_account(&acc_buf).is_none()); 295 | 296 | trie.update_account(&acc_buf, &newacc).unwrap(); 297 | if trie.is_trie_dirty() { 298 | trie.prepare_root(); 299 | } 300 | assert_eq!(trie.root(), root); 301 | } 302 | } 303 | -------------------------------------------------------------------------------- /trie/zk_trie_impl_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "bytes" 5 | "math/big" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | 10 | zkt "github.com/scroll-tech/zktrie/types" 11 | ) 12 | 13 | // we do not need zktrie impl anymore, only made a wrapper for adapting testing 14 | type zkTrieImplTestWrapper struct { 15 | *ZkTrieImpl 16 | } 17 | 18 | func newZkTrieImpl(storage ZktrieDatabase, maxLevels int) (*zkTrieImplTestWrapper, error) { 19 | return newZkTrieImplWithRoot(storage, &zkt.HashZero, maxLevels) 20 | } 21 | 22 | // NewZkTrieImplWithRoot loads a new ZkTrieImpl. If in the storage already exists one 23 | // will open that one, if not, will create a new one. 24 | func newZkTrieImplWithRoot(storage ZktrieDatabase, root *zkt.Hash, maxLevels int) (*zkTrieImplTestWrapper, error) { 25 | impl, err := NewZkTrieImplWithRoot(storage, root, maxLevels) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | return &zkTrieImplTestWrapper{impl}, nil 31 | } 32 | 33 | func (mt *zkTrieImplTestWrapper) AddWord(kPreimage, vPreimage *zkt.Byte32) error { 34 | 35 | if v, _ := mt.TryGet(kPreimage[:]); v != nil { 36 | return ErrEntryIndexAlreadyExists 37 | } 38 | 39 | return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBytes(kPreimage[:]), 1, []zkt.Byte32{*vPreimage}) 40 | } 41 | 42 | func (mt *zkTrieImplTestWrapper) GetLeafNodeByWord(kPreimage *zkt.Byte32) (*Node, error) { 43 | return mt.ZkTrieImpl.GetLeafNode(zkt.NewHashFromBytes(kPreimage[:])) 44 | } 45 | 46 | func (mt *zkTrieImplTestWrapper) UpdateWord(kPreimage, vPreimage *zkt.Byte32) error { 47 | return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBytes(kPreimage[:]), 1, []zkt.Byte32{*vPreimage}) 48 | } 49 | 50 | func (mt *zkTrieImplTestWrapper) DeleteWord(kPreimage *zkt.Byte32) error { 51 | return mt.ZkTrieImpl.TryDelete(zkt.NewHashFromBytes(kPreimage[:])) 52 | } 53 | 54 | func (mt *zkTrieImplTestWrapper) TryGet(key []byte) ([]byte, error) { 55 | return mt.ZkTrieImpl.TryGet(zkt.NewHashFromBytes(key)) 56 | } 57 | 58 | func newTestingMerkle(t *testing.T, numLevels int) *zkTrieImplTestWrapper { 59 | mt, err := newZkTrieImpl(NewZkTrieMemoryDb(), numLevels) 60 | if err != nil { 61 | t.Fatal(err) 62 | return nil 63 | } 64 | mt.Debug = true 65 | assert.Equal(t, numLevels, mt.MaxLevels()) 66 | return mt 67 | } 68 | 69 | func TestMerkleTree_Init(t *testing.T) { 70 | maxLevels := 248 71 | db := NewZkTrieMemoryDb() 72 | 73 | t.Run("Test NewZkTrieImpl", func(t *testing.T) { 74 | mt, err := NewZkTrieImpl(db, maxLevels) 75 | assert.NoError(t, err) 76 | mtRoot, err := mt.Root() 77 | assert.NoError(t, err) 78 | assert.Equal(t, zkt.HashZero.Bytes(), mtRoot.Bytes()) 79 | }) 80 | 81 | t.Run("Test NewZkTrieImplWithRoot with zero hash root", func(t *testing.T) { 82 | mt, err := NewZkTrieImplWithRoot(db, &zkt.HashZero, maxLevels) 83 | assert.NoError(t, err) 84 | mtRoot, err := mt.Root() 85 | assert.NoError(t, err) 86 | assert.Equal(t, zkt.HashZero.Bytes(), mtRoot.Bytes()) 87 | }) 88 | 89 | t.Run("Test NewZkTrieImplWithRoot with non-zero hash root and node exists", func(t *testing.T) { 90 | mt1, err := NewZkTrieImplWithRoot(db, &zkt.HashZero, maxLevels) 91 | assert.NoError(t, err) 92 | mt1Root, err := mt1.Root() 93 | assert.NoError(t, err) 94 | assert.Equal(t, zkt.HashZero.Bytes(), mt1Root.Bytes()) 95 | err = mt1.TryUpdate(zkt.NewHashFromBytes([]byte{1}), 1, []zkt.Byte32{{byte(1)}}) 96 | assert.NoError(t, err) 97 | mt1Root, err = mt1.Root() 98 | assert.NoError(t, err) 99 | assert.Equal(t, "0539c6b1cac741eb1e98b2c271733d1e6f0fad557228f6b039d894b0a627c8d9", mt1Root.Hex()) 100 | assert.NoError(t, mt1.Commit()) 101 | 102 | mt2, err := NewZkTrieImplWithRoot(db, mt1Root, maxLevels) 103 | assert.NoError(t, err) 104 | assert.Equal(t, maxLevels, mt2.maxLevels) 105 | mt2Root, err := mt2.Root() 106 | assert.NoError(t, err) 107 | assert.Equal(t, "0539c6b1cac741eb1e98b2c271733d1e6f0fad557228f6b039d894b0a627c8d9", mt2Root.Hex()) 108 | }) 109 | 110 | t.Run("Test NewZkTrieImplWithRoot with non-zero hash root and node does not exist", func(t *testing.T) { 111 | root := zkt.NewHashFromBytes([]byte{1, 2, 3, 4, 5}) 112 | 113 | mt, err := NewZkTrieImplWithRoot(db, root, maxLevels) 114 | assert.Error(t, err) 115 | assert.Nil(t, mt) 116 | }) 117 | } 118 | 119 | func TestMerkleTree_AddUpdateGetWord(t *testing.T) { 120 | mt := newTestingMerkle(t, 10) 121 | 122 | testData := []struct { 123 | key byte 124 | initialVal byte 125 | updatedVal byte 126 | }{ 127 | {1, 2, 7}, 128 | {3, 4, 8}, 129 | {5, 6, 9}, 130 | } 131 | 132 | for _, td := range testData { 133 | err := mt.AddWord(zkt.NewByte32FromBytes([]byte{td.key}), &zkt.Byte32{td.initialVal}) 134 | assert.NoError(t, err) 135 | 136 | node, err := mt.GetLeafNodeByWord(zkt.NewByte32FromBytes([]byte{td.key})) 137 | assert.NoError(t, err) 138 | assert.Equal(t, 1, len(node.ValuePreimage)) 139 | assert.Equal(t, (&zkt.Byte32{td.initialVal})[:], node.ValuePreimage[0][:]) 140 | } 141 | 142 | err := mt.AddWord(zkt.NewByte32FromBytes([]byte{5}), &zkt.Byte32{7}) 143 | assert.Equal(t, ErrEntryIndexAlreadyExists, err) 144 | 145 | for _, td := range testData { 146 | err := mt.UpdateWord(zkt.NewByte32FromBytes([]byte{td.key}), &zkt.Byte32{td.updatedVal}) 147 | assert.NoError(t, err) 148 | 149 | node, err := mt.GetLeafNodeByWord(zkt.NewByte32FromBytes([]byte{td.key})) 150 | assert.NoError(t, err) 151 | assert.Equal(t, 1, len(node.ValuePreimage)) 152 | assert.Equal(t, (&zkt.Byte32{td.updatedVal})[:], node.ValuePreimage[0][:]) 153 | } 154 | 155 | _, err = mt.GetLeafNodeByWord(&zkt.Byte32{100}) 156 | assert.Equal(t, ErrKeyNotFound, err) 157 | } 158 | 159 | func TestMerkleTree_Deletion(t *testing.T) { 160 | t.Run("Check root consistency", func(t *testing.T) { 161 | var err error 162 | mt := newTestingMerkle(t, 10) 163 | hashes := make([]*zkt.Hash, 7) 164 | hashes[0], err = mt.Root() 165 | assert.NoError(t, err) 166 | 167 | for i := 0; i < 6; i++ { 168 | err := mt.AddWord(zkt.NewByte32FromBytes([]byte{byte(i)}), &zkt.Byte32{byte(i)}) 169 | assert.NoError(t, err) 170 | hashes[i+1], err = mt.Root() 171 | assert.NoError(t, err) 172 | } 173 | 174 | for i := 5; i >= 0; i-- { 175 | err := mt.DeleteWord(zkt.NewByte32FromBytes([]byte{byte(i)})) 176 | assert.NoError(t, err) 177 | root, err := mt.Root() 178 | assert.NoError(t, err) 179 | assert.Equal(t, hashes[i], root, i) 180 | } 181 | }) 182 | 183 | t.Run("Check depth", func(t *testing.T) { 184 | mt := newTestingMerkle(t, 10) 185 | key1 := zkt.NewByte32FromBytes([]byte{67}) //0b1000011 186 | err := mt.AddWord(key1, &zkt.Byte32{67}) 187 | assert.NoError(t, err) 188 | rootPhase1, err := mt.Root() 189 | assert.NoError(t, err) 190 | key2 := zkt.NewByte32FromBytes([]byte{131}) //0b10000011 191 | err = mt.AddWord(key2, &zkt.Byte32{131}) 192 | assert.NoError(t, err) 193 | rootPhase2, err := mt.Root() 194 | assert.NoError(t, err) 195 | 196 | assertKeyDepth := func(key *zkt.Byte32, expectedDep int) { 197 | levelCnt := 0 198 | err := mt.Prove(zkt.NewHashFromBytes(key[:]), 0, 199 | func(*Node) error { 200 | levelCnt++ 201 | return nil 202 | }, 203 | ) 204 | assert.NoError(t, err) 205 | assert.Equal(t, expectedDep, levelCnt) 206 | } 207 | 208 | assertKeyDepth(key1, 8) 209 | assertKeyDepth(key2, 8) 210 | 211 | err = mt.DeleteWord(key2) 212 | assert.NoError(t, err) 213 | 214 | assertKeyDepth(key1, 1) 215 | curRoot, err := mt.Root() 216 | assert.NoError(t, err) 217 | assert.Equal(t, rootPhase1, curRoot) 218 | 219 | err = mt.AddWord(key2, &zkt.Byte32{131}) 220 | assert.NoError(t, err) 221 | curRoot, err = mt.Root() 222 | assert.NoError(t, err) 223 | assert.Equal(t, rootPhase2, curRoot) 224 | assertKeyDepth(key1, 8) 225 | 226 | // delete node with parent sibling (fail before a410f14) 227 | key3 := zkt.NewByte32FromBytes([]byte{19}) //0b10011 228 | err = mt.AddWord(key3, &zkt.Byte32{19}) 229 | assert.NoError(t, err) 230 | 231 | err = mt.DeleteWord(key3) 232 | assert.NoError(t, err) 233 | assertKeyDepth(key1, 8) 234 | curRoot, err = mt.Root() 235 | assert.NoError(t, err) 236 | assert.Equal(t, rootPhase2, curRoot) 237 | 238 | key4 := zkt.NewByte32FromBytes([]byte{4}) //0b100, so it is 2 level node (fail before d1c735) 239 | err = mt.AddWord(key4, &zkt.Byte32{4}) 240 | assert.NoError(t, err) 241 | 242 | assertKeyDepth(key4, 2) 243 | 244 | err = mt.DeleteWord(key4) 245 | assert.NoError(t, err) 246 | curRoot, err = mt.Root() 247 | assert.NoError(t, err) 248 | assert.Equal(t, rootPhase2, curRoot) 249 | }) 250 | } 251 | 252 | func TestZkTrieImpl_Add(t *testing.T) { 253 | k1 := zkt.NewByte32FromBytes([]byte{1}) 254 | k2 := zkt.NewByte32FromBytes([]byte{2}) 255 | k3 := zkt.NewByte32FromBytes([]byte{3}) 256 | 257 | kvMap := map[*zkt.Byte32]*zkt.Byte32{ 258 | k1: zkt.NewByte32FromBytes([]byte{1}), 259 | k2: zkt.NewByte32FromBytes([]byte{2}), 260 | k3: zkt.NewByte32FromBytes([]byte{3}), 261 | } 262 | 263 | t.Run("Add 1 and 2 in different orders", func(t *testing.T) { 264 | orders := [][]*zkt.Byte32{ 265 | {k1, k2}, 266 | {k2, k1}, 267 | } 268 | 269 | roots := make([]*zkt.Hash, len(orders)) 270 | for i, order := range orders { 271 | mt := newTestingMerkle(t, 10) 272 | for _, key := range order { 273 | value := kvMap[key] 274 | err := mt.AddWord(key, value) 275 | assert.NoError(t, err) 276 | } 277 | var err error 278 | roots[i], err = mt.Root() 279 | assert.NoError(t, err) 280 | } 281 | 282 | assert.Equal(t, "225fe589e8cbdfe424a032e6e2fd1132762b20794cff61f0c70e8f757b6a0ed7", roots[0].Hex()) 283 | assert.Equal(t, roots[0], roots[1]) 284 | }) 285 | 286 | t.Run("Add 1, 2, 3 in different orders", func(t *testing.T) { 287 | orders := [][]*zkt.Byte32{ 288 | {k1, k2, k3}, 289 | {k1, k3, k2}, 290 | {k2, k1, k3}, 291 | {k2, k3, k1}, 292 | {k3, k1, k2}, 293 | {k3, k2, k1}, 294 | } 295 | 296 | roots := make([]*zkt.Hash, len(orders)) 297 | for i, order := range orders { 298 | mt := newTestingMerkle(t, 10) 299 | for _, key := range order { 300 | value := kvMap[key] 301 | err := mt.AddWord(key, value) 302 | assert.NoError(t, err) 303 | } 304 | var err error 305 | roots[i], err = mt.Root() 306 | assert.NoError(t, err) 307 | } 308 | 309 | for i := 1; i < len(roots); i++ { 310 | assert.Equal(t, "25aa478a6c8c3a7cab40b0c3a37f8ed6815ee575228f0ba8e77d1145191f9a34", roots[0].Hex()) 311 | assert.Equal(t, roots[0], roots[i]) 312 | } 313 | }) 314 | 315 | t.Run("Add twice", func(t *testing.T) { 316 | keys := []*zkt.Byte32{k1, k2, k3} 317 | 318 | mt := newTestingMerkle(t, 10) 319 | for _, key := range keys { 320 | err := mt.AddWord(key, kvMap[key]) 321 | assert.NoError(t, err) 322 | 323 | err = mt.AddWord(key, kvMap[key]) 324 | assert.Equal(t, ErrEntryIndexAlreadyExists, err) 325 | } 326 | }) 327 | } 328 | 329 | func TestZkTrieImpl_Update(t *testing.T) { 330 | k1 := zkt.NewByte32FromBytes([]byte{1}) 331 | k2 := zkt.NewByte32FromBytes([]byte{2}) 332 | k3 := zkt.NewByte32FromBytes([]byte{3}) 333 | 334 | t.Run("Update 1", func(t *testing.T) { 335 | mt1 := newTestingMerkle(t, 10) 336 | err := mt1.AddWord(k1, zkt.NewByte32FromBytes([]byte{1})) 337 | assert.NoError(t, err) 338 | root1, err := mt1.Root() 339 | assert.NoError(t, err) 340 | 341 | mt2 := newTestingMerkle(t, 10) 342 | err = mt2.AddWord(k1, zkt.NewByte32FromBytes([]byte{2})) 343 | assert.NoError(t, err) 344 | err = mt2.UpdateWord(k1, zkt.NewByte32FromBytes([]byte{1})) 345 | assert.NoError(t, err) 346 | root2, err := mt2.Root() 347 | assert.NoError(t, err) 348 | 349 | assert.Equal(t, root1, root2) 350 | }) 351 | 352 | t.Run("Update 2", func(t *testing.T) { 353 | mt1 := newTestingMerkle(t, 10) 354 | err := mt1.AddWord(k1, zkt.NewByte32FromBytes([]byte{1})) 355 | assert.NoError(t, err) 356 | err = mt1.AddWord(k2, zkt.NewByte32FromBytes([]byte{2})) 357 | assert.NoError(t, err) 358 | root1, err := mt1.Root() 359 | assert.NoError(t, err) 360 | 361 | mt2 := newTestingMerkle(t, 10) 362 | err = mt2.AddWord(k1, zkt.NewByte32FromBytes([]byte{1})) 363 | assert.NoError(t, err) 364 | err = mt2.AddWord(k2, zkt.NewByte32FromBytes([]byte{3})) 365 | assert.NoError(t, err) 366 | err = mt2.UpdateWord(k2, zkt.NewByte32FromBytes([]byte{2})) 367 | assert.NoError(t, err) 368 | root2, err := mt2.Root() 369 | assert.NoError(t, err) 370 | 371 | assert.Equal(t, root1, root2) 372 | }) 373 | 374 | t.Run("Update 1, 2, 3", func(t *testing.T) { 375 | mt1 := newTestingMerkle(t, 10) 376 | mt2 := newTestingMerkle(t, 10) 377 | keys := []*zkt.Byte32{k1, k2, k3} 378 | for i, key := range keys { 379 | err := mt1.AddWord(key, zkt.NewByte32FromBytes([]byte{byte(i)})) 380 | assert.NoError(t, err) 381 | } 382 | for i, key := range keys { 383 | err := mt2.AddWord(key, zkt.NewByte32FromBytes([]byte{byte(i + 3)})) 384 | assert.NoError(t, err) 385 | } 386 | for i, key := range keys { 387 | err := mt1.UpdateWord(key, zkt.NewByte32FromBytes([]byte{byte(i + 6)})) 388 | assert.NoError(t, err) 389 | err = mt2.UpdateWord(key, zkt.NewByte32FromBytes([]byte{byte(i + 6)})) 390 | assert.NoError(t, err) 391 | } 392 | 393 | root1, err := mt1.Root() 394 | assert.NoError(t, err) 395 | root2, err := mt2.Root() 396 | assert.NoError(t, err) 397 | 398 | assert.Equal(t, root1, root2) 399 | }) 400 | 401 | t.Run("Update same value", func(t *testing.T) { 402 | mt := newTestingMerkle(t, 10) 403 | keys := []*zkt.Byte32{k1, k2, k3} 404 | for _, key := range keys { 405 | err := mt.AddWord(key, zkt.NewByte32FromBytes([]byte{1})) 406 | assert.NoError(t, err) 407 | err = mt.UpdateWord(key, zkt.NewByte32FromBytes([]byte{1})) 408 | assert.NoError(t, err) 409 | node, err := mt.GetLeafNodeByWord(key) 410 | assert.NoError(t, err) 411 | assert.Equal(t, 1, len(node.ValuePreimage)) 412 | assert.Equal(t, zkt.NewByte32FromBytes([]byte{1}).Bytes(), node.ValuePreimage[0][:]) 413 | } 414 | }) 415 | 416 | t.Run("Update non-existent word", func(t *testing.T) { 417 | mt := newTestingMerkle(t, 10) 418 | err := mt.UpdateWord(k1, zkt.NewByte32FromBytes([]byte{1})) 419 | assert.NoError(t, err) 420 | node, err := mt.GetLeafNodeByWord(k1) 421 | assert.NoError(t, err) 422 | assert.Equal(t, 1, len(node.ValuePreimage)) 423 | assert.Equal(t, zkt.NewByte32FromBytes([]byte{1}).Bytes(), node.ValuePreimage[0][:]) 424 | }) 425 | } 426 | 427 | func TestZkTrieImpl_Delete(t *testing.T) { 428 | k1 := zkt.NewByte32FromBytes([]byte{1}) 429 | k2 := zkt.NewByte32FromBytes([]byte{2}) 430 | k3 := zkt.NewByte32FromBytes([]byte{3}) 431 | k4 := zkt.NewByte32FromBytes([]byte{4}) 432 | 433 | t.Run("Test deletion leads to empty tree", func(t *testing.T) { 434 | emptyMT := newTestingMerkle(t, 10) 435 | emptyMTRoot, err := emptyMT.Root() 436 | assert.NoError(t, err) 437 | 438 | mt1 := newTestingMerkle(t, 10) 439 | err = mt1.AddWord(k1, zkt.NewByte32FromBytes([]byte{1})) 440 | assert.NoError(t, err) 441 | err = mt1.DeleteWord(k1) 442 | assert.NoError(t, err) 443 | mt1Root, err := mt1.Root() 444 | assert.NoError(t, err) 445 | assert.Equal(t, zkt.HashZero, *mt1Root) 446 | assert.Equal(t, emptyMTRoot, mt1Root) 447 | 448 | keys := []*zkt.Byte32{k1, k2, k3, k4} 449 | mt2 := newTestingMerkle(t, 10) 450 | for _, key := range keys { 451 | err := mt2.AddWord(key, zkt.NewByte32FromBytes([]byte{1})) 452 | assert.NoError(t, err) 453 | } 454 | for _, key := range keys { 455 | err := mt2.DeleteWord(key) 456 | assert.NoError(t, err) 457 | } 458 | mt2Root, err := mt2.Root() 459 | assert.NoError(t, err) 460 | assert.Equal(t, zkt.HashZero, *mt2Root) 461 | assert.Equal(t, emptyMTRoot, mt2Root) 462 | 463 | mt3 := newTestingMerkle(t, 10) 464 | for _, key := range keys { 465 | err := mt3.AddWord(key, zkt.NewByte32FromBytes([]byte{1})) 466 | assert.NoError(t, err) 467 | } 468 | for i := len(keys) - 1; i >= 0; i-- { 469 | err := mt3.DeleteWord(keys[i]) 470 | assert.NoError(t, err) 471 | } 472 | mt3Root, err := mt3.Root() 473 | assert.NoError(t, err) 474 | assert.Equal(t, zkt.HashZero, *mt3Root) 475 | assert.Equal(t, emptyMTRoot, mt3Root) 476 | }) 477 | 478 | t.Run("Test equivalent trees after deletion", func(t *testing.T) { 479 | keys := []*zkt.Byte32{k1, k2, k3, k4} 480 | 481 | mt1 := newTestingMerkle(t, 10) 482 | for i, key := range keys { 483 | err := mt1.AddWord(key, zkt.NewByte32FromBytes([]byte{byte(i + 1)})) 484 | assert.NoError(t, err) 485 | } 486 | err := mt1.DeleteWord(k1) 487 | assert.NoError(t, err) 488 | err = mt1.DeleteWord(k2) 489 | assert.NoError(t, err) 490 | 491 | mt2 := newTestingMerkle(t, 10) 492 | err = mt2.AddWord(k3, zkt.NewByte32FromBytes([]byte{byte(3)})) 493 | assert.NoError(t, err) 494 | err = mt2.AddWord(k4, zkt.NewByte32FromBytes([]byte{byte(4)})) 495 | assert.NoError(t, err) 496 | 497 | mt1Root, err := mt1.Root() 498 | assert.NoError(t, err) 499 | mt2Root, err := mt2.Root() 500 | assert.NoError(t, err) 501 | 502 | assert.Equal(t, mt1Root, mt2Root) 503 | 504 | mt3 := newTestingMerkle(t, 10) 505 | for i, key := range keys { 506 | err := mt3.AddWord(key, zkt.NewByte32FromBytes([]byte{byte(i + 1)})) 507 | assert.NoError(t, err) 508 | } 509 | err = mt3.DeleteWord(k1) 510 | assert.NoError(t, err) 511 | err = mt3.DeleteWord(k3) 512 | assert.NoError(t, err) 513 | mt4 := newTestingMerkle(t, 10) 514 | err = mt4.AddWord(k2, zkt.NewByte32FromBytes([]byte{2})) 515 | assert.NoError(t, err) 516 | err = mt4.AddWord(k4, zkt.NewByte32FromBytes([]byte{4})) 517 | assert.NoError(t, err) 518 | 519 | mt3Root, err := mt3.Root() 520 | assert.NoError(t, err) 521 | mt4Root, err := mt4.Root() 522 | assert.NoError(t, err) 523 | 524 | assert.Equal(t, mt3Root, mt4Root) 525 | }) 526 | 527 | t.Run("Test repeat deletion", func(t *testing.T) { 528 | mt := newTestingMerkle(t, 10) 529 | err := mt.AddWord(k1, zkt.NewByte32FromBytes([]byte{1})) 530 | assert.NoError(t, err) 531 | err = mt.DeleteWord(k1) 532 | assert.NoError(t, err) 533 | err = mt.DeleteWord(k1) 534 | assert.Equal(t, ErrKeyNotFound, err) 535 | }) 536 | 537 | t.Run("Test deletion of non-existent node", func(t *testing.T) { 538 | mt := newTestingMerkle(t, 10) 539 | err := mt.DeleteWord(k1) 540 | assert.Equal(t, ErrKeyNotFound, err) 541 | }) 542 | } 543 | 544 | func TestMerkleTree_BuildAndVerifyZkTrieProof(t *testing.T) { 545 | zkTrie := newTestingMerkle(t, 10) 546 | 547 | testData := []struct { 548 | key *big.Int 549 | value byte 550 | }{ 551 | {big.NewInt(1), 2}, 552 | {big.NewInt(3), 4}, 553 | {big.NewInt(5), 6}, 554 | {big.NewInt(7), 8}, 555 | {big.NewInt(9), 10}, 556 | } 557 | 558 | nonExistentKey := big.NewInt(11) 559 | 560 | getNode := func(hash *zkt.Hash) (*Node, error) { 561 | node, err := zkTrie.GetNode(hash) 562 | if err != nil { 563 | return nil, err 564 | } 565 | return node, nil 566 | } 567 | 568 | for _, td := range testData { 569 | err := zkTrie.AddWord(zkt.NewByte32FromBytes([]byte{byte(td.key.Int64())}), &zkt.Byte32{td.value}) 570 | assert.NoError(t, err) 571 | } 572 | 573 | t.Run("Test with existent key", func(t *testing.T) { 574 | for _, td := range testData { 575 | 576 | node, err := zkTrie.GetLeafNodeByWord(zkt.NewByte32FromBytes([]byte{byte(td.key.Int64())})) 577 | assert.NoError(t, err) 578 | assert.Equal(t, 1, len(node.ValuePreimage)) 579 | assert.Equal(t, (&zkt.Byte32{td.value})[:], node.ValuePreimage[0][:]) 580 | assert.NoError(t, zkTrie.Commit()) 581 | 582 | proof, node, err := BuildZkTrieProof(zkTrie.rootKey, td.key, 10, getNode) 583 | assert.NoError(t, err) 584 | 585 | valid := VerifyProofZkTrie(zkTrie.rootKey, proof, node) 586 | assert.True(t, valid) 587 | } 588 | }) 589 | 590 | t.Run("Test with non-existent key", func(t *testing.T) { 591 | proof, node, err := BuildZkTrieProof(zkTrie.rootKey, nonExistentKey, 10, getNode) 592 | assert.NoError(t, err) 593 | assert.False(t, proof.Existence) 594 | valid := VerifyProofZkTrie(zkTrie.rootKey, proof, node) 595 | assert.True(t, valid) 596 | nodeAnother, err := zkTrie.GetLeafNodeByWord(zkt.NewByte32FromBytes([]byte{byte(big.NewInt(1).Int64())})) 597 | assert.NoError(t, err) 598 | valid = VerifyProofZkTrie(zkTrie.rootKey, proof, nodeAnother) 599 | assert.False(t, valid) 600 | 601 | hash, err := proof.Verify(node.nodeHash) 602 | assert.NoError(t, err) 603 | assert.Equal(t, hash[:], zkTrie.rootKey[:]) 604 | }) 605 | } 606 | 607 | func TestMerkleTree_GraphViz(t *testing.T) { 608 | mt := newTestingMerkle(t, 10) 609 | 610 | var buffer bytes.Buffer 611 | err := mt.GraphViz(&buffer, nil) 612 | assert.NoError(t, err) 613 | assert.Equal(t, "--------\nGraphViz of the ZkTrieImpl with RootHash 0\ndigraph hierarchy {\nnode [fontname=Monospace,fontsize=10,shape=box]\n}\nEnd of GraphViz of the ZkTrieImpl with RootHash 0\n--------\n", buffer.String()) 614 | buffer.Reset() 615 | 616 | key1 := zkt.NewByte32FromBytes([]byte{1}) //0b1 617 | err = mt.AddWord(key1, &zkt.Byte32{1}) 618 | assert.NoError(t, err) 619 | key2 := zkt.NewByte32FromBytes([]byte{3}) //0b11 620 | err = mt.AddWord(key2, &zkt.Byte32{3}) 621 | assert.NoError(t, err) 622 | 623 | err = mt.GraphViz(&buffer, nil) 624 | assert.NoError(t, err) 625 | assert.Equal(t, "--------\nGraphViz of the ZkTrieImpl with RootHash 18814328259272153650095812929528579893472885385393031263032639585810677019057\ndigraph hierarchy {\nnode [fontname=Monospace,fontsize=10,shape=box]\n\"18814328...\" -> {\"empty0\" \"36062889...\"}\n\"empty0\" [style=dashed,label=0];\n\"36062889...\" -> {\"23636458...\" \"20814118...\"}\n\"23636458...\" [style=filled];\n\"20814118...\" [style=filled];\n}\nEnd of GraphViz of the ZkTrieImpl with RootHash 18814328259272153650095812929528579893472885385393031263032639585810677019057\n--------\n", buffer.String()) 626 | buffer.Reset() 627 | } 628 | -------------------------------------------------------------------------------- /rs_zktrie/src/types.rs: -------------------------------------------------------------------------------- 1 | use crate::raw::ImplError; 2 | use num; 3 | use num_derive::FromPrimitive; 4 | use std::fmt::Debug; 5 | 6 | const HASH_BYTE_LEN: usize = 32; 7 | 8 | pub trait Hashable: Clone + Debug + Default + PartialEq { 9 | fn hash_elems_with_domain(domain: u64, lbytes: &Self, rbytes: &Self) 10 | -> Result; 11 | fn hash_zero() -> Self; 12 | fn check_in_field(hash: &Self) -> bool; 13 | fn test_bit(key: &Self, pos: usize) -> bool; 14 | fn from_bytes(bytes: &[u8]) -> Result; 15 | fn to_bytes(&self) -> Vec; 16 | } 17 | 18 | pub trait TrieHashScheme { 19 | type Hash: Hashable; 20 | fn handling_elems_and_bytes32( 21 | flags: u32, 22 | bytes: &[[u8; HASH_BYTE_LEN]], 23 | ) -> Result; 24 | /// hash any byes not longer than HASH_BYTE_LEN 25 | fn hash_bytes(bytes: &[u8]) -> Result; 26 | } 27 | 28 | #[derive(Copy, Clone, Debug, FromPrimitive, Display, PartialEq)] 29 | pub enum NodeType { 30 | // NodeTypeParent indicates the type of parent Node that has children. 31 | NodeTypeParent = 0, 32 | // NodeTypeLeaf indicates the type of a leaf Node that contains a key & 33 | // value. 34 | NodeTypeLeaf = 1, 35 | // NodeTypeEmpty indicates the type of an empty Node. 36 | NodeTypeEmpty = 2, 37 | 38 | // DBEntryTypeRoot indicates the type of a DB entry that indicates the 39 | // current Root of a MerkleTree 40 | DBEntryTypeRoot = 3, 41 | 42 | NodeTypeLeafNew = 4, 43 | NodeTypeEmptyNew = 5, 44 | // branch node for both child are terminal nodes 45 | NodeTypeBranch0 = 6, 46 | // branch node for left child is terminal node and right child is branch 47 | NodeTypeBranch1 = 7, 48 | // branch node for left child is branch node and right child is terminal 49 | NodeTypeBranch2 = 8, 50 | // branch node for both child are branch nodes 51 | NodeTypeBranch3 = 9, 52 | // any invalid situation 53 | NodeTypeInvalid = 10, 54 | } 55 | 56 | use strum_macros::Display; 57 | use NodeType::*; 58 | 59 | impl NodeType { 60 | /// deduce a new branch type from current branch when one of its child become non trivial 61 | pub fn deduce_upgrade_type(&self, is_right: bool) -> Self { 62 | if is_right { 63 | match self { 64 | NodeTypeBranch0 => NodeTypeBranch1, 65 | NodeTypeBranch1 => *self, 66 | NodeTypeBranch2 => NodeTypeBranch3, 67 | NodeTypeBranch3 => NodeTypeBranch3, 68 | _ => unreachable!(), 69 | } 70 | } else { 71 | match self { 72 | NodeTypeBranch0 => NodeTypeBranch2, 73 | NodeTypeBranch1 => NodeTypeBranch3, 74 | NodeTypeBranch3 => NodeTypeBranch3, 75 | NodeTypeBranch2 => *self, 76 | _ => unreachable!(), 77 | } 78 | } 79 | } 80 | 81 | /// deduce a new branch type from current branch when one of its child become terminal 82 | pub fn deduce_downgrade_type(&self, is_right: bool) -> Self { 83 | if is_right { 84 | match self { 85 | NodeTypeBranch1 => NodeTypeBranch0, 86 | NodeTypeBranch3 => NodeTypeBranch2, 87 | _ => { 88 | panic!("can not downgrade a node with terminal child {}", self); 89 | } 90 | } 91 | } else { 92 | match self { 93 | NodeTypeBranch3 => NodeTypeBranch1, 94 | NodeTypeBranch2 => NodeTypeBranch0, 95 | _ => { 96 | panic!("can not downgrade a node with terminal child {}", self); 97 | } 98 | } 99 | } 100 | } 101 | } 102 | 103 | // Node is the struct that represents a node in the MT. The node should not be 104 | // modified after creation because the cached key won't be updated. 105 | #[derive(Clone, Debug)] 106 | pub struct Node { 107 | // node_type is the type of node in the tree. 108 | pub node_type: NodeType, 109 | // child_l is the node hash of the left child of a parent node. 110 | pub child_left: Option, 111 | // child_r is the node hash of the right child of a parent node. 112 | pub child_right: Option, 113 | // key is the node's key stored in a leaf node. 114 | pub node_key: H, 115 | // value_preimage can store at most 256 byte32 as fields (represnted by BIG-ENDIAN integer) 116 | // and the first 24 can be compressed (each bytes32 consider as 2 fields), in hashing the compressed 117 | // elemments would be calculated first 118 | pub value_preimage: Vec<[u8; 32]>, 119 | // use each bit for indicating the compressed flag for the first 24 fields 120 | compress_flags: u32, 121 | // nodeHash is the cache of the hash of the node to avoid recalculating 122 | node_hash: Option, 123 | // valueHash is the cache of the hash of valuePreimage to avoid recalculating, only valid for leaf node 124 | value_hash: Option, 125 | // KeyPreimage is the original key value that derives the node_key, kept here only for proof 126 | key_preimage: Option<[u8; 32]>, 127 | } 128 | 129 | const HASH_DOMAIN_ELEMS_BASE: usize = 256; 130 | const HASH_DOMAIN_BYTE32: usize = 2 * HASH_DOMAIN_ELEMS_BASE; 131 | 132 | impl TrieHashScheme for Node { 133 | type Hash = H; 134 | 135 | fn handling_elems_and_bytes32(flags: u32, bytes: &[[u8; 32]]) -> Result { 136 | assert!(!bytes.len() > 1); 137 | let mut tmp = vec![]; 138 | for (i, byte) in bytes.iter().enumerate() { 139 | if flags & (1 << i) != 0 { 140 | tmp.push(Self::hash_bytes(byte.as_slice())?); 141 | } else { 142 | tmp.push(H::from_bytes(byte)?); 143 | } 144 | } 145 | assert_eq!(tmp.len(), bytes.len()); 146 | 147 | let domain = bytes.len() * HASH_DOMAIN_ELEMS_BASE; 148 | while tmp.len() > 1 { 149 | let mut out = Vec::new(); 150 | for pair in tmp.chunks(2) { 151 | out.push(if pair.len() == 2 { 152 | H::hash_elems_with_domain(domain as u64, &pair[0], &pair[1])? 153 | } else { 154 | pair[0].clone() 155 | }); 156 | } 157 | tmp = out; 158 | } 159 | 160 | Ok(tmp.pop().unwrap()) 161 | } 162 | 163 | fn hash_bytes(v: &[u8]) -> Result { 164 | assert!(v.len() <= HASH_BYTE_LEN); 165 | const HALF_BYTE: usize = HASH_BYTE_LEN / 2; 166 | let mut v_lo = [0u8; HASH_BYTE_LEN]; 167 | let mut v_hi = [0u8; HASH_BYTE_LEN]; 168 | let lo_len = if v.len() > HALF_BYTE { 169 | HALF_BYTE 170 | } else { 171 | v.len() 172 | }; 173 | v_lo[HALF_BYTE..HALF_BYTE + lo_len].copy_from_slice(&v[..lo_len]); 174 | if v.len() > HALF_BYTE { 175 | v_hi[HALF_BYTE..v.len()].copy_from_slice(&v[HALF_BYTE..v.len()]); 176 | } 177 | H::hash_elems_with_domain( 178 | HASH_DOMAIN_BYTE32 as u64, 179 | &H::from_bytes(&v_lo)?, 180 | &H::from_bytes(&v_hi)?, 181 | ) 182 | } 183 | } 184 | 185 | impl Node { 186 | /// create a new leaf node 187 | pub fn new_leaf_node(node_key: H, value_flags: u32, value_preimage: Vec<[u8; 32]>) -> Self { 188 | Node { 189 | node_type: NodeType::NodeTypeLeafNew, 190 | node_key, 191 | compress_flags: value_flags, 192 | value_preimage, 193 | child_left: None, 194 | child_right: None, 195 | node_hash: None, 196 | value_hash: None, 197 | key_preimage: None, 198 | } 199 | } 200 | 201 | /// creates a new parent node. 202 | pub fn new_parent_node(node_type: NodeType, child_left: H, child_right: H) -> Self { 203 | Node { 204 | node_type, 205 | node_key: H::default(), 206 | compress_flags: 0, 207 | value_preimage: vec![], 208 | child_left: Some(child_left), 209 | child_right: Some(child_right), 210 | node_hash: None, 211 | value_hash: None, 212 | key_preimage: None, 213 | } 214 | } 215 | 216 | /// creates a new empty node. 217 | pub fn new_empty_node() -> Self { 218 | Node { 219 | node_type: NodeType::NodeTypeEmptyNew, 220 | node_key: H::default(), 221 | compress_flags: 0, 222 | value_preimage: vec![], 223 | child_left: None, 224 | child_right: None, 225 | node_hash: None, 226 | value_hash: None, 227 | key_preimage: None, 228 | } 229 | } 230 | 231 | // new_node_from_bytes creates a new node by parsing the input []byte. 232 | pub fn new_node_from_bytes(b: &[u8]) -> Result, ImplError> { 233 | if b.is_empty() { 234 | Err(ImplError::ErrNodeBytesBadSize) 235 | } else { 236 | let mut node = Node::new_empty_node(); 237 | node.node_type = num::FromPrimitive::from_u32(b[0] as u32).unwrap_or(NodeTypeInvalid); 238 | let b = &b[1..]; 239 | match node.node_type { 240 | NodeTypeParent | NodeTypeBranch0 | NodeTypeBranch1 | NodeTypeBranch2 241 | | NodeTypeBranch3 => { 242 | if b.len() != 2 * HASH_BYTE_LEN { 243 | Err(ImplError::ErrNodeBytesBadSize) 244 | } else { 245 | node.child_left = Some(H::from_bytes(&b[..HASH_BYTE_LEN])?); 246 | node.child_right = 247 | Some(H::from_bytes(&b[HASH_BYTE_LEN..HASH_BYTE_LEN * 2])?); 248 | Ok(node) 249 | } 250 | } 251 | NodeTypeLeaf | NodeTypeLeafNew => { 252 | if b.len() < HASH_BYTE_LEN + 4 { 253 | Err(ImplError::ErrNodeBytesBadSize) 254 | } else { 255 | node.node_key = H::from_bytes(&b[..HASH_BYTE_LEN])?; 256 | let mark = u32::from_le_bytes( 257 | b[HASH_BYTE_LEN..HASH_BYTE_LEN + 4].try_into().unwrap(), 258 | ); 259 | let preimage_len = (mark & 255) as usize; 260 | node.compress_flags = mark >> 8; 261 | let mut cur_pos = HASH_BYTE_LEN + 4; 262 | if b.len() < cur_pos + preimage_len * 32 + 1 { 263 | Err(ImplError::ErrNodeBytesBadSize) 264 | } else { 265 | for i in 0..preimage_len { 266 | let a = &b[i * 32 + cur_pos..(i + 1) * 32 + cur_pos]; 267 | node.value_preimage.push(a.try_into().unwrap()); 268 | } 269 | cur_pos += preimage_len * 32; 270 | let preimage_size = b[cur_pos] as usize; 271 | cur_pos += 1; 272 | if preimage_size != 0 { 273 | if b.len() < cur_pos + preimage_size || preimage_size != 32 { 274 | Err(ImplError::ErrNodeBytesBadSize) 275 | } else { 276 | let a = &b[cur_pos..cur_pos + preimage_size]; 277 | node.key_preimage = Some(a.try_into().unwrap()); 278 | Ok(node) 279 | } 280 | } else { 281 | Ok(node) 282 | } 283 | } 284 | } 285 | } 286 | NodeTypeEmpty | NodeTypeEmptyNew => Ok(node), 287 | _ => Err(ImplError::ErrInvalidNodeFound), 288 | } 289 | } 290 | } 291 | /// is_terminal returns if the node is 'terminated', i.e. empty or leaf node 292 | pub fn is_terminal(&self) -> bool { 293 | match self.node_type { 294 | NodeTypeEmptyNew | NodeTypeLeafNew => true, 295 | NodeTypeBranch0 | NodeTypeBranch1 | NodeTypeBranch2 | NodeTypeBranch3 => false, 296 | NodeTypeEmpty | NodeTypeLeaf | NodeTypeParent => { 297 | panic!("encounter deprecated node types") 298 | } 299 | _ => panic!("encounter unknown node types {:?}", self.node_type), 300 | } 301 | } 302 | 303 | /// NodeHash computes the hash digest of the node by hashing the content in a 304 | /// specific way for each type of node. This key is used as the hash of the 305 | /// Merkle tree for each node. 306 | pub fn calc_node_hash(mut self) -> Result { 307 | let zero_temp = H::hash_zero(); 308 | if self.node_hash.is_none() { 309 | // Cache the key to avoid repeated hash computations. 310 | // NOTE: We are not using the type to calculate the hash! 311 | match self.node_type { 312 | NodeTypeBranch0 | NodeTypeBranch1 | NodeTypeBranch2 | NodeTypeBranch3 => { 313 | // H(ChildL || ChildR) 314 | self.node_hash = Some(H::hash_elems_with_domain( 315 | self.node_type as u64, 316 | self.child_left.as_ref().unwrap_or(&zero_temp), 317 | self.child_right.as_ref().unwrap_or(&zero_temp), 318 | )?); 319 | } 320 | NodeTypeLeafNew => { 321 | let value_hash = Self::handling_elems_and_bytes32( 322 | self.compress_flags, 323 | &self.value_preimage, 324 | )?; 325 | self.node_hash = Some(H::hash_elems_with_domain( 326 | self.node_type as u64, 327 | &self.node_key, 328 | &value_hash, 329 | )?); 330 | self.value_hash = Some(value_hash); 331 | } 332 | NodeTypeEmptyNew => { 333 | // Zero 334 | self.node_hash = Some(H::hash_zero()); 335 | } 336 | NodeTypeEmpty | NodeTypeLeaf | NodeTypeParent => { 337 | panic!("encounter deprecated node types") 338 | } 339 | _ => return Err(ImplError::ErrInvalidField), 340 | } 341 | } 342 | Ok(self) 343 | } 344 | 345 | /// Set NodeHash, if node hash has been calculated, 346 | /// we would compare and complain for unmatch 347 | pub fn set_node_hash(&mut self, hash: H) { 348 | if let Some(existed_hash) = &self.node_hash { 349 | assert_eq!( 350 | existed_hash, &hash, 351 | "the set hash must be equal to calculated one" 352 | ); 353 | } else { 354 | self.node_hash.replace(hash); 355 | } 356 | } 357 | 358 | /// Return the nodehash, in case it is not calculated, we get None 359 | pub fn node_hash(&self) -> Option { 360 | self.node_hash.clone() 361 | } 362 | 363 | /// ValueHash computes the hash digest of the value stored in the leaf node. For 364 | /// other node types, it returns the zero hash. in case it is not calculated, 365 | /// we get None 366 | pub fn value_hash(&self) -> Option { 367 | if self.node_hash.is_some() { 368 | match self.node_type { 369 | NodeTypeLeafNew => self.value_hash.clone(), 370 | _ => Some(H::hash_zero()), 371 | } 372 | } else { 373 | None 374 | } 375 | } 376 | 377 | /// Data returns the wrapped data inside LeafNode and cast them into bytes 378 | /// for other node type it just return None 379 | pub fn data(&self) -> Option> { 380 | match self.node_type { 381 | NodeTypeLeafNew => { 382 | let bytes = self 383 | .value_preimage 384 | .as_slice() 385 | .iter() 386 | .flat_map(|bt| bt.as_slice()) 387 | .copied(); 388 | 389 | Some(bytes.collect::>()) 390 | } 391 | _ => None, 392 | } 393 | } 394 | 395 | // Value returns the encoded bytes of a node, include all information of it 396 | pub fn value(&self) -> Vec { 397 | let mut out_bytes = self.canonical_value(); 398 | let len = out_bytes.len(); 399 | if self.node_type == NodeTypeLeafNew { 400 | if let Some(key_preimage) = &self.key_preimage { 401 | out_bytes[len - 1] = key_preimage.len() as u8; 402 | out_bytes.extend(key_preimage) 403 | } 404 | } 405 | out_bytes 406 | } 407 | 408 | /// CanonicalValue returns the byte form of a node required to be persisted, and strip unnecessary fields 409 | /// from the encoding (current only KeyPreimage for Leaf node) to keep a minimum size for content being 410 | /// stored in backend storage 411 | pub fn canonical_value(&self) -> Vec { 412 | match self.node_type { 413 | NodeTypeBranch0 | NodeTypeBranch1 | NodeTypeBranch2 | NodeTypeBranch3 => { 414 | let mut b = vec![self.node_type as u8]; 415 | b.append(&mut self.child_left.as_ref().unwrap().to_bytes()); 416 | b.append(&mut self.child_right.as_ref().unwrap().to_bytes()); 417 | b 418 | } 419 | NodeTypeLeafNew => { 420 | let mut b = vec![self.node_type as u8]; 421 | b.append(&mut self.node_key.to_bytes()); 422 | let mark = (self.compress_flags << 8) + self.value_preimage.len() as u32; 423 | b.append(&mut u32::to_le_bytes(mark).to_vec()); 424 | for i in 0..self.value_preimage.len() { 425 | b.append(&mut self.value_preimage[i].to_vec()); 426 | } 427 | b.push(0); 428 | b 429 | } 430 | NodeTypeEmptyNew => { 431 | vec![self.node_type as u8] 432 | } 433 | NodeTypeEmpty | NodeTypeLeaf | NodeTypeParent => { 434 | panic!("encounter deprecated node types") 435 | } 436 | _ => { 437 | vec![] 438 | } 439 | } 440 | } 441 | 442 | /// String outputs a string representation of a node (different for each type). 443 | #[allow(clippy::inherent_to_string)] 444 | pub fn to_string(&self) -> String { 445 | match self.node_type { 446 | // {Type || ChildL || ChildR} 447 | NodeTypeBranch0 => format!( 448 | "Parent L(t):{:?} R(t):{:?}", 449 | self.child_left, self.child_right 450 | ), 451 | NodeTypeBranch1 => { 452 | format!("Parent L(t):{:?} R:{:?}", self.child_left, self.child_right) 453 | } 454 | NodeTypeBranch2 => { 455 | format!("Parent L:{:?} R(t):{:?}", self.child_left, self.child_right) 456 | } 457 | NodeTypeBranch3 => format!("Parent L:{:?} R:{:?}", self.child_left, self.child_right), 458 | NodeTypeLeafNew => 459 | // {Type || Data...} 460 | { 461 | format!( 462 | "Leaf I:{:?} Items: {}, First:{:?}", 463 | self.node_key, 464 | self.value_preimage.len(), 465 | self.value_preimage[0] 466 | ) 467 | } 468 | NodeTypeEmptyNew => 469 | // {} 470 | { 471 | "Empty".to_string() 472 | } 473 | NodeTypeEmpty | NodeTypeLeaf | NodeTypeParent => "deprecated Node".to_string(), 474 | _ => "Invalid Node".to_string(), 475 | } 476 | } 477 | } 478 | 479 | #[cfg(test)] 480 | mod tests { 481 | use crate::hash::HashImpl as Hash; 482 | use crate::raw::ImplError; 483 | use crate::types::{Hashable, Node}; 484 | use crate::types::{NodeType::*, HASH_BYTE_LEN}; 485 | 486 | #[test] 487 | fn test_new_node() { 488 | //NodeTypeEmptyNew 489 | let node1 = Node::::new_empty_node().calc_node_hash().unwrap(); 490 | assert_eq!(node1.node_type, NodeTypeEmptyNew); 491 | 492 | let h = node1.node_hash().unwrap(); 493 | assert_eq!(h, Hash::hash_zero()); 494 | let h = node1.value_hash().unwrap(); 495 | assert_eq!(h, Hash::hash_zero()); 496 | 497 | //NodeTypeLeafNew 498 | let k = Hash::from_bytes(&[47u8; 32]).unwrap(); 499 | let vp = vec![[48u8; 32]]; 500 | let node2 = Node::::new_leaf_node(k, 1, vp.clone()) 501 | .calc_node_hash() 502 | .unwrap(); 503 | assert_eq!(node2.node_type, NodeTypeLeafNew); 504 | assert_eq!(node2.compress_flags, 1u32); 505 | assert_eq!(node2.value_preimage, vp); 506 | 507 | let h = node2.node_hash(); 508 | assert!(h.is_some()); 509 | let h = node2.value_hash(); 510 | assert!(h.is_some()); 511 | 512 | //New Parent Node 513 | let k = Hash::from_bytes(&[47u8; 32]).unwrap(); 514 | let node3 = Node::::new_parent_node(NodeTypeBranch3, k.clone(), k.clone()) 515 | .calc_node_hash() 516 | .unwrap(); 517 | assert_eq!(node3.node_type, NodeTypeBranch3); 518 | assert_eq!(node3.child_left.as_ref().unwrap(), &k); 519 | assert_eq!(node3.child_right.as_ref().unwrap(), &k); 520 | 521 | //New Parent Node with empty child 522 | let k = Hash::from_bytes(&[47u8; 32]).unwrap(); 523 | let r = Hash::hash_zero(); 524 | let node4 = Node::::new_parent_node(NodeTypeBranch2, k.clone(), r.clone()) 525 | .calc_node_hash() 526 | .unwrap(); 527 | assert_eq!(node4.node_type, NodeTypeBranch2); 528 | assert_eq!(node4.child_left.as_ref().unwrap(), &k); 529 | assert_eq!(node4.child_right.as_ref().unwrap(), &r); 530 | 531 | let h = node4.node_hash(); 532 | assert!(h.is_some()); 533 | let h = node4.value_hash(); 534 | assert!(h.is_some()); 535 | } 536 | 537 | #[test] 538 | fn test_new_node_from_bytes() { 539 | //Parent Node 540 | let k1 = Hash::from_bytes(&[47u8; 32]).unwrap(); 541 | let k2 = Hash::from_bytes(&[48u8; 32]).unwrap(); 542 | let node1 = Node::::new_parent_node(NodeTypeBranch0, k1.clone(), k2.clone()) 543 | .calc_node_hash() 544 | .unwrap(); 545 | assert_eq!(node1.node_type, NodeTypeBranch0); 546 | assert_eq!(node1.child_left.as_ref().unwrap(), &k1); 547 | assert_eq!(node1.child_right.as_ref().unwrap(), &k2); 548 | 549 | let h = node1.node_hash(); 550 | assert!(h.is_some()); 551 | let h = node1.value_hash(); 552 | assert!(h.is_some()); 553 | 554 | //Leaf Node 555 | let k = Hash::from_bytes(&[47u8; 32]).unwrap(); 556 | let vp = vec![[1u8; 32]]; 557 | let mut node2 = Node::::new_leaf_node(k, 1, vp.clone()) 558 | .calc_node_hash() 559 | .unwrap(); 560 | let h = node2.node_hash(); 561 | assert!(h.is_some()); 562 | let h = node2.value_hash(); 563 | assert!(h.is_some()); 564 | 565 | node2.key_preimage = Some([48u8; 32]); 566 | let b = node2.value(); 567 | let new_node = Node::::new_node_from_bytes(&b); 568 | assert!(new_node.is_ok()); 569 | let new_node = new_node.unwrap(); 570 | assert_eq!(node2.node_type, new_node.node_type); 571 | assert_eq!(node2.node_key, new_node.node_key); 572 | assert_eq!(node2.value_preimage, new_node.value_preimage); 573 | assert_eq!(node2.key_preimage, new_node.key_preimage); 574 | 575 | //Empty Node 576 | let b = Node::::new_empty_node().value(); 577 | let new_node = Node::::new_node_from_bytes(&b); 578 | assert!(new_node.is_ok()); 579 | 580 | let node3 = new_node.unwrap().calc_node_hash().unwrap(); 581 | let h = node3.node_hash().unwrap(); 582 | assert_eq!(h, Hash::hash_zero()); 583 | let h = node3.value_hash().unwrap(); 584 | assert_eq!(h, Hash::hash_zero()); 585 | 586 | //Bad Size 587 | let b = vec![]; 588 | let node = Node::::new_node_from_bytes(&b); 589 | assert!(node.is_err()); 590 | assert_eq!(node.err().unwrap(), ImplError::ErrNodeBytesBadSize); 591 | 592 | let b = vec![0u8, 1u8, 2u8]; 593 | let node = Node::::new_node_from_bytes(&b); 594 | assert!(node.is_err()); 595 | assert_eq!(node.err().unwrap(), ImplError::ErrNodeBytesBadSize); 596 | 597 | let b = vec![NodeTypeLeaf as u8; HASH_BYTE_LEN + 3]; 598 | let node = Node::::new_node_from_bytes(&b); 599 | assert!(node.is_err()); 600 | assert_eq!(node.err().unwrap(), ImplError::ErrNodeBytesBadSize); 601 | 602 | let k = Hash::from_bytes(&[47u8; 32]).unwrap(); 603 | let vp = vec![[1u8; 32]]; 604 | let valid_node = Node::::new_leaf_node(k, 1, vp.clone()); 605 | let b = valid_node.value(); 606 | let node = Node::::new_node_from_bytes(&b[0..b.len() - 32]); 607 | assert!(node.is_err()); 608 | assert_eq!(node.err().unwrap(), ImplError::ErrNodeBytesBadSize); 609 | 610 | let k = Hash::from_bytes(&[47u8; 32]).unwrap(); 611 | let vp = vec![[1u8; 32]]; 612 | let mut valid_node = Node::::new_leaf_node(k, 1, vp.clone()); 613 | valid_node.key_preimage = Some([48u8; 32]); 614 | let b = valid_node.value(); 615 | let node = Node::::new_node_from_bytes(&b[0..b.len() - 1]); 616 | assert!(node.is_err()); 617 | assert_eq!(node.err().unwrap(), ImplError::ErrNodeBytesBadSize); 618 | 619 | //Invalid type 620 | let b = vec![255u8]; 621 | let node = Node::::new_node_from_bytes(&b); 622 | assert!(node.is_err()); 623 | assert_eq!(node.err().unwrap(), ImplError::ErrInvalidNodeFound); 624 | } 625 | 626 | #[test] 627 | fn test_node_value_and_data() { 628 | let a1 = [47u8; 32]; 629 | let a2 = [48u8; 32]; 630 | let a3 = [49u8; 32]; 631 | let mark = [1u8, 1u8, 0u8, 0u8]; 632 | let k = Hash::from_bytes(&a1).unwrap(); 633 | let vp = vec![a2]; 634 | 635 | //Leaf Node 636 | let mut node = Node::::new_leaf_node(k.clone(), 1, vp.clone()); 637 | let mut v = vec![4u8]; 638 | v.append(&mut a1.to_vec()); 639 | v.append(&mut mark.to_vec()); 640 | v.append(&mut a2.to_vec()); 641 | v.push(0); 642 | assert_eq!(node.canonical_value(), v); 643 | 644 | v.remove(v.len() - 1); 645 | node.key_preimage = Some([49u8; 32]); 646 | v.push(32u8); 647 | v.append(&mut a3.to_vec()); 648 | assert_eq!(node.value(), v); 649 | 650 | assert_eq!(node.data().unwrap(), a2.to_vec()); 651 | 652 | //Parent Node 653 | let node = Node::::new_parent_node(NodeTypeBranch3, k.clone(), k.clone()); 654 | v = vec![9u8]; 655 | v.append(&mut a1.to_vec()); 656 | v.append(&mut a1.to_vec()); 657 | assert_eq!(node.canonical_value(), v); 658 | 659 | //empty Node 660 | let node = Node::::new_empty_node(); 661 | v = vec![5u8]; 662 | assert_eq!(node.canonical_value(), v); 663 | } 664 | } 665 | --------------------------------------------------------------------------------