diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index fc53950d42..f1def6a32f 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -36,6 +36,7 @@ type Reader interface { L1HandlerTxnHash(msgHash *common.Hash) (l1HandlerTxnHash *felt.Felt, err error) HeadState() (core.StateReader, StateCloser, error) + HeadTrie() (core.TrieReader, StateCloser, error) StateAtBlockHash(blockHash *felt.Felt) (core.StateReader, StateCloser, error) StateAtBlockNumber(blockNumber uint64) (core.StateReader, StateCloser, error) @@ -768,6 +769,17 @@ func (b *Blockchain) HeadState() (core.StateReader, StateCloser, error) { return core.NewState(txn), txn.Discard, nil } +func (b *Blockchain) HeadTrie() (core.TrieReader, StateCloser, error) { + // Note: I'm not sure I should open a new db txn since the TrieReader is a State + // so the same instance of the state we create in HeadState will do job. + txn, err := b.database.NewTransaction(false) + if err != nil { + return nil, nil, err + } + + return core.NewState(txn), txn.Discard, nil +} + // StateAtBlockNumber returns a StateReader that provides a stable view to the state at the given block number func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (core.StateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockNumber") diff --git a/core/state.go b/core/state.go index 378ba65bec..bda6e2ffa8 100644 --- a/core/state.go +++ b/core/state.go @@ -44,6 +44,17 @@ type StateReader interface { Class(classHash *felt.Felt) (*DeclaredClass, error) } +// TrieReader used for storage proofs, can only be supported by current state implementation (for now, we plan to add db snapshots) +var _ TrieReader = (*State)(nil) + +//go:generate mockgen -destination=../mocks/mock_trie.go -package=mocks github.com/NethermindEth/juno/core TrieReader +type TrieReader interface { + ClassTrie() (*trie.Trie, func() error, error) + StorageTrie() (*trie.Trie, func() error, error) + StorageTrieForAddr(addr *felt.Felt) (*trie.Trie, error) + StateAndClassRoot() (*felt.Felt, *felt.Felt, error) +} + type State struct { *history txn db.Transaction @@ -129,6 +140,18 @@ func (s *State) storage() (*trie.Trie, func() error, error) { return s.globalTrie(db.StateTrie, trie.NewTriePedersen) } +func (s *State) StorageTrie() (*trie.Trie, func() error, error) { + return s.storage() +} + +func (s *State) ClassTrie() (*trie.Trie, func() error, error) { + return s.classesTrie() +} + +func (s *State) StorageTrieForAddr(addr *felt.Felt) (*trie.Trie, error) { + return storage(addr, s.txn) +} + func (s *State) classesTrie() (*trie.Trie, func() error, error) { return s.globalTrie(db.ClassesTrie, trie.NewTriePoseidon) } @@ -547,7 +570,7 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { err = s.performStateDeletions(blockNumber, update.StateDiff) if err != nil { - return fmt.Errorf("error performing state deletions: %v", err) + return fmt.Errorf("build reverse diff: %v", err) } stateTrie, storageCloser, err := s.storage() @@ -581,6 +604,7 @@ func (s *State) purgeNoClassContracts() error { // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, // we can use the lack of key's existence as reason for purging noClassContracts. + for addr := range noClassContracts { noClassC, err := NewContractUpdater(&addr, s.txn) if err != nil { @@ -743,3 +767,35 @@ func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error return nil } + +func (s *State) StateAndClassRoot() (*felt.Felt, *felt.Felt, error) { + var storageRoot, classesRoot *felt.Felt + + sStorage, closer, err := s.storage() + if err != nil { + return nil, nil, err + } + + if storageRoot, err = sStorage.Root(); err != nil { + return nil, nil, err + } + + if err = closer(); err != nil { + return nil, nil, err + } + + classes, closer, err := s.classesTrie() + if err != nil { + return nil, nil, err + } + + if classesRoot, err = classes.Root(); err != nil { + return nil, nil, err + } + + if err = closer(); err != nil { + return nil, nil, err + } + + return storageRoot, classesRoot, nil +} diff --git a/core/trie/key.go b/core/trie/key.go index 0d0ca7aa88..faae0d2b49 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -1,10 +1,10 @@ package trie import ( - "bytes" "encoding/hex" "errors" "fmt" + "io" "math/big" "github.com/NethermindEth/juno/core/felt" @@ -39,8 +39,8 @@ func (k *Key) unusedBytes() []byte { return k.bitset[:len(k.bitset)-int(k.bytesNeeded())] } -func (k *Key) WriteTo(buf *bytes.Buffer) (int64, error) { - if err := buf.WriteByte(k.len); err != nil { +func (k *Key) WriteTo(buf io.Writer) (int64, error) { + if _, err := buf.Write([]byte{k.len}); err != nil { return 0, err } diff --git a/core/trie/key_test.go b/core/trie/key_test.go index 3867678e6e..d0cf3a9c35 100644 --- a/core/trie/key_test.go +++ b/core/trie/key_test.go @@ -2,6 +2,7 @@ package trie_test import ( "bytes" + "errors" "testing" "github.com/NethermindEth/juno/core/felt" @@ -227,3 +228,39 @@ func TestMostSignificantBits(t *testing.T) { }) } } + +func TestKeyErrorHandling(t *testing.T) { + t.Run("passed too long key bytes panics", func(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + require.Contains(t, r.(string), "bytes does not fit in bitset") + }() + tooLongKeyB := make([]byte, 33) + trie.NewKey(8, tooLongKeyB) + }) + t.Run("MostSignificantBits n greater than key length", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + _, err := key.MostSignificantBits(9) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot get more bits than the key length") + }) + t.Run("MostSignificantBits equals key length return copy of key", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + kCopy, err := key.MostSignificantBits(8) + require.NoError(t, err) + require.Equal(t, key, *kCopy) + }) + t.Run("WriteTo returns error", func(t *testing.T) { + key := trie.NewKey(8, []byte{0x01}) + wrote, err := key.WriteTo(&errorBuffer{}) + require.Error(t, err) + require.Equal(t, int64(0), wrote) + }) +} + +type errorBuffer struct{} + +func (*errorBuffer) Write([]byte) (int, error) { + return 0, errors.New("expected to fail") +} diff --git a/core/trie/node.go b/core/trie/node.go index 2ef176f92a..c738fe4272 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -1,9 +1,9 @@ package trie import ( - "bytes" "errors" "fmt" + "io" "github.com/NethermindEth/juno/core/felt" ) @@ -19,7 +19,7 @@ type Node struct { } // Hash calculates the hash of a [Node] -func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) Hash(path *Key, hashFunc HashFunc) *felt.Felt { if path.Len() == 0 { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -34,12 +34,12 @@ func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { } // Hash calculates the hash of a [Node] -func (n *Node) HashFromParent(parentKey, nodeKey *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) HashFromParent(parentKey, nodeKey *Key, hashFunc HashFunc) *felt.Felt { path := path(nodeKey, parentKey) return n.Hash(&path, hashFunc) } -func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { +func (n *Node) WriteTo(buf io.Writer) (int64, error) { if n.Value == nil { return 0, errors.New("cannot marshal node with nil value") } diff --git a/core/trie/node_test.go b/core/trie/node_test.go index ccb52b3eac..3ac71a9241 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -1,7 +1,9 @@ package trie_test import ( + "bytes" "encoding/hex" + "errors" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -26,3 +28,33 @@ func TestNodeHash(t *testing.T) { assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") } + +func TestNodeErrorHandling(t *testing.T) { + t.Run("WriteTo node value is nil", func(t *testing.T) { + node := trie.Node{} + var buffer bytes.Buffer + _, err := node.WriteTo(&buffer) + require.Error(t, err) + }) + t.Run("WriteTo returns error", func(t *testing.T) { + node := trie.Node{ + Value: new(felt.Felt).SetUint64(42), + Left: &trie.Key{}, + Right: &trie.Key{}, + } + + wrote, err := node.WriteTo(&errorBuffer{}) + require.Error(t, err) + require.Equal(t, int64(0), wrote) + }) + t.Run("UnmarshalBinary returns error", func(t *testing.T) { + node := trie.Node{} + + err := node.UnmarshalBinary([]byte{42}) + require.Equal(t, errors.New("size of input data is less than felt size"), err) + + bs := new(felt.Felt).Bytes() + err = node.UnmarshalBinary(append(bs[:], 0, 0, 42)) + require.Equal(t, errors.New("the node does not contain both left and right hash"), err) + }) +} diff --git a/core/trie/proof.go b/core/trie/proof.go index bc4b66d0d9..f0c4e0130f 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -16,7 +16,7 @@ func NewProofNodeSet() *ProofNodeSet { } type ProofNode interface { - Hash(hash hashFunc) *felt.Felt + Hash(hash HashFunc) *felt.Felt Len() uint8 String() string } @@ -26,7 +26,7 @@ type Binary struct { RightHash *felt.Felt } -func (b *Binary) Hash(hash hashFunc) *felt.Felt { +func (b *Binary) Hash(hash HashFunc) *felt.Felt { return hash(b.LeftHash, b.RightHash) } @@ -43,7 +43,7 @@ type Edge struct { Path *Key // path from parent to child } -func (e *Edge) Hash(hash hashFunc) *felt.Felt { +func (e *Edge) Hash(hash HashFunc) *felt.Felt { length := make([]byte, len(e.Path.bitset)) length[len(e.Path.bitset)-1] = e.Path.len pathFelt := e.Path.Felt() @@ -137,7 +137,7 @@ func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSe // - Any node's computed hash doesn't match its expected hash // - The path bits don't match the key bits // - The proof ends before processing all key bits -func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash hashFunc) (*felt.Felt, error) { +func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash HashFunc) (*felt.Felt, error) { key := FeltToKey(globalTrieHeight, keyFelt) expectedHash := root keyLen := key.Len() diff --git a/core/trie/proofset.go b/core/trie/proofset.go new file mode 100644 index 0000000000..cb20195b5b --- /dev/null +++ b/core/trie/proofset.go @@ -0,0 +1,49 @@ +package trie + +import ( + "sync" + + "github.com/NethermindEth/juno/core/felt" +) + +// ProofSet represents a set of trie nodes used in a Merkle proof verification process. +// Rather than relying on only either map of list, ProofSet provides both for the following reasons: +// - map allows for unique node insertion +// - list allows for ordered iteration over the proof nodes +// It also supports concurrent read and write operations. +type ProofSet struct { + nodeSet map[felt.Felt]ProofNode + nodeList []ProofNode + size int + lock sync.RWMutex +} + +func NewProofSet() *ProofSet { + return &ProofSet{ + nodeSet: make(map[felt.Felt]ProofNode), + } +} + +func (ps *ProofSet) Put(key felt.Felt, node ProofNode) { + ps.lock.Lock() + defer ps.lock.Unlock() + + ps.nodeSet[key] = node + ps.nodeList = append(ps.nodeList, node) + ps.size++ +} + +func (ps *ProofSet) Get(key felt.Felt) (ProofNode, bool) { + ps.lock.RLock() + defer ps.lock.RUnlock() + + node, ok := ps.nodeSet[key] + return node, ok +} + +func (ps *ProofSet) Size() int { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return ps.size +} diff --git a/core/trie/trie.go b/core/trie/trie.go index c21168c505..1e4e760754 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -16,7 +16,7 @@ import ( const globalTrieHeight = 251 // TODO(weiihann): this is declared in core also, should be moved to a common place -type hashFunc func(*felt.Felt, *felt.Felt) *felt.Felt +type HashFunc func(*felt.Felt, *felt.Felt) *felt.Felt // Trie is a dense Merkle Patricia Trie (i.e., all internal nodes have two children). // @@ -40,7 +40,7 @@ type Trie struct { rootKey *Key maxKey *felt.Felt storage *Storage - hash hashFunc + hash HashFunc dirtyNodes []*Key rootKeyIsDirty bool @@ -56,7 +56,7 @@ func NewTriePoseidon(storage *Storage, height uint8) (*Trie, error) { return newTrie(storage, height, crypto.Poseidon) } -func newTrie(storage *Storage, height uint8, hash hashFunc) (*Trie, error) { +func newTrie(storage *Storage, height uint8, hash HashFunc) (*Trie, error) { if height > felt.Bits { return nil, fmt.Errorf("max trie height is %d, got: %d", felt.Bits, height) } @@ -96,12 +96,17 @@ func RunOnTempTriePoseidon(height uint8, do func(*Trie) error) error { return do(trie) } -// feltToKey Converts a key, given in felt, to a trie.Key which when followed on a [Trie], +// FeltToKey Converts a key, given in felt, to a trie.Key which when followed on a [Trie], // leads to the corresponding [Node] func (t *Trie) FeltToKey(k *felt.Felt) Key { return FeltToKey(t.height, k) } +// HashFunc returns the hash function used by the trie +func (t *Trie) HashFunc() HashFunc { + return t.hash +} + // path returns the path as mentioned in the [specification] for commitment calculations. // path is suffix of key that diverges from parentKey. For example, // for a key 0b1011 and parentKey 0b10, this function would return the path object of 0b0. diff --git a/mocks/mock_blockchain.go b/mocks/mock_blockchain.go index 05ec6b7f9c..4339c80adc 100644 --- a/mocks/mock_blockchain.go +++ b/mocks/mock_blockchain.go @@ -164,6 +164,22 @@ func (mr *MockReaderMockRecorder) HeadState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadState", reflect.TypeOf((*MockReader)(nil).HeadState)) } +// HeadTrie mocks base method. +func (m *MockReader) HeadTrie() (core.TrieReader, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HeadTrie") + ret0, _ := ret[0].(core.TrieReader) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// HeadTrie indicates an expected call of HeadTrie. +func (mr *MockReaderMockRecorder) HeadTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadTrie", reflect.TypeOf((*MockReader)(nil).HeadTrie)) +} + // HeadsHeader mocks base method. func (m *MockReader) HeadsHeader() (*core.Header, error) { m.ctrl.T.Helper() diff --git a/mocks/mock_trie.go b/mocks/mock_trie.go new file mode 100644 index 0000000000..570a055c4e --- /dev/null +++ b/mocks/mock_trie.go @@ -0,0 +1,104 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/NethermindEth/juno/core (interfaces: TrieReader) +// +// Generated by this command: +// +// mockgen -destination=../mocks/mock_trie.go -package=mocks github.com/NethermindEth/juno/core TrieReader +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + felt "github.com/NethermindEth/juno/core/felt" + trie "github.com/NethermindEth/juno/core/trie" + gomock "go.uber.org/mock/gomock" +) + +// MockTrieReader is a mock of TrieReader interface. +type MockTrieReader struct { + ctrl *gomock.Controller + recorder *MockTrieReaderMockRecorder +} + +// MockTrieReaderMockRecorder is the mock recorder for MockTrieReader. +type MockTrieReaderMockRecorder struct { + mock *MockTrieReader +} + +// NewMockTrieReader creates a new mock instance. +func NewMockTrieReader(ctrl *gomock.Controller) *MockTrieReader { + mock := &MockTrieReader{ctrl: ctrl} + mock.recorder = &MockTrieReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrieReader) EXPECT() *MockTrieReaderMockRecorder { + return m.recorder +} + +// ClassTrie mocks base method. +func (m *MockTrieReader) ClassTrie() (*trie.Trie, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClassTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ClassTrie indicates an expected call of ClassTrie. +func (mr *MockTrieReaderMockRecorder) ClassTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClassTrie", reflect.TypeOf((*MockTrieReader)(nil).ClassTrie)) +} + +// StateAndClassRoot mocks base method. +func (m *MockTrieReader) StateAndClassRoot() (*felt.Felt, *felt.Felt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StateAndClassRoot") + ret0, _ := ret[0].(*felt.Felt) + ret1, _ := ret[1].(*felt.Felt) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// StateAndClassRoot indicates an expected call of StateAndClassRoot. +func (mr *MockTrieReaderMockRecorder) StateAndClassRoot() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateAndClassRoot", reflect.TypeOf((*MockTrieReader)(nil).StateAndClassRoot)) +} + +// StorageTrie mocks base method. +func (m *MockTrieReader) StorageTrie() (*trie.Trie, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StorageTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// StorageTrie indicates an expected call of StorageTrie. +func (mr *MockTrieReaderMockRecorder) StorageTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrie", reflect.TypeOf((*MockTrieReader)(nil).StorageTrie)) +} + +// StorageTrieForAddr mocks base method. +func (m *MockTrieReader) StorageTrieForAddr(arg0 *felt.Felt) (*trie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StorageTrieForAddr", arg0) + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StorageTrieForAddr indicates an expected call of StorageTrieForAddr. +func (mr *MockTrieReaderMockRecorder) StorageTrieForAddr(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrieForAddr", reflect.TypeOf((*MockTrieReader)(nil).StorageTrieForAddr), arg0) +} diff --git a/rpc/contract.go b/rpc/contract.go index a33f84399e..ba7de93029 100644 --- a/rpc/contract.go +++ b/rpc/contract.go @@ -1,10 +1,7 @@ package rpc import ( - "errors" - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" ) @@ -30,33 +27,3 @@ func (h *Handler) Nonce(id BlockID, address felt.Felt) (*felt.Felt, *jsonrpc.Err return nonce, nil } - -// StorageAt gets the value of the storage at the given address and key. -// -// It follows the specification defined here: -// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L110 -func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *jsonrpc.Error) { - stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) - if rpcErr != nil { - return nil, rpcErr - } - defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageAt") - - // This checks if the contract exists because if a key doesn't exist in contract storage, - // the returned value is always zero and error is nil. - _, err := stateReader.ContractClassHash(&address) - if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { - return nil, ErrContractNotFound - } - h.log.Errorw("Failed to get contract nonce", "err", err) - return nil, ErrInternal - } - - value, err := stateReader.ContractStorage(&address, &key) - if err != nil { - return nil, ErrContractNotFound - } - - return value, nil -} diff --git a/rpc/contract_test.go b/rpc/contract_test.go index c9abb5214e..522ab0bb62 100644 --- a/rpc/contract_test.go +++ b/rpc/contract_test.go @@ -86,89 +86,3 @@ func TestNonce(t *testing.T) { assert.Equal(t, expectedNonce, nonce) }) } - -func TestStorageAt(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockReader := mocks.NewMockReader(mockCtrl) - log := utils.NewNopZapLogger() - handler := rpc.New(mockReader, nil, nil, "", log) - - t.Run("empty blockchain", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - mockState := mocks.NewMockStateHistoryReader(mockCtrl) - - t.Run("non-existent contract", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrContractNotFound, rpcErr) - }) - - t.Run("non-existent key", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrContractNotFound, rpcErr) - }) - - expectedStorage := new(felt.Felt).SetUint64(1) - - t.Run("blockID - latest", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) - - t.Run("blockID - hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) - - t.Run("blockID - number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) -} diff --git a/rpc/handlers.go b/rpc/handlers.go index 1cf96b0c21..57d88f76e3 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -71,6 +71,11 @@ var ( // These errors can be only be returned by Juno-specific methods. ErrSubscriptionNotFound = &jsonrpc.Error{Code: 100, Message: "Subscription not found"} + + ErrStorageProofNotSupported = &jsonrpc.Error{ + Code: 42, + Message: "the node doesn't support storage proofs for blocks that are too far in the past. Use 'latest' as block id", + } ) const ( @@ -260,6 +265,13 @@ func (h *Handler) Methods() ([]jsonrpc.Method, string) { //nolint: funlen Params: []jsonrpc.Parameter{{Name: "contract_address"}, {Name: "key"}, {Name: "block_id"}}, Handler: h.StorageAt, }, + { + Name: "starknet_getStorageProof", + Params: []jsonrpc.Parameter{ + {Name: "block_id"}, {Name: "classes", Optional: true}, {Name: "contracts", Optional: true}, {Name: "storage_keys", Optional: true}, + }, + Handler: h.StorageProof, + }, { Name: "starknet_getClassHashAt", Params: []jsonrpc.Parameter{{Name: "block_id"}, {Name: "contract_address"}}, diff --git a/rpc/storage.go b/rpc/storage.go new file mode 100644 index 0000000000..d5cd31a799 --- /dev/null +++ b/rpc/storage.go @@ -0,0 +1,320 @@ +package rpc + +import ( + "errors" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/jsonrpc" + "github.com/NethermindEth/juno/utils" +) + +/**************************************************** + Storage Handlers +*****************************************************/ + +// StorageAt gets the value of the storage at the given address and key. +// +// It follows the specification defined here: +// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L110 +func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *jsonrpc.Error) { + stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) + if rpcErr != nil { + return nil, rpcErr + } + defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageAt") + + // This checks if the contract exists because if a key doesn't exist in contract storage, + // the returned value is always zero and error is nil. + _, err := stateReader.ContractClassHash(&address) + if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { + return nil, ErrContractNotFound + } + h.log.Errorw("Failed to get contract nonce", "err", err) + return nil, ErrInternal + } + + value, err := stateReader.ContractStorage(&address, &key) + if err != nil { + return nil, ErrContractNotFound + } + + return value, nil +} + +// StorageProof returns the merkle paths in one of the state tries: global state, classes, individual contract +// +// It follows the specification defined here: +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L910 +func (h *Handler) StorageProof( + id BlockID, + classes, contracts []felt.Felt, + storageKeys []StorageKeys, +) (*StorageProofResult, *jsonrpc.Error) { + if !id.Latest { + return nil, ErrStorageProofNotSupported + } + + stateReader, stateCloser, err := h.bcReader.HeadState() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageProof") + + head, err := h.bcReader.Head() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + trieReader, stateCloser2, err := h.bcReader.HeadTrie() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + defer h.callAndLogErr(stateCloser2, "Error closing trie reader in getStorageProof") + + storageRoot, classRoot, err := trieReader.StateAndClassRoot() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + result := &StorageProofResult{ + GlobalRoots: &GlobalRoots{ + ContractsTreeRoot: storageRoot, + ClassesTreeRoot: classRoot, + BlockHash: head.Hash, + }, + } + + result.ClassesProof, err = getClassesProof(trieReader, classes) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + result.ContractsProof, err = getContractsProof(stateReader, trieReader, contracts) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + result.ContractsStorageProofs, err = getContractsStorageProofs(trieReader, storageKeys) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + return result, nil +} + +// StorageKeys represents an item in `contracts_storage_keys. parameter +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L938 +type StorageKeys struct { + Contract felt.Felt `json:"contract_address"` + Keys []felt.Felt `json:"storage_keys"` +} + +// MerkleNode represents a proof node in a trie +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3632 +// Implemented by MerkleBinaryNode, MerkleEdgeNode +type MerkleNode interface { + AsProofNode() trie.ProofNode +} + +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3644 +type MerkleBinaryNode struct { + Left *felt.Felt `json:"left"` + Right *felt.Felt `json:"right"` +} + +func (mbn *MerkleBinaryNode) AsProofNode() trie.ProofNode { + return &trie.Binary{ + LeftHash: mbn.Left, + RightHash: mbn.Right, + } +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L3720 +type MerkleEdgeNode struct { + Path string `json:"path"` + Length int `json:"length"` + Child *felt.Felt `json:"child"` +} + +func (men *MerkleEdgeNode) AsProofNode() trie.ProofNode { + f, _ := new(felt.Felt).SetString(men.Path) + pbs := f.Bytes() + path := trie.NewKey(uint8(men.Length), pbs[:]) + + return &trie.Edge{ + Path: &path, + Child: men.Child, + } +} + +// HashToNode represents an item in `NODE_HASH_TO_NODE_MAPPING` specified here +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3667 +type HashToNode struct { + Hash *felt.Felt `json:"node_hash"` + Node MerkleNode `json:"node"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L986 +type LeafData struct { + Nonce *felt.Felt `json:"nonce"` + ClassHash *felt.Felt `json:"class_hash"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L979 +type ContractProof struct { + Nodes []*HashToNode `json:"nodes"` + LeavesData []*LeafData `json:"contract_leaves_data"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L1011 +type GlobalRoots struct { + ContractsTreeRoot *felt.Felt `json:"contracts_tree_root"` + ClassesTreeRoot *felt.Felt `json:"classes_tree_root"` + BlockHash *felt.Felt `json:"block_hash"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L970 +type StorageProofResult struct { + ClassesProof []*HashToNode `json:"classes_proof"` + ContractsProof *ContractProof `json:"contracts_proof"` + ContractsStorageProofs [][]*HashToNode `json:"contracts_storage_proofs"` + GlobalRoots *GlobalRoots `json:"global_roots"` +} + +func getClassesProof(reader core.TrieReader, classes []felt.Felt) ([]*HashToNode, error) { + cTrie, _, err := reader.ClassTrie() + if err != nil { + return nil, err + } + result := []*HashToNode{} + for _, class := range utils.Unique(classes) { + nodes, err := getProof(cTrie, &class) + if err != nil { + return nil, err + } + result = append(result, nodes...) + } + + return deduplicate(result), nil +} + +func getContractsProof(stReader core.StateReader, trReader core.TrieReader, contracts []felt.Felt) (*ContractProof, error) { + sTrie, _, err := trReader.StorageTrie() + if err != nil { + return nil, err + } + + result := &ContractProof{ + Nodes: []*HashToNode{}, + LeavesData: make([]*LeafData, 0, len(contracts)), + } + + for _, contract := range contracts { + leafData, err := getLeafData(stReader, &contract) + if err != nil { + return nil, err + } + result.LeavesData = append(result.LeavesData, leafData) + + nodes, err := getProof(sTrie, &contract) + if err != nil { + return nil, err + } + result.Nodes = append(result.Nodes, nodes...) + } + + result.Nodes = deduplicate(result.Nodes) + return result, nil +} + +func getLeafData(reader core.StateReader, contract *felt.Felt) (*LeafData, error) { + nonce, err := reader.ContractNonce(contract) + if errors.Is(err, db.ErrKeyNotFound) { + return nil, nil + } + if err != nil { + return nil, err + } + classHash, err := reader.ContractClassHash(contract) + if err != nil { + return nil, err + } + + return &LeafData{ + Nonce: nonce, + ClassHash: classHash, + }, nil +} + +func getContractsStorageProofs(reader core.TrieReader, keys []StorageKeys) ([][]*HashToNode, error) { + result := make([][]*HashToNode, 0, len(keys)) + + for _, key := range keys { + csTrie, err := reader.StorageTrieForAddr(&key.Contract) + if err != nil { + // Note: if contract does not exist, `StorageTrieForAddr()` returns an empty trie, not an error + return nil, err + } + + nodes := []*HashToNode{} + for _, slot := range utils.Unique(key.Keys) { + proof, err := getProof(csTrie, &slot) + if err != nil { + return nil, err + } + nodes = append(nodes, proof...) + } + result = append(result, deduplicate(nodes)) + } + + return result, nil +} + +func getProof(t *trie.Trie, elt *felt.Felt) ([]*HashToNode, error) { + proofSet := trie.NewProofNodeSet() + err := t.Prove(elt, proofSet) + if err != nil { + return nil, err + } + + // adapt proofs to the expected format + nodes := proofSet.List() + hashNodes := make([]*HashToNode, len(nodes)) + for i, node := range nodes { + var merkle MerkleNode + + if binary, ok := node.(*trie.Binary); ok { + merkle = &MerkleBinaryNode{ + Left: binary.LeftHash, + Right: binary.RightHash, + } + } + if edge, ok := node.(*trie.Edge); ok { + path := edge.Path + f := path.Felt() + merkle = &MerkleEdgeNode{ + Path: f.String(), + Length: int(edge.Len()), + Child: edge.Child, + } + } + + hashNodes[i] = &HashToNode{ + Hash: node.Hash(t.HashFunc()), + Node: merkle, + } + } + + return hashNodes, nil +} + +func deduplicate(proof []*HashToNode) []*HashToNode { + if len(proof) == 0 { + return proof + } + + keyF := func(node *HashToNode) felt.Felt { return *node.Hash } + return utils.UniqueFunc(proof, keyF) +} diff --git a/rpc/storage_test.go b/rpc/storage_test.go new file mode 100644 index 0000000000..5defb5bad2 --- /dev/null +++ b/rpc/storage_test.go @@ -0,0 +1,860 @@ +package rpc_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/clients/feeder" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/mocks" + "github.com/NethermindEth/juno/rpc" + adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" + "github.com/NethermindEth/juno/sync" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestStorageAt(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + log := utils.NewNopZapLogger() + handler := rpc.New(mockReader, nil, nil, "", log) + + t.Run("empty blockchain", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + + t.Run("non-existent contract", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrContractNotFound, rpcErr) + }) + + t.Run("non-existent key", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrContractNotFound, rpcErr) + }) + + expectedStorage := new(felt.Felt).SetUint64(1) + + t.Run("blockID - latest", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) + + t.Run("blockID - hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) + + t.Run("blockID - number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) +} + +func TestStorageProof(t *testing.T) { + t.Parallel() + + // dummy values + var ( + blkHash = utils.HexToFelt(t, "0x11ead") + clsRoot = utils.HexToFelt(t, "0xc1a55") + stgRoot = utils.HexToFelt(t, "0xc0ffee") + key = new(felt.Felt).SetUint64(1) + key2 = new(felt.Felt).SetUint64(8) + noSuchKey = new(felt.Felt).SetUint64(0) + value = new(felt.Felt).SetUint64(51) + value2 = new(felt.Felt).SetUint64(58) + blockLatest = rpc.BlockID{Latest: true} + blockNumber = uint64(1313) + nopCloser = func() error { + return nil + } + ) + + tempTrie := emptyTrie(t) + + _, err := tempTrie.Put(key, value) + require.NoError(t, err) + _, err = tempTrie.Put(key2, value2) + require.NoError(t, err) + require.NoError(t, tempTrie.Commit()) + + trieRoot, err := tempTrie.Root() + require.NoError(t, err) + + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockTrie := mocks.NewMockTrieReader(mockCtrl) + + mockReader.EXPECT().HeadState().Return(mockState, func() error { + return nil + }, nil).AnyTimes() + mockReader.EXPECT().HeadTrie().Return(mockTrie, func() error { return nil }, nil).AnyTimes() + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}}, nil).AnyTimes() + mockTrie.EXPECT().StateAndClassRoot().Return(stgRoot, clsRoot, nil).AnyTimes() + mockTrie.EXPECT().ClassTrie().Return(tempTrie, nopCloser, nil).AnyTimes() + mockTrie.EXPECT().StorageTrie().Return(tempTrie, nopCloser, nil).AnyTimes() + + log := utils.NewNopZapLogger() + handler := rpc.New(mockReader, nil, nil, "", log) + + t.Run("Trie proofs sanity check", func(t *testing.T) { + t.Parallel() + + proof := trie.NewProofNodeSet() + err := tempTrie.Prove(key, proof) + require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) + leaf, err := trie.VerifyProof(root, key, proof, tempTrie.HashFunc()) + require.NoError(t, err) + require.Equal(t, leaf, value) + + // non-membership test + proof = trie.NewProofNodeSet() + err = tempTrie.Prove(key, proof) + require.NoError(t, err) + leaf, err = trie.VerifyProof(root, noSuchKey, proof, tempTrie.HashFunc()) + require.NoError(t, err) + require.Equal(t, felt.Zero, *leaf) + }) + t.Run("global roots are filled", func(t *testing.T) { + t.Parallel() + + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) + require.Nil(t, rpcErr) + + require.NotNil(t, proof) + require.NotNil(t, proof.GlobalRoots) + require.Equal(t, blkHash, proof.GlobalRoots.BlockHash) + require.Equal(t, clsRoot, proof.GlobalRoots.ClassesTreeRoot) + require.Equal(t, stgRoot, proof.GlobalRoots.ContractsTreeRoot) + }) + t.Run("error is returned whenever not latest block is requested", func(t *testing.T) { + t.Parallel() + + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: 1}, nil, nil, nil) + assert.Equal(t, rpc.ErrStorageProofNotSupported, rpcErr) + require.Nil(t, proof) + }) + t.Run("error is returned even when blknum matches head", func(t *testing.T) { + t.Parallel() + + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: blockNumber}, nil, nil, nil) + assert.Equal(t, rpc.ErrStorageProofNotSupported, rpcErr) + require.Nil(t, proof) + }) + t.Run("empty request", func(t *testing.T) { + t.Parallel() + + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 0, 0, 0) + }) + t.Run("class trie hash does not exist in a trie", func(t *testing.T) { + t.Parallel() + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*noSuchKey}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 0, 0, 0) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFunc()) + }) + t.Run("class trie hash exists in a trie", func(t *testing.T) { + t.Parallel() + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 0, 0, 0) + verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFunc()) + }) + t.Run("only unique proof nodes are returned", func(t *testing.T) { + t.Parallel() + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key, *key2}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + + rootNodes := utils.Filter(proof.ClassesProof, func(h *rpc.HashToNode) bool { + return h.Hash.Equal(trieRoot) + }) + require.Len(t, rootNodes, 1) + + // verify we can still prove any of the keys in query + verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFunc()) + verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFunc()) + }) + t.Run("storage trie address does not exist in a trie", func(t *testing.T) { + t.Parallel() + + mockState.EXPECT().ContractNonce(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(0) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*noSuchKey}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 3, 1, 0) + require.Nil(t, proof.ContractsProof.LeavesData[0]) + + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFunc()) + }) + t.Run("storage trie address exists in a trie", func(t *testing.T) { + t.Parallel() + + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil).Times(1) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 3, 1, 0) + + require.NotNil(t, proof.ContractsProof.LeavesData[0]) + ld := proof.ContractsProof.LeavesData[0] + require.Equal(t, nonce, ld.Nonce) + require.Equal(t, classHasah, ld.ClassHash) + + verifyIf(t, trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFunc()) + }) + t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { + t.Parallel() + + contract := utils.HexToFelt(t, "0xdead") + mockTrie.EXPECT().StorageTrieForAddr(contract).Return(emptyTrie(t), nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 0) + }) + //nolint:dupl + t.Run("contract storage trie key slot does not exist in a trie", func(t *testing.T) { + t.Parallel() + + contract := utils.HexToFelt(t, "0xabcd") + mockTrie.EXPECT().StorageTrieForAddr(contract).Return(tempTrie, nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*noSuchKey}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 3) + + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFunc()) + }) + //nolint:dupl + t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { + t.Parallel() + + contract := utils.HexToFelt(t, "0xadd0") + mockTrie.EXPECT().StorageTrieForAddr(contract).Return(tempTrie, nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 3) + + verifyIf(t, trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFunc()) + }) + t.Run("class & storage tries proofs requested", func(t *testing.T) { + t.Parallel() + + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 3, 1, 0) + }) +} + +func TestStorageProofErrorHandling(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockTrie := mocks.NewMockTrieReader(mockCtrl) + handler := rpc.New(mockReader, nil, nil, "", utils.NewNopZapLogger()) + nopCloser := func() error { return nil } + + key := new(felt.Felt).SetUint64(1) + blockLatest := rpc.BlockID{Latest: true} + expectedErr := errors.New("expected error") + + t.Run("error handling HeadState", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(nil, nil, expectedErr).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, proof) + require.NotNil(t, rpcErr) + require.Equal(t, "Internal error", rpcErr.Message) + require.Equal(t, expectedErr, rpcErr.Data.(error)) + }) + t.Run("error handling Head()", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).Times(1) + mockReader.EXPECT().Head().Return(nil, expectedErr).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, proof) + require.NotNil(t, rpcErr) + require.Equal(t, "Internal error", rpcErr.Message) + require.Equal(t, expectedErr, rpcErr.Data.(error)) + }) + t.Run("error handling HeadTrie", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).Times(1) + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: new(felt.Felt), Number: 0}}, nil).Times(1) + mockReader.EXPECT().HeadTrie().Return(nil, nil, expectedErr).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, proof) + require.NotNil(t, rpcErr) + require.Equal(t, "Internal error", rpcErr.Message) + require.Equal(t, expectedErr, rpcErr.Data.(error)) + }) + t.Run("error handling StateAndClassRoot", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).Times(1) + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: new(felt.Felt), Number: 0}}, nil).Times(1) + mockReader.EXPECT().HeadTrie().Return(mockTrie, nopCloser, nil).Times(1) + mockTrie.EXPECT().StateAndClassRoot().Return(nil, nil, expectedErr).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, proof) + require.NotNil(t, rpcErr) + require.Equal(t, "Internal error", rpcErr.Message) + require.Equal(t, expectedErr, rpcErr.Data.(error)) + }) + t.Run("error handling getClassesProof", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).Times(1) + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: new(felt.Felt), Number: 0}}, nil).Times(1) + mockReader.EXPECT().HeadTrie().Return(mockTrie, nopCloser, nil).Times(1) + mockTrie.EXPECT().StateAndClassRoot().Return(new(felt.Felt), new(felt.Felt), nil).Times(1) + mockTrie.EXPECT().ClassTrie().Return(nil, nil, expectedErr).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, proof) + require.NotNil(t, rpcErr) + require.Equal(t, "Internal error", rpcErr.Message) + require.Equal(t, expectedErr, rpcErr.Data.(error)) + }) + t.Run("error handling getContractsProof", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).Times(1) + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: new(felt.Felt), Number: 0}}, nil).Times(1) + mockReader.EXPECT().HeadTrie().Return(mockTrie, nopCloser, nil).Times(1) + mockTrie.EXPECT().StateAndClassRoot().Return(new(felt.Felt), new(felt.Felt), nil).Times(1) + mockTrie.EXPECT().ClassTrie().Return(new(trie.Trie), nopCloser, nil).Times(1) + mockTrie.EXPECT().StorageTrie().Return(nil, nil, expectedErr).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*key}, nil) + require.Nil(t, proof) + require.NotNil(t, rpcErr) + require.Equal(t, "Internal error", rpcErr.Message) + require.Equal(t, expectedErr, rpcErr.Data.(error)) + }) + t.Run("error handling getContractsStorageProofs", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).Times(1) + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: new(felt.Felt), Number: 0}}, nil).Times(1) + mockReader.EXPECT().HeadTrie().Return(mockTrie, nopCloser, nil).Times(1) + mockTrie.EXPECT().StateAndClassRoot().Return(new(felt.Felt), new(felt.Felt), nil).Times(1) + mockTrie.EXPECT().ClassTrie().Return(new(trie.Trie), nopCloser, nil).Times(1) + mockTrie.EXPECT().StorageTrie().Return(new(trie.Trie), nopCloser, nil).Times(1) + mockTrie.EXPECT().StorageTrieForAddr(key).Return(nil, expectedErr).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, []rpc.StorageKeys{{Contract: *key, Keys: []felt.Felt{*key}}}) + require.Nil(t, proof) + require.NotNil(t, rpcErr) + require.Equal(t, "Internal error", rpcErr.Message) + require.Equal(t, expectedErr, rpcErr.Data.(error)) + }) +} + +func TestStorageRoots(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + client := feeder.NewTestClient(t, &utils.Mainnet) + gw := adaptfeeder.New(client) + + log := utils.NewNopZapLogger() + testDB := pebble.NewMemTest(t) + bc := blockchain.New(testDB, &utils.Mainnet, nil) + synchronizer := sync.New(bc, gw, log, time.Duration(0), false, testDB) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + + require.NoError(t, synchronizer.Run(ctx)) + cancel() + + var ( + expectedBlockHash = utils.HexToFelt(t, "0x4e1f77f39545afe866ac151ac908bd1a347a2a8a7d58bef1276db4f06fdf2f6") + expectedGlobalRoot = utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9") + expectedClsRoot = utils.HexToFelt(t, "0x0") + expectedStgRoot = utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9") + expectedContractAddress = utils.HexToFelt(t, "0x2d6c9569dea5f18628f1ef7c15978ee3093d2d3eec3b893aac08004e678ead3") + expectedContractLeaf = utils.HexToFelt(t, "0x7036d8dd68dc9539c6db8c88f72b1ab16e76d62b5f09118eca5ae78276b0ee4") + ) + + t.Run("sanity check - mainnet block 2", func(t *testing.T) { + t.Parallel() + + expectedBlockNumber := uint64(2) + + blk, err := bc.Head() + assert.NoError(t, err) + assert.Equal(t, expectedBlockNumber, blk.Number) + assert.Equal(t, expectedBlockHash, blk.Hash, blk.Hash.String()) + assert.Equal(t, expectedGlobalRoot, blk.GlobalStateRoot, blk.GlobalStateRoot.String()) + }) + + t.Run("check class and storage roots matches the global", func(t *testing.T) { + t.Parallel() + + reader, closer, err := bc.HeadTrie() + assert.NoError(t, err) + defer func() { _ = closer() }() + + stgRoot, clsRoot, err := reader.StateAndClassRoot() + assert.NoError(t, err) + + assert.Equal(t, expectedClsRoot, clsRoot, clsRoot.String()) + assert.Equal(t, expectedStgRoot, stgRoot, stgRoot.String()) + + verifyGlobalStateRoot(t, expectedGlobalRoot, clsRoot, stgRoot) + }) + + t.Run("check requested contract and storage slot exists", func(t *testing.T) { + t.Parallel() + + trieReader, closer, err := bc.HeadTrie() + assert.NoError(t, err) + defer func() { _ = closer() }() + + sTrie, sCloser, err := trieReader.StorageTrie() + assert.NoError(t, err) + defer func() { _ = sCloser() }() + + leaf, err := sTrie.Get(expectedContractAddress) + assert.NoError(t, err) + assert.Equal(t, leaf, expectedContractLeaf, leaf.String()) + + stateReader, stCloser, err := bc.HeadState() + assert.NoError(t, err) + defer func() { _ = stCloser() }() + + clsHash, err := stateReader.ContractClassHash(expectedContractAddress) + assert.NoError(t, err) + assert.Equal(t, clsHash, utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), clsHash.String()) + }) + + t.Run("get contract proof", func(t *testing.T) { + t.Parallel() + + handler := rpc.New(bc, nil, nil, "", log) + result, rpcErr := handler.StorageProof( + rpc.BlockID{Latest: true}, nil, []felt.Felt{*expectedContractAddress}, nil) + require.Nil(t, rpcErr) + + expectedResult := rpc.StorageProofResult{ + ClassesProof: []*rpc.HashToNode{}, + ContractsStorageProofs: [][]*rpc.HashToNode{}, + ContractsProof: &rpc.ContractProof{ + LeavesData: []*rpc.LeafData{ + { + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), + }, + }, + Nodes: []*rpc.HashToNode{ + { + Hash: utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x4e1f289e55ac8a821fd463478e6f5543256beb934a871be91d00a0d3f2e7964"), + Right: utils.HexToFelt(t, "0x67d9833b51e7bf1cab0e71e68477bf7f0b704391d753f9d793008e4f6587c53"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x4e1f289e55ac8a821fd463478e6f5543256beb934a871be91d00a0d3f2e7964"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x1ef87d62309ff1cad58d39e8f5480f9caa9acd78a43f139d87220a1babe38a4"), + Right: utils.HexToFelt(t, "0x9a258d24b3aeb7e263e910d68a18d85305703a2f20df2e806ecbb1fb28760f"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x9a258d24b3aeb7e263e910d68a18d85305703a2f20df2e806ecbb1fb28760f"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x53f61d0cb8099e2e7ffc214c4ef7ac8520abb5327510f84affe90b1890d314c"), + Right: utils.HexToFelt(t, "0x45ca67f381dcd01fec774743a4aaed6b36e1bda979185cf5dce538ad0007914"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x53f61d0cb8099e2e7ffc214c4ef7ac8520abb5327510f84affe90b1890d314c"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x17d6fc8431c48e41222a3ede441d1e2d91c31eb67a8aa9c030c99c510e9f34c"), + Right: utils.HexToFelt(t, "0x1cf95259ae39c038e87224fa5fdb7c7eeba6dd4263e05e80c9a8e27c3240f2c"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x1cf95259ae39c038e87224fa5fdb7c7eeba6dd4263e05e80c9a8e27c3240f2c"), + Node: &rpc.MerkleEdgeNode{ + Path: "0x56c9569dea5f18628f1ef7c15978ee3093d2d3eec3b893aac08004e678ead3", + Length: 247, + Child: expectedContractLeaf, + }, + }, + }, + }, + GlobalRoots: &rpc.GlobalRoots{ + BlockHash: expectedBlockHash, + ClassesTreeRoot: expectedClsRoot, + ContractsTreeRoot: expectedStgRoot, + }, + } + + assert.Equal(t, expectedResult, *result) + }) +} + +func TestVerifyPathfinderResponse(t *testing.T) { + t.Parallel() + + // Pathfinder response for query: + // "method": "starknet_getStorageProof", + // "params": [ + // "latest", + // [], + // [ + // "0x5a03b82d726f9bb31ba41ea3a0c1143f90241e37c9a4a92174d168cda9c716d", + // "0x5fbaa249500be29fee38fdd90a7a2651a8d3935c14167570f6863f563d838f0" + // ] + // ], + // Sepolia, at block 10434 + result := rpc.StorageProofResult{ + ClassesProof: []*rpc.HashToNode{}, + ContractsProof: &rpc.ContractProof{ + LeavesData: []*rpc.LeafData{ + { + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x772164c9d6179a89e7f1167f099219f47d752304b16ed01f081b6e0b45c93c3"), + }, + { + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x78401746828463e2c3f92ebb261fc82f7d4d4c8d9a80a356c44580dab124cb0"), + }, + }, + Nodes: []*rpc.HashToNode{ + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x5c6be09d8faaa42a8525898b1047cebdd3526349b48decc2b767a4fa612263d"), + Right: utils.HexToFelt(t, "0xcd11aa7699c4157a287e5fe574df37e40c8b6a5ed5e1aee658fc2d634398ef"), + }, + Hash: utils.HexToFelt(t, "0x7884784e689e733c1ea2c4ee3b1f790c4ca4992b26d8aee31abb5d9270d4947"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x1cdf395ebbba2f3a6234ad9827b08453a4a0b7745e2d919fe7b07749efa5325"), + Right: utils.HexToFelt(t, "0xcdd37cf6cce8bc373e2c9d8d6754b057275ddd910a9d133b4d31086632d0f4"), + }, + Hash: utils.HexToFelt(t, "0x44fcfce222b7e5a098346615dc838d8ae90ff55da82db7cdce4303f34042ff6"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x2c55bc287a1b31a405c681c2bb720811dd9f33523241561ea4b356f717ff9f6"), + Right: utils.HexToFelt(t, "0x2012025c00174e3eb72baba21e58a56e5114e571f64cb1040f7de0c8daef618"), + }, + Hash: utils.HexToFelt(t, "0x7f2b62cf9713a0b635b967c2e2891282631519eebca6ea0bddaa1a1a804919f"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x211a80e63ac0b12b29279c3d57ea5771b5003ea464b055aeb8ad8618ff3cd69"), + Right: utils.HexToFelt(t, "0x44f55356be17913dcd79e0bb4dbc986d0642bb3f000e540bb54bfa2d4189a74"), + }, + Hash: utils.HexToFelt(t, "0x69e208899d9deeae0732e95ce9d68d123abd9b59f157435fc3554e1fa3a92a8"), + }, + { + Node: &rpc.MerkleEdgeNode{ + Child: utils.HexToFelt(t, "0x6b45780618ce075fb4543396b3a6949915c04962b2e411c4f1b2a6813d540da"), + Length: 239, + Path: "0x3b82d726f9bb31ba41ea3a0c1143f90241e37c9a4a92174d168cda9c716d", + }, + Hash: utils.HexToFelt(t, "0x2c55bc287a1b31a405c681c2bb720811dd9f33523241561ea4b356f717ff9f6"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x7be97a0f8a99126208712673c69c292a26273707c884e96e17c761ee7097ae5"), + Right: utils.HexToFelt(t, "0x3ae1731f598d03a9033c6f5d29871cd5a80c4eba36a7a0a73775ea9d8d522f3"), + }, + Hash: utils.HexToFelt(t, "0xcd11aa7699c4157a287e5fe574df37e40c8b6a5ed5e1aee658fc2d634398ef"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x7f2b62cf9713a0b635b967c2e2891282631519eebca6ea0bddaa1a1a804919f"), + Right: utils.HexToFelt(t, "0x77f807a73f0e7ccad122cd946d79d8f4ce9e02f01017467e7cf4ad993cfa482"), + }, + Hash: utils.HexToFelt(t, "0x326e52c7cba85fedb456bb1c25dda2075ebe3367a329eb297144cb7f8d1f7d9"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x35d32a880d122ffc43a46e280c0ff34a9de286c2cb2e3933229f419a6ceed8e"), + Right: utils.HexToFelt(t, "0x14c9f5368ebbe1cc8d1db2dde1f97d18cabf450bbc23f154985c7e15e15bdcf"), + }, + Hash: utils.HexToFelt(t, "0x1159575d44f9b716f2cfbb13da873f8e7d9824e6b7b615dac5ce9c7b0e2bffd"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x1e5dfbcf23a5e942208f5ccfa25db1147dbfb2984df32a692102851757998cd"), + Right: utils.HexToFelt(t, "0x69e208899d9deeae0732e95ce9d68d123abd9b59f157435fc3554e1fa3a92a8"), + }, + Hash: utils.HexToFelt(t, "0x2722e2a47b3f10db016928bcc7451cd2088a1caea2fbb5f08e1b71dfe1db1c2"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x2634833b52e930231b53d58286647d9818a276dd12ace8286dae63b896c3ba1"), + Right: utils.HexToFelt(t, "0x1f248a8796f18bc9d116e5f3c3956c47e091c05f1c9596453b2fefa2b725507"), + }, + Hash: utils.HexToFelt(t, "0x109e30040b25357cc51726d6041ba1f09ec02dd8b3ca2ffa686a858c9293796"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x7884784e689e733c1ea2c4ee3b1f790c4ca4992b26d8aee31abb5d9270d4947"), + Right: utils.HexToFelt(t, "0x4e354efe4fcc718d3454d532b50cd3c73ac84f05df918981433162c84650f6c"), + }, + Hash: utils.HexToFelt(t, "0x88648f7a7b355914ed41bb28101110cff8fb68f1a9b39958823c72992d8675"), + }, + { + Node: &rpc.MerkleEdgeNode{ + Child: utils.HexToFelt(t, "0x4169679eea4895011fb8e9029b4591a210b3b9e9aa23f12f25cf45cbcaadfe8"), + Length: 1, + Path: "0x1", + }, + Hash: utils.HexToFelt(t, "0x44f55356be17913dcd79e0bb4dbc986d0642bb3f000e540bb54bfa2d4189a74"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x192804e98b1f3fdad2d8fab79bfb922611edc5fb48dcd1e9db02cd46cfa9763"), + Right: utils.HexToFelt(t, "0x4717a5dd5048d62401bc7db57594d3bdbfd3c7b99788a83c5e77b6db9822149"), + }, + Hash: utils.HexToFelt(t, "0x14c9f5368ebbe1cc8d1db2dde1f97d18cabf450bbc23f154985c7e15e15bdcf"), + }, + { + Node: &rpc.MerkleEdgeNode{ + Child: utils.HexToFelt(t, "0x25790175fe1fbeed47cbf510a41fba8676bea20a0c8888d4b9090b8f5cf19b8"), + Length: 238, + Path: "0x2a249500be29fee38fdd90a7a2651a8d3935c14167570f6863f563d838f0", + }, + Hash: utils.HexToFelt(t, "0x331128166378265a07c0be65b242d47d1965e785b6f4f6e1bca3731de5d2d1d"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x331128166378265a07c0be65b242d47d1965e785b6f4f6e1bca3731de5d2d1d"), + Right: utils.HexToFelt(t, "0x12af5e7e95772777d98792be8ade3b18c06ab21aa492a1821d5be3ac291374a"), + }, + Hash: utils.HexToFelt(t, "0x4169679eea4895011fb8e9029b4591a210b3b9e9aa23f12f25cf45cbcaadfe8"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x485b298f33aa076113362f82f4bf64f23e2eb5b84209353a630a46cd20fdde5"), + Right: utils.HexToFelt(t, "0x1159575d44f9b716f2cfbb13da873f8e7d9824e6b7b615dac5ce9c7b0e2bffd"), + }, + Hash: utils.HexToFelt(t, "0x3ae1731f598d03a9033c6f5d29871cd5a80c4eba36a7a0a73775ea9d8d522f3"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x2358473807e0a43a66b918247c0fb0d0649c72a32f19eee8bcc76c090b37951"), + Right: utils.HexToFelt(t, "0x109e30040b25357cc51726d6041ba1f09ec02dd8b3ca2ffa686a858c9293796"), + }, + Hash: utils.HexToFelt(t, "0x485b298f33aa076113362f82f4bf64f23e2eb5b84209353a630a46cd20fdde5"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x326e52c7cba85fedb456bb1c25dda2075ebe3367a329eb297144cb7f8d1f7d9"), + Right: utils.HexToFelt(t, "0x41149879a9d24ba0a2ccfb56415c04bdabb1c51eb0900a17dee2c715d6b1c70"), + }, + Hash: utils.HexToFelt(t, "0x1cdf395ebbba2f3a6234ad9827b08453a4a0b7745e2d919fe7b07749efa5325"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x454a8b3fc492869e79b16e87461d0b5101eb5d25389f492039ef6a380878b39"), + Right: utils.HexToFelt(t, "0x5a99604af4e482d046afe656b6ebe7805c72a1b7979d00608f27b276eb33442"), + }, + Hash: utils.HexToFelt(t, "0x4717a5dd5048d62401bc7db57594d3bdbfd3c7b99788a83c5e77b6db9822149"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x2f6c0e4b8022b48461e54e4f9358c51d5444ae2e2253a31baa68d4cb0c938de"), + Right: utils.HexToFelt(t, "0x88648f7a7b355914ed41bb28101110cff8fb68f1a9b39958823c72992d8675"), + }, + Hash: utils.HexToFelt(t, "0x47182b7d8158a8f80ed15822719aa306af37383a0cf91518d21ba63e73fea13"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x44fcfce222b7e5a098346615dc838d8ae90ff55da82db7cdce4303f34042ff6"), + Right: utils.HexToFelt(t, "0xc3da9c726d244197963a8a7beb4a3aee353b3b663daf2aa1bcf1c087b5e20d"), + }, + Hash: utils.HexToFelt(t, "0x2634833b52e930231b53d58286647d9818a276dd12ace8286dae63b896c3ba1"), + }, + { + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x2722e2a47b3f10db016928bcc7451cd2088a1caea2fbb5f08e1b71dfe1db1c2"), + Right: utils.HexToFelt(t, "0x79c09acd32044c7d455299ca67e2a8fafce25afaf6d5e89ff4632b251dddc8d"), + }, + Hash: utils.HexToFelt(t, "0x5a99604af4e482d046afe656b6ebe7805c72a1b7979d00608f27b276eb33442"), + }, + }, + }, + ContractsStorageProofs: [][]*rpc.HashToNode{}, + GlobalRoots: &rpc.GlobalRoots{ + BlockHash: utils.HexToFelt(t, "0xae4cc763c8b350913e00e12cffd51fb7e3b730e29036864a8afd8ec323ecd6"), + ClassesTreeRoot: utils.HexToFelt(t, "0xea1568e1ca4e5b8c19cdf130dc3194f9cb8e5eee2fa5ec54a338a4dccfd6e3"), + ContractsTreeRoot: utils.HexToFelt(t, "0x47182b7d8158a8f80ed15822719aa306af37383a0cf91518d21ba63e73fea13"), + }, + } + + root := result.GlobalRoots.ContractsTreeRoot + + t.Run("first contract proof verification", func(t *testing.T) { + t.Parallel() + + firstContractAddr := utils.HexToFelt(t, "0x5a03b82d726f9bb31ba41ea3a0c1143f90241e37c9a4a92174d168cda9c716d") + firstContractLeaf := utils.HexToFelt(t, "0x6b45780618ce075fb4543396b3a6949915c04962b2e411c4f1b2a6813d540da") + verifyIf(t, root, firstContractAddr, firstContractLeaf, result.ContractsProof.Nodes, crypto.Pedersen) + }) + + t.Run("second contract proof verification", func(t *testing.T) { + t.Parallel() + + secondContractAddr := utils.HexToFelt(t, "0x5fbaa249500be29fee38fdd90a7a2651a8d3935c14167570f6863f563d838f0") + secondContractLeaf := utils.HexToFelt(t, "0x25790175fe1fbeed47cbf510a41fba8676bea20a0c8888d4b9090b8f5cf19b8") + verifyIf(t, root, secondContractAddr, secondContractLeaf, result.ContractsProof.Nodes, crypto.Pedersen) + }) +} + +func arityTest(t *testing.T, + proof *rpc.StorageProofResult, + classesProofArity int, + contractsProofNodesArity int, + contractsProofLeavesArity int, + contractStorageArity int, +) { + require.Len(t, proof.ClassesProof, classesProofArity) + require.Len(t, proof.ContractsStorageProofs, contractStorageArity) + require.NotNil(t, proof.ContractsProof) + require.Len(t, proof.ContractsProof.Nodes, contractsProofNodesArity) + require.Len(t, proof.ContractsProof.LeavesData, contractsProofLeavesArity) +} + +func emptyTrie(t *testing.T) *trie.Trie { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + return tempTrie +} + +func verifyIf( + t *testing.T, + root, key, value *felt.Felt, + proof []*rpc.HashToNode, + hashF trie.HashFunc, +) { + t.Helper() + + proofSet := trie.NewProofNodeSet() + for _, hn := range proof { + proofSet.Put(*hn.Hash, hn.Node.AsProofNode()) + } + + leaf, err := trie.VerifyProof(root, key, proofSet, hashF) + require.NoError(t, err) + + // non-membership test + if value == nil { + value = felt.Zero.Clone() + } + require.Equal(t, leaf, value) +} + +func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { + stateVersion := new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) + if classRoot.IsZero() { + assert.Equal(t, globalStateRoot, storageRoot) + } else { + assert.Equal(t, globalStateRoot, crypto.PoseidonArray(stateVersion, storageRoot, classRoot)) + } +} diff --git a/utils/slices.go b/utils/slices.go index 789c06ba31..229311d4d2 100644 --- a/utils/slices.go +++ b/utils/slices.go @@ -1,6 +1,10 @@ package utils -import "slices" +import ( + "fmt" + "reflect" + "slices" +) func Map[T1, T2 any](slice []T1, f func(T1) T2) []T2 { if slice == nil { @@ -39,3 +43,30 @@ func AnyOf[T comparable](e T, values ...T) bool { } return false } + +// Unique returns a new slice with duplicates removed +func Unique[T comparable](slice []T) []T { + // check if not used with a pointer type + if len(slice) > 0 { + elt := slice[0] + if reflect.TypeOf(elt).Kind() == reflect.Ptr { + panic(fmt.Sprintf("Unique() cannot be used with a slice of pointers (%T). Use `UniqueFunc()` instead.", elt)) + } + } + + return UniqueFunc(slice, func(t T) T { return t }) +} + +// UniqueFunc returns a new slice with duplicates removed, using a key function +func UniqueFunc[T, K comparable](slice []T, key func(T) K) []T { + var result []T + seen := make(map[K]struct{}) + for _, e := range slice { + k := key(e) + if _, ok := seen[k]; !ok { + result = append(result, e) + seen[k] = struct{}{} + } + } + return result +} diff --git a/utils/slices_test.go b/utils/slices_test.go index 9ef6fcff66..91c98823c6 100644 --- a/utils/slices_test.go +++ b/utils/slices_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMap(t *testing.T) { @@ -73,3 +74,46 @@ func TestAnyOf(t *testing.T) { assert.False(t, AnyOf("9", "1", "2", "3", "4", "5", "6")) }) } + +func TestUnique(t *testing.T) { + t.Run("nil slice", func(t *testing.T) { + var input []int + actual := Unique(input) + assert.Nil(t, actual) + }) + t.Run("empty slice returns nil", func(t *testing.T) { + input := []int{} + actual := Unique(input) + assert.Nil(t, actual) + }) + t.Run("slice with data", func(t *testing.T) { + expected := []int{1, 2, 3} + input := expected + input = append(input, expected...) + actual := Unique(input) + assert.Equal(t, expected, actual) + }) + t.Run("panic when called on pointers", func(t *testing.T) { + defer func() { + r := recover() + assert.NotNil(t, r) + assert.Contains(t, r.(string), "Unique() cannot be used with a slice of pointers") + }() + input := []*int{new(int), new(int)} + Unique(input) + }) + t.Run("with key function", func(t *testing.T) { + type thing struct { + id int + name string + } + + things := []thing{ + {1, "one"}, {1, "two"}, {2, "one"}, {2, "two"}, + } + + require.Len(t, UniqueFunc(things, func(t thing) int { return t.id }), 2) + require.Len(t, UniqueFunc(things, func(t thing) string { return t.name }), 2) + require.Equal(t, things, UniqueFunc(things, func(t thing) thing { return t })) + }) +}