Skip to content

Commit

Permalink
implement a mempool for the sequencer (#2341)
Browse files Browse the repository at this point in the history
* implement a mempool for the sequencer
  • Loading branch information
rianhughes authored Jan 24, 2025
1 parent 6bf445a commit 61d810a
Show file tree
Hide file tree
Showing 4 changed files with 595 additions and 0 deletions.
4 changes: 4 additions & 0 deletions db/buckets.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ const (
Temporary // used temporarily for migrations
SchemaIntermediateState
L1HandlerTxnHashByMsgHash // maps l1 handler msg hash to l1 handler txn hash
MempoolHead // key of the head node
MempoolTail // key of the tail node
MempoolLength // number of transactions
MempoolNode
)

// Key flattens a prefix and series of byte arrays into a single []byte.
Expand Down
63 changes: 63 additions & 0 deletions mempool/db_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package mempool

import (
"errors"
"math/big"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/db"
"github.com/NethermindEth/juno/encoder"
)

func headValue(txn db.Transaction, head *felt.Felt) error {
return txn.Get(db.MempoolHead.Key(), func(b []byte) error {
head.SetBytes(b)
return nil
})
}

func tailValue(txn db.Transaction, tail *felt.Felt) error {
return txn.Get(db.MempoolTail.Key(), func(b []byte) error {
tail.SetBytes(b)
return nil
})
}

func updateHead(txn db.Transaction, head *felt.Felt) error {
return txn.Set(db.MempoolHead.Key(), head.Marshal())
}

func updateTail(txn db.Transaction, tail *felt.Felt) error {
return txn.Set(db.MempoolTail.Key(), tail.Marshal())
}

func readTxn(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) {
var item dbPoolTxn
keyBytes := itemKey.Bytes()
err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error {
return encoder.Unmarshal(b, &item)
})
return item, err
}

func setTxn(txn db.Transaction, item *dbPoolTxn) error {
itemBytes, err := encoder.Marshal(item)
if err != nil {
return err
}
keyBytes := item.Txn.Transaction.Hash().Bytes()
return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes)
}

func lenDB(txn db.Transaction) (int, error) {
var l int
err := txn.Get(db.MempoolLength.Key(), func(b []byte) error {
l = int(new(big.Int).SetBytes(b).Int64())
return nil
})

if err != nil && errors.Is(err, db.ErrKeyNotFound) {
return 0, nil
}
return l, err
}
302 changes: 302 additions & 0 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
package mempool

import (
"errors"
"fmt"
"math/big"
"sync"

"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/db"
"github.com/NethermindEth/juno/utils"
)

var ErrTxnPoolFull = errors.New("transaction pool is full")

type BroadcastedTransaction struct {
Transaction core.Transaction
DeclaredClass core.Class
}

// runtime mempool txn
type memPoolTxn struct {
Txn BroadcastedTransaction
Next *memPoolTxn
}

// persistent db txn value
type dbPoolTxn struct {
Txn BroadcastedTransaction
NextHash *felt.Felt
}

// memTxnList represents a linked list of user transactions at runtime
type memTxnList struct {
head *memPoolTxn
tail *memPoolTxn
len int
mu sync.Mutex
}

func (t *memTxnList) push(newNode *memPoolTxn) {
t.mu.Lock()
defer t.mu.Unlock()
if t.tail != nil {
t.tail.Next = newNode
t.tail = newNode
} else {
t.head = newNode
t.tail = newNode
}
t.len++
}

func (t *memTxnList) pop() (BroadcastedTransaction, error) {
t.mu.Lock()
defer t.mu.Unlock()

if t.head == nil {
return BroadcastedTransaction{}, errors.New("transaction pool is empty")
}

headNode := t.head
t.head = headNode.Next
if t.head == nil {
t.tail = nil
}
t.len--
return headNode.Txn, nil
}

// Pool represents a blockchain mempool, managing transactions using both an
// in-memory and persistent database.
type Pool struct {
log utils.SimpleLogger
state core.StateReader
db db.DB // to store the persistent mempool
txPushed chan struct{}
memTxnList *memTxnList
maxNumTxns int
dbWriteChan chan *BroadcastedTransaction
wg sync.WaitGroup
}

// New initialises the Pool and starts the database writer goroutine.
// It is the responsibility of the caller to execute the closer function.
func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleLogger) (*Pool, func() error) {
pool := &Pool{
log: log,
state: state,
db: mainDB, // todo: txns should be deleted everytime a new block is stored (builder responsibility)
txPushed: make(chan struct{}, 1),
memTxnList: &memTxnList{},
maxNumTxns: maxNumTxns,
dbWriteChan: make(chan *BroadcastedTransaction, maxNumTxns),
}
closer := func() error {
close(pool.dbWriteChan)
pool.wg.Wait()
if err := pool.db.Close(); err != nil {
return fmt.Errorf("failed to close mempool database: %v", err)
}
return nil
}
pool.dbWriter()
return pool, closer
}

func (p *Pool) dbWriter() {
p.wg.Add(1)
go func() {
defer p.wg.Done()
for txn := range p.dbWriteChan {
err := p.writeToDB(txn)
if err != nil {
p.log.Errorw("error in handling user transaction in persistent mempool", "err", err)
}
}
}()
}

// LoadFromDB restores the in-memory transaction pool from the database
func (p *Pool) LoadFromDB() error {
return p.db.View(func(txn db.Transaction) error {
headVal := new(felt.Felt)
err := headValue(txn, headVal)
if err != nil {
if errors.Is(err, db.ErrKeyNotFound) {
return nil
}
return err
}
// loop through the persistent pool and push nodes to the in-memory pool
currentHash := headVal
for currentHash != nil {
curTxn, err := readTxn(txn, currentHash)
if err != nil {
return err
}
newMemPoolTxn := &memPoolTxn{
Txn: curTxn.Txn,
}
if curTxn.NextHash != nil {
nextDBTxn, err := readTxn(txn, curTxn.NextHash)
if err != nil {
return err
}
newMemPoolTxn.Next = &memPoolTxn{
Txn: nextDBTxn.Txn,
}
}
p.memTxnList.push(newMemPoolTxn)
currentHash = curTxn.NextHash
}
return nil
})
}

// writeToDB adds the transaction to the persistent pool db
func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error {
return p.db.Update(func(dbTxn db.Transaction) error {
tailVal := new(felt.Felt)
if err := tailValue(dbTxn, tailVal); err != nil {
if !errors.Is(err, db.ErrKeyNotFound) {
return err
}
tailVal = nil
}
if err := setTxn(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil {
return err
}
if tailVal != nil {
// Update old tail to point to the new item
var oldTailElem dbPoolTxn
oldTailElem, err := readTxn(dbTxn, tailVal)
if err != nil {
return err
}
oldTailElem.NextHash = userTxn.Transaction.Hash()
if err = setTxn(dbTxn, &oldTailElem); err != nil {
return err
}
} else {
// Empty list, make new item both the head and the tail
if err := updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil {
return err
}
}
if err := updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil {
return err
}
pLen, err := lenDB(dbTxn)
if err != nil {
return err
}
return dbTxn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(pLen+1)).Bytes())
})
}

// Push queues a transaction to the pool
func (p *Pool) Push(userTxn *BroadcastedTransaction) error {
err := p.validate(userTxn)
if err != nil {
return err
}

select {
case p.dbWriteChan <- userTxn:
default:
select {
case _, ok := <-p.dbWriteChan:
if !ok {
p.log.Errorw("cannot store user transasction in persistent pool, database write channel is closed")
}
p.log.Errorw("cannot store user transasction in persistent pool, database is full")
default:
p.log.Errorw("cannot store user transasction in persistent pool, database is full")
}
}

newNode := &memPoolTxn{Txn: *userTxn, Next: nil}
p.memTxnList.push(newNode)

select {
case p.txPushed <- struct{}{}:
default:
}

return nil
}

func (p *Pool) validate(userTxn *BroadcastedTransaction) error {
if p.memTxnList.len+1 >= p.maxNumTxns {
return ErrTxnPoolFull
}

switch t := userTxn.Transaction.(type) {
case *core.DeployTransaction:
return fmt.Errorf("deploy transactions are not supported")
case *core.DeployAccountTransaction:
if !t.Nonce.IsZero() {
return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce)
}
case *core.DeclareTransaction:
nonce, err := p.state.ContractNonce(t.SenderAddress)
if err != nil {
return fmt.Errorf("validation failed, error when retrieving nonce, %v", err)
}
if nonce.Cmp(t.Nonce) > 0 {
return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce)
}
case *core.InvokeTransaction:
if t.TxVersion().Is(0) { // cant verify nonce since SenderAddress was only added in v1
return fmt.Errorf("invoke v0 transactions not supported")
}
nonce, err := p.state.ContractNonce(t.SenderAddress)
if err != nil {
return fmt.Errorf("validation failed, error when retrieving nonce, %v", err)
}
if nonce.Cmp(t.Nonce) > 0 {
return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce)
}
case *core.L1HandlerTransaction:
// todo: verification of the L1 handler nonce requires checking the
// message nonce on the L1 Core Contract.
}
return nil
}

// Pop returns the transaction with the highest priority from the in-memory pool
func (p *Pool) Pop() (BroadcastedTransaction, error) {
return p.memTxnList.pop()
}

// Remove removes a set of transactions from the pool
// todo: should be called by the builder to remove txns from the db everytime a new block is stored.
// todo: in the consensus+p2p world, the txns should also be removed from the in-memory pool.
func (p *Pool) Remove(hash ...*felt.Felt) error {
return errors.New("not implemented")
}

// Len returns the number of transactions in the in-memory pool
func (p *Pool) Len() int {
return p.memTxnList.len
}

func (p *Pool) Wait() <-chan struct{} {
return p.txPushed
}

// Len returns the number of transactions in the persistent pool
func (p *Pool) LenDB() (int, error) {
txn, err := p.db.NewTransaction(false)
if err != nil {
return 0, err
}
lenDB, err := lenDB(txn)
if err != nil {
return 0, err
}
return lenDB, txn.Discard()
}
Loading

0 comments on commit 61d810a

Please sign in to comment.