Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revert revert #59

Merged
merged 4 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions x/wasm/keeper/handler_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

wasmvmtypes "github.com/CosmWasm/wasmvm/types"
"github.com/cosmos/cosmos-sdk/baseapp"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
Expand All @@ -21,9 +20,11 @@ type msgEncoder interface {
Encode(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) ([]sdk.Msg, error)
}

type MsgHandler = func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error)

// MessageRouter ADR 031 request type routing
type MessageRouter interface {
Handler(msg sdk.Msg) baseapp.MsgServiceHandler
Handler(msg sdk.Msg) MsgHandler
}

// SDKMessageHandler can handles messages that can be encoded into sdk.Message types and routed.
Expand Down Expand Up @@ -59,16 +60,18 @@ func NewSDKMessageHandler(router MessageRouter, encoders msgEncoder) SDKMessageH
}
}

func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
sdkMsgs, err := h.encoders.Encode(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
if err != nil {
return nil, nil, err
return ctx, nil, nil, err
}
for _, sdkMsg := range sdkMsgs {
res, err := h.handleSdkMessage(ctx, contractAddr, sdkMsg)
rCtx, res, err := h.handleSdkMessage(ctx, contractAddr, sdkMsg)
if err != nil {
return nil, nil, err
return ctx, nil, nil, err
}
ctx = rCtx
resCtx = rCtx
// append data
data = append(data, res.Data)
// append events
Expand All @@ -81,29 +84,29 @@ func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddr
return
}

func (h SDKMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Address, msg sdk.Msg) (*sdk.Result, error) {
func (h SDKMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Address, msg sdk.Msg) (sdk.Context, *sdk.Result, error) {
if err := msg.ValidateBasic(); err != nil {
return nil, err
return ctx, nil, err
}
// make sure this account can send it
for _, acct := range msg.GetSigners() {
if !acct.Equals(contractAddr) {
return nil, sdkerrors.Wrap(sdkerrors.ErrUnauthorized, "contract doesn't have permission")
return ctx, nil, sdkerrors.Wrap(sdkerrors.ErrUnauthorized, "contract doesn't have permission")
}
}

// find the handler and execute it
if handler := h.router.Handler(msg); handler != nil {
// ADR 031 request type routing
msgResult, err := handler(ctx, msg)
return msgResult, err
resCtx, msgResult, err := handler(ctx, msg)
return resCtx, msgResult, err
}
// legacy sdk.Msg routing
// Assuming that the app developer has migrated all their Msgs to
// proto messages and has registered all `Msg services`, then this
// path should never be called, because all those Msgs should be
// registered within the `msgServiceRouter` already.
return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "can't route message %+v", msg)
return ctx, nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "can't route message %+v", msg)
}

// MessageHandlerChain defines a chain of handlers that are called one by one until it can be handled.
Expand All @@ -125,19 +128,19 @@ func NewMessageHandlerChain(first Messenger, others ...Messenger) *MessageHandle
// order to find the right one to process given message. If a handler cannot
// process given message (returns ErrUnknownMsg), its result is ignored and the
// next handler is executed.
func (m MessageHandlerChain) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) ([]sdk.Event, [][]byte, error) {
func (m MessageHandlerChain) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (sdk.Context, []sdk.Event, [][]byte, error) {
for _, h := range m.handlers {
events, data, err := h.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
resCtx, events, data, err := h.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
switch {
case err == nil:
return events, data, nil
return resCtx, events, data, nil
case errors.Is(err, types.ErrUnknownMsg):
continue
default:
return events, data, err
return ctx, events, data, err
}
}
return nil, nil, sdkerrors.Wrap(types.ErrUnknownMsg, "no handler found")
return ctx, nil, nil, sdkerrors.Wrap(types.ErrUnknownMsg, "no handler found")
}

// IBCRawPacketHandler handels IBC.SendPacket messages which are published to an IBC channel.
Expand All @@ -151,32 +154,32 @@ func NewIBCRawPacketHandler(chk types.ChannelKeeper, cak types.CapabilityKeeper)
}

// DispatchMsg publishes a raw IBC packet onto the channel.
func (h IBCRawPacketHandler) DispatchMsg(ctx sdk.Context, _ sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
func (h IBCRawPacketHandler) DispatchMsg(ctx sdk.Context, _ sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
if msg.IBC == nil || msg.IBC.SendPacket == nil {
return nil, nil, types.ErrUnknownMsg
return ctx, nil, nil, types.ErrUnknownMsg
}
if contractIBCPortID == "" {
return nil, nil, sdkerrors.Wrapf(types.ErrUnsupportedForContract, "ibc not supported")
return ctx, nil, nil, sdkerrors.Wrapf(types.ErrUnsupportedForContract, "ibc not supported")
}
contractIBCChannelID := msg.IBC.SendPacket.ChannelID
if contractIBCChannelID == "" {
return nil, nil, sdkerrors.Wrapf(types.ErrEmpty, "ibc channel")
return ctx, nil, nil, sdkerrors.Wrapf(types.ErrEmpty, "ibc channel")
}

sequence, found := h.channelKeeper.GetNextSequenceSend(ctx, contractIBCPortID, contractIBCChannelID)
if !found {
return nil, nil, sdkerrors.Wrapf(channeltypes.ErrSequenceSendNotFound,
return ctx, nil, nil, sdkerrors.Wrapf(channeltypes.ErrSequenceSendNotFound,
"source port: %s, source channel: %s", contractIBCPortID, contractIBCChannelID,
)
}

channelInfo, ok := h.channelKeeper.GetChannel(ctx, contractIBCPortID, contractIBCChannelID)
if !ok {
return nil, nil, sdkerrors.Wrap(channeltypes.ErrInvalidChannel, "not found")
return ctx, nil, nil, sdkerrors.Wrap(channeltypes.ErrInvalidChannel, "not found")
}
channelCap, ok := h.capabilityKeeper.GetCapability(ctx, host.ChannelCapabilityPath(contractIBCPortID, contractIBCChannelID))
if !ok {
return nil, nil, sdkerrors.Wrap(channeltypes.ErrChannelCapabilityNotFound, "module does not own channel capability")
return ctx, nil, nil, sdkerrors.Wrap(channeltypes.ErrChannelCapabilityNotFound, "module does not own channel capability")
}
packet := channeltypes.NewPacket(
msg.IBC.SendPacket.Data,
Expand All @@ -188,36 +191,36 @@ func (h IBCRawPacketHandler) DispatchMsg(ctx sdk.Context, _ sdk.AccAddress, cont
ConvertWasmIBCTimeoutHeightToCosmosHeight(msg.IBC.SendPacket.Timeout.Block),
msg.IBC.SendPacket.Timeout.Timestamp,
)
return nil, nil, h.channelKeeper.SendPacket(ctx, channelCap, packet)
return ctx, nil, nil, h.channelKeeper.SendPacket(ctx, channelCap, packet)
}

var _ Messenger = MessageHandlerFunc(nil)

// MessageHandlerFunc is a helper to construct a function based message handler.
type MessageHandlerFunc func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error)
type MessageHandlerFunc func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error)

// DispatchMsg delegates dispatching of provided message into the MessageHandlerFunc.
func (m MessageHandlerFunc) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
func (m MessageHandlerFunc) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
return m(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
}

// NewBurnCoinMessageHandler handles wasmvm.BurnMsg messages
func NewBurnCoinMessageHandler(burner types.Burner) MessageHandlerFunc {
return func(ctx sdk.Context, contractAddr sdk.AccAddress, _ string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
return func(ctx sdk.Context, contractAddr sdk.AccAddress, _ string, msg wasmvmtypes.CosmosMsg, _ wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
if msg.Bank != nil && msg.Bank.Burn != nil {
coins, err := ConvertWasmCoinsToSdkCoins(msg.Bank.Burn.Amount)
if err != nil {
return nil, nil, err
return ctx, nil, nil, err
}
if err := burner.SendCoinsFromAccountToModule(ctx, contractAddr, types.ModuleName, coins); err != nil {
return nil, nil, sdkerrors.Wrap(err, "transfer to module")
return ctx, nil, nil, sdkerrors.Wrap(err, "transfer to module")
}
if err := burner.BurnCoins(ctx, types.ModuleName, coins); err != nil {
return nil, nil, sdkerrors.Wrap(err, "burn coins")
return ctx, nil, nil, sdkerrors.Wrap(err, "burn coins")
}
moduleLogger(ctx).Info("Burned", "amount", coins)
return nil, nil, nil
return ctx, nil, nil, nil
}
return nil, nil, types.ErrUnknownMsg
return ctx, nil, nil, types.ErrUnknownMsg
}
}
31 changes: 15 additions & 16 deletions x/wasm/keeper/handler_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

wasmvm "github.com/CosmWasm/wasmvm"
wasmvmtypes "github.com/CosmWasm/wasmvm/types"
"github.com/cosmos/cosmos-sdk/baseapp"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
Expand All @@ -25,13 +24,13 @@ func TestMessageHandlerChainDispatch(t *testing.T) {
capturingHandler, gotMsgs := wasmtesting.NewCapturingMessageHandler()

alwaysUnknownMsgHandler := &wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
return nil, nil, types.ErrUnknownMsg
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
return ctx, nil, nil, types.ErrUnknownMsg
},
}

assertNotCalledHandler := &wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
t.Fatal("not expected to be called")
return
},
Expand All @@ -54,18 +53,18 @@ func TestMessageHandlerChainDispatch(t *testing.T) {
},
"stops iteration on handler error": {
handlers: []Messenger{&wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
return nil, nil, types.ErrInvalidMsg
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, _ types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
return ctx, nil, nil, types.ErrInvalidMsg
},
}, assertNotCalledHandler},
expErr: types.ErrInvalidMsg,
},
"return events when handle": {
handlers: []Messenger{
&wasmtesting.MockMessageHandler{
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (events []sdk.Event, data [][]byte, err error) {
_, data, _ = capturingHandler.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
return []sdk.Event{sdk.NewEvent("myEvent", sdk.NewAttribute("foo", "bar"))}, data, nil
DispatchMsgFn: func(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg, info wasmvmtypes.MessageInfo, codeInfo types.CodeInfo) (resCtx sdk.Context, events []sdk.Event, data [][]byte, err error) {
resCtx, _, data, _ = capturingHandler.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg, info, codeInfo)
return resCtx, []sdk.Event{sdk.NewEvent("myEvent", sdk.NewAttribute("foo", "bar"))}, data, nil
},
},
},
Expand All @@ -82,7 +81,7 @@ func TestMessageHandlerChainDispatch(t *testing.T) {

// when
h := MessageHandlerChain{spec.handlers}
gotEvents, gotData, gotErr := h.DispatchMsg(sdk.Context{}, RandomAccountAddress(t), "anyPort", myMsg, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
_, gotEvents, gotData, gotErr := h.DispatchMsg(sdk.Context{}, RandomAccountAddress(t), "anyPort", myMsg, wasmvmtypes.MessageInfo{}, types.CodeInfo{})

// then
require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr)
Expand All @@ -105,13 +104,13 @@ func TestSDKMessageHandlerDispatch(t *testing.T) {
}

var gotMsg []sdk.Msg
capturingMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) baseapp.MsgServiceHandler {
return func(ctx sdk.Context, req sdk.Msg) (*sdk.Result, error) {
capturingMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) MsgHandler {
return func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error) {
gotMsg = append(gotMsg, msg)
return &myRouterResult, nil
return ctx, &myRouterResult, nil
}
})
noRouteMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) baseapp.MsgServiceHandler {
noRouteMessageRouter := wasmtesting.MessageRouterFunc(func(msg sdk.Msg) MsgHandler {
return nil
})
myContractAddr := RandomAccountAddress(t)
Expand Down Expand Up @@ -204,7 +203,7 @@ func TestSDKMessageHandlerDispatch(t *testing.T) {
// when
ctx := sdk.Context{}
h := NewSDKMessageHandler(spec.srcRoute, MessageEncoders{Custom: spec.srcEncoder})
gotEvents, gotData, gotErr := h.DispatchMsg(ctx, myContractAddr, "myPort", myContractMessage, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
_, gotEvents, gotData, gotErr := h.DispatchMsg(ctx, myContractAddr, "myPort", myContractMessage, wasmvmtypes.MessageInfo{}, types.CodeInfo{})

// then
require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr)
Expand Down Expand Up @@ -308,7 +307,7 @@ func TestIBCRawPacketHandler(t *testing.T) {
capturedPacket = nil
// when
h := NewIBCRawPacketHandler(spec.chanKeeper, spec.capKeeper)
data, evts, gotErr := h.DispatchMsg(ctx, RandomAccountAddress(t), ibcPort, wasmvmtypes.CosmosMsg{IBC: &wasmvmtypes.IBCMsg{SendPacket: &spec.srcMsg}}, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
_, data, evts, gotErr := h.DispatchMsg(ctx, RandomAccountAddress(t), ibcPort, wasmvmtypes.CosmosMsg{IBC: &wasmvmtypes.IBCMsg{SendPacket: &spec.srcMsg}}, wasmvmtypes.MessageInfo{}, types.CodeInfo{})
// then
require.True(t, spec.expErr.Is(gotErr), "exp %v but got %#+v", spec.expErr, gotErr)
if spec.expErr != nil {
Expand Down
18 changes: 16 additions & 2 deletions x/wasm/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/armon/go-metrics"
"github.com/cosmos/cosmos-sdk/baseapp"
"github.com/cosmos/cosmos-sdk/types/address"

wasmvm "github.com/CosmWasm/wasmvm"
Expand Down Expand Up @@ -89,6 +90,18 @@ type Keeper struct {
maxQueryStackSize uint32
}

type routerWithContext struct {
router *baseapp.MsgServiceRouter
}

func (rc routerWithContext) Handler(msg sdk.Msg) MsgHandler {
h := rc.router.Handler(msg)
return func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error) {
result, err := h(ctx, msg)
return ctx, result, err
}
}

// NewKeeper creates a new contract Keeper instance
// If customEncoders is non-nil, we can use this to override some of the message handler, especially custom
func NewKeeper(
Expand All @@ -104,7 +117,7 @@ func NewKeeper(
portKeeper types.PortKeeper,
capabilityKeeper types.CapabilityKeeper,
portSource types.ICS20TransferPortSource,
router MessageRouter,
router *baseapp.MsgServiceRouter,
queryRouter GRPCQueryRouter,
homeDir string,
wasmConfig types.WasmConfig,
Expand All @@ -120,6 +133,7 @@ func NewKeeper(
paramSpace = paramSpace.WithKeyTable(types.ParamKeyTable())
}

routerWithCtx := routerWithContext{router}
keeper := &Keeper{
storeKey: storeKey,
cdc: cdc,
Expand All @@ -129,7 +143,7 @@ func NewKeeper(
bank: NewBankCoinTransferrer(bankKeeper),
portKeeper: portKeeper,
capabilityKeeper: capabilityKeeper,
messenger: NewDefaultMessageHandler(router, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource),
messenger: NewDefaultMessageHandler(routerWithCtx, channelKeeper, capabilityKeeper, bankKeeper, cdc, portSource),
queryGasLimit: wasmConfig.SmartQueryGasLimit,
paramSpace: paramSpace,
gasRegister: NewDefaultWasmGasRegister(),
Expand Down
Loading
Loading