Skip to content

Commit

Permalink
Allow passing ctx in return values of submsg handler (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
codchen authored Jun 13, 2024
1 parent 2034943 commit d14d7fb
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 105 deletions.
68 changes: 35 additions & 33 deletions x/wasm/keeper/handler_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

wasmvmtypes "github.com/CosmWasm/wasmvm/types"
"github.com/cosmos/cosmos-sdk/baseapp"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
Expand All @@ -21,9 +20,11 @@ type msgEncoder interface {
Encode(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) ([]sdk.Msg, error)
}

type MsgHandler = func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error)

// MessageRouter ADR 031 request type routing
type MessageRouter interface {
Handler(msg sdk.Msg) baseapp.MsgServiceHandler
Handler(msg sdk.Msg) MsgHandler
}

// SDKMessageHandler can handles messages that can be encoded into sdk.Message types and routed.
Expand Down Expand Up @@ -59,16 +60,17 @@ func NewSDKMessageHandler(router MessageRouter, encoders msgEncoder) SDKMessageH
}
}

func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
sdkMsgs, err := h.encoders.Encode(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
if err != nil {
return nil, nil, err
return ctx, nil, nil, err
}
for _, sdkMsg := range sdkMsgs {
res, err := h.handleSdkMessage(ctx, contractAddr, sdkMsg)
resCtx, res, err := h.handleSdkMessage(ctx, contractAddr, sdkMsg)
if err != nil {
return nil, nil, err
return ctx, nil, nil, err
}
ctx = resCtx
// append data
data = append(data, res.Data)
// append events
Expand All @@ -81,29 +83,29 @@ func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddr
return
}

func (h SDKMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Address, msg sdk.Msg) (*sdk.Result, error) {
func (h SDKMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Address, msg sdk.Msg) (sdk.Context, *sdk.Result, error) {
if err := msg.ValidateBasic(); err != nil {
return nil, err
return ctx, nil, err
}
// make sure this account can send it
for _, acct := range msg.GetSigners() {
if !acct.Equals(contractAddr) {
return nil, sdkerrors.Wrap(sdkerrors.ErrUnauthorized, "contract doesn't have permission")
return ctx, nil, sdkerrors.Wrap(sdkerrors.ErrUnauthorized, "contract doesn't have permission")
}
}

// find the handler and execute it
if handler := h.router.Handler(msg); handler != nil {
// ADR 031 request type routing
msgResult, err := handler(ctx, msg)
return msgResult, err
resCtx, msgResult, err := handler(ctx, msg)
return resCtx, msgResult, err
}
// legacy sdk.Msg routing
// Assuming that the app developer has migrated all their Msgs to
// proto messages and has registered all `Msg services`, then this
// path should never be called, because all those Msgs should be
// registered within the `msgServiceRouter` already.
return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "can't route message %+v", msg)
return ctx, nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "can't route message %+v", msg)
}

// MessageHandlerChain defines a chain of handlers that are called one by one until it can be handled.
Expand All @@ -125,19 +127,19 @@ func NewMessageHandlerChain(first Messenger, others ...Messenger) *MessageHandle
// order to find the right one to process given message. If a handler cannot
// process given message (returns ErrUnknownMsg), its result is ignored and the
// next handler is executed.
func (m MessageHandlerChain) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) ([]sdk.Event, [][]byte, error) {
func (m MessageHandlerChain) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (sdk.Context, []sdk.Event, [][]byte, error) {
for _, h := range m.handlers {
events, data, err := h.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
resCtx, events, data, err := h.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
switch {
case err == nil:
return events, data, nil
return resCtx, events, data, nil
case errors.Is(err, types.ErrUnknownMsg):
continue
default:
return events, data, err
return ctx, events, data, err
}
}
return nil, nil, sdkerrors.Wrap(types.ErrUnknownMsg, "no handler found")
return ctx, nil, nil, sdkerrors.Wrap(types.ErrUnknownMsg, "no handler found")
}

// IBCRawPacketHandler handels IBC.SendPacket messages which are published to an IBC channel.
Expand All @@ -151,32 +153,32 @@ func NewIBCRawPacketHandler(chk types.ChannelKeeper, cak types.CapabilityKeeper)
}

// DispatchMsg publishes a raw IBC packet onto the channel.
func (h IBCRawPacketHandler) DispatchMsg(ctx sdk.Context, _ sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
func (h IBCRawPacketHandler) DispatchMsg(ctx sdk.Context, _ sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
if msg.IBC == nil || msg.IBC.SendPacket == nil {
return nil, nil, types.ErrUnknownMsg
return ctx, nil, nil, types.ErrUnknownMsg
}
if contractIBCPortID == "" {
return nil, nil, sdkerrors.Wrapf(types.ErrUnsupportedForContract, "ibc not supported")
return ctx, nil, nil, sdkerrors.Wrapf(types.ErrUnsupportedForContract, "ibc not supported")
}
contractIBCChannelID := msg.IBC.SendPacket.ChannelID
if contractIBCChannelID == "" {
return nil, nil, sdkerrors.Wrapf(types.ErrEmpty, "ibc channel")
return ctx, nil, nil, sdkerrors.Wrapf(types.ErrEmpty, "ibc channel")
}

sequence, found := h.channelKeeper.GetNextSequenceSend(ctx, contractIBCPortID, contractIBCChannelID)
if !found {
return nil, nil, sdkerrors.Wrapf(channeltypes.ErrSequenceSendNotFound,
return ctx, nil, nil, sdkerrors.Wrapf(channeltypes.ErrSequenceSendNotFound,
"source port: %s, source channel: %s", contractIBCPortID, contractIBCChannelID,
)
}

channelInfo, ok := h.channelKeeper.GetChannel(ctx, contractIBCPortID, contractIBCChannelID)
if !ok {
return nil, nil, sdkerrors.Wrap(channeltypes.ErrInvalidChannel, "not found")
return ctx, nil, nil, sdkerrors.Wrap(channeltypes.ErrInvalidChannel, "not found")
}
channelCap, ok := h.capabilityKeeper.GetCapability(ctx, host.ChannelCapabilityPath(contractIBCPortID, contractIBCChannelID))
if !ok {
return nil, nil, sdkerrors.Wrap(channeltypes.ErrChannelCapabilityNotFound, "module does not own channel capability")
return ctx, nil, nil, sdkerrors.Wrap(channeltypes.ErrChannelCapabilityNotFound, "module does not own channel capability")
}
packet := channeltypes.NewPacket(
msg.IBC.SendPacket.Data,
Expand All @@ -188,36 +190,36 @@ func (h IBCRawPacketHandler) DispatchMsg(ctx sdk.Context, _ sdk.AccAddress, cont
ConvertWasmIBCTimeoutHeightToCosmosHeight(msg.IBC.SendPacket.Timeout.Block),
msg.IBC.SendPacket.Timeout.Timestamp,
)
return nil, nil, h.channelKeeper.SendPacket(ctx, channelCap, packet)
return ctx, nil, nil, h.channelKeeper.SendPacket(ctx, channelCap, packet)
}

var _ Messenger = MessageHandlerFunc(nil)

// MessageHandlerFunc is a helper to construct a function based message handler.
type MessageHandlerFunc func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error)
type MessageHandlerFunc func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error)

// DispatchMsg delegates dispatching of provided message into the MessageHandlerFunc.
func (m MessageHandlerFunc) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
func (m MessageHandlerFunc) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
return m(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
}

// NewBurnCoinMessageHandler handles wasmvm.BurnMsg messages
func NewBurnCoinMessageHandler(burner types.Burner) MessageHandlerFunc {
return func(ctx sdk.Context, contractAddr sdk.AccAddress, _ string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
return func(ctx sdk.Context, contractAddr sdk.AccAddress, _ string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
if msg.Bank != nil && msg.Bank.Burn != nil {
coins, err := ConvertWasmCoinsToSdkCoins(msg.Bank.Burn.Amount)
if err != nil {
return nil, nil, err
return ctx, nil, nil, err
}
if err := burner.SendCoinsFromAccountToModule(ctx, contractAddr, types.ModuleName, coins); err != nil {
return nil, nil, sdkerrors.Wrap(err, "transfer to module")
return ctx, nil, nil, sdkerrors.Wrap(err, "transfer to module")
}
if err := burner.BurnCoins(ctx, types.ModuleName, coins); err != nil {
return nil, nil, sdkerrors.Wrap(err, "burn coins")
return ctx, nil, nil, sdkerrors.Wrap(err, "burn coins")
}
moduleLogger(ctx).Info("Burned", "amount", coins)
return nil, nil, nil
return ctx, nil, nil, nil
}
return nil, nil, types.ErrUnknownMsg
return ctx, nil, nil, types.ErrUnknownMsg
}
}
31 changes: 15 additions & 16 deletions x/wasm/keeper/handler_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

wasmvm "github.com/CosmWasm/wasmvm"
wasmvmtypes "github.com/CosmWasm/wasmvm/types"
"github.com/cosmos/cosmos-sdk/baseapp"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
Expand All @@ -25,13 +24,13 @@ func TestMessageHandlerChainDispatch(t *testing.T) {
capturingHandler, gotMsgs := wasmtesting.NewCapturingMessageHandler()

alwaysUnknownMsgHandler := &wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
return nil, nil, types.ErrUnknownMsg
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
return ctx, nil, nil, types.ErrUnknownMsg
},
}

assertNotCalledHandler := &wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
t.Fatal("not expected to be called")
return
},
Expand All @@ -54,18 +53,18 @@ func TestMessageHandlerChainDispatch(t *testing.T) {
},
"stops iteration on handler error": {
handlers: []Messenger{&wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
return nil, nil, types.ErrInvalidMsg
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
return ctx, nil, nil, types.ErrInvalidMsg
},
}, assertNotCalledHandler},
expErr: types.ErrInvalidMsg,
},
"return events when handle": {
handlers: []Messenger{
&wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
_, data, _ = capturingHandler.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
return []sdk.Event{sdk.NewEvent("myEvent", sdk.NewAttribute("foo", "bar"))}, data, nil
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
resCtx, _, data, _ = capturingHandler.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
return resCtx, []sdk.Event{sdk.NewEvent("myEvent", sdk.NewAttribute("foo", "bar"))}, data, nil
},
},
},
Expand All @@ -82,7 +81,7 @@ func TestMessageHandlerChainDispatch(t *testing.T) {

// when
h := MessageHandlerChain{spec.handlers}
gotEvents, gotData, gotErr := h.DispatchMsg(sdk.Context{}, RandomAccountAddress(t), "anyPort", myMsg, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
_, gotEvents, gotData, gotErr := h.DispatchMsg(sdk.Context{}, RandomAccountAddress(t), "anyPort", myMsg, wasmvmtypes.MessageInfo{}, types.CodeInfo{})

// then
require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr)
Expand All @@ -105,13 +104,13 @@ func TestSDKMessageHandlerDispatch(t *testing.T) {
}

var gotMsg []sdk.Msg
capturingMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) baseapp.MsgServiceHandler {
return func(ctx sdk.Context, req sdk.Msg) (*sdk.Result, error) {
capturingMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) MsgHandler {
return func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error) {
gotMsg = append(gotMsg, msg)
return &myRouterResult, nil
return ctx, &myRouterResult, nil
}
})
noRouteMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) baseapp.MsgServiceHandler {
noRouteMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) MsgHandler {
return nil
})
myContractAddr := RandomAccountAddress(t)
Expand Down Expand Up @@ -204,7 +203,7 @@ func TestSDKMessageHandlerDispatch(t *testing.T) {
// when
ctx := sdk.Context{}
h := NewSDKMessageHandler(spec.srcRoute, MessageEncoders{Custom: spec.srcEncoder})
gotEvents, gotData, gotErr := h.DispatchMsg(ctx, myContractAddr, "myPort", myContractMessage, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
_, gotEvents, gotData, gotErr := h.DispatchMsg(ctx, myContractAddr, "myPort", myContractMessage, wasmvmtypes.MessageInfo{}, types.CodeInfo{})

// then
require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr)
Expand Down Expand Up @@ -308,7 +307,7 @@ func TestIBCRawPacketHandler(t *testing.T) {
capturedPacket = nil
// when
h := NewIBCRawPacketHandler(spec.chanKeeper, spec.capKeeper)
data, evts, gotErr := h.DispatchMsg(ctx, RandomAccountAddress(t), ibcPort, wasmvmtypes.CosmosMsg{IBC: &wasmvmtypes.IBCMsg{SendPacket: &spec.srcMsg}}, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
_, data, evts, gotErr := h.DispatchMsg(ctx, RandomAccountAddress(t), ibcPort, wasmvmtypes.CosmosMsg{IBC: &wasmvmtypes.IBCMsg{SendPacket: &spec.srcMsg}}, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
// then
require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr)
if spec.expErr != nil {
Expand Down
18 changes: 16 additions & 2 deletions x/wasm/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/armon/go-metrics"
"github.com/cosmos/cosmos-sdk/baseapp"
"github.com/cosmos/cosmos-sdk/types/address"

wasmvm "github.com/CosmWasm/wasmvm"
Expand Down Expand Up @@ -89,6 +90,18 @@ type Keeper struct {
maxQueryStackSize uint32
}

type routerWithContext struct {
router *baseapp.MsgServiceRouter
}

func (rc routerWithContext) Handler(msg sdk.Msg) MsgHandler {
h := rc.router.Handler(msg)
return func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error) {
result, err := h(ctx, msg)
return ctx, result, err
}
}

// NewKeeper creates a new contract Keeper instance
// If customEncoders is non-nil, we can use this to override some of the message handler, especially custom
func NewKeeper(
Expand All @@ -104,7 +117,7 @@ func NewKeeper(
portKeeper types.PortKeeper,
capabilityKeeper types.CapabilityKeeper,
portSource types.ICS20TransferPortSource,
router MessageRouter,
router *baseapp.MsgServiceRouter,
queryRouter GRPCQueryRouter,
homeDir string,
wasmConfig types.WasmConfig,
Expand All @@ -120,6 +133,7 @@ func NewKeeper(
paramSpace = paramSpace.WithKeyTable(types.ParamKeyTable())
}

routerWithCtx := routerWithContext{router}
keeper := &Keeper{
storeKey: storeKey,
cdc: cdc,
Expand All @@ -129,7 +143,7 @@ func NewKeeper(
bank: NewBankCoinTransferrer(bankKeeper),
portKeeper: portKeeper,
capabilityKeeper: capabilityKeeper,
messenger: NewDefaultMessageHandler(router, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource),
messenger: NewDefaultMessageHandler(routerWithCtx, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource),
queryGasLimit: wasmConfig.SmartQueryGasLimit,
paramSpace: paramSpace,
gasRegister: NewDefaultWasmGasRegister(),
Expand Down
Loading

0 comments on commit d14d7fb

Please sign in to comment.