Skip to content

Commit

Permalink
Extend ext.Context to store bot information (#198)
Browse files Browse the repository at this point in the history
* Add the bot's userinfo to the context struct to ensure we have all the necessary information to determine update ownership at runtime

* Use botID instead of full bot info

* Improve overall data
  • Loading branch information
PaulSonOfLars authored Nov 3, 2024
1 parent 6e09c7f commit 616c537
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 34 deletions.
27 changes: 25 additions & 2 deletions bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
)

//go:generate go run ./scripts/generate

var (
ErrNilBotClient = errors.New("nil BotClient")
ErrInvalidTokenFormat = errors.New("invalid token format")
)

// Bot is the default Bot struct used to send and receive messages to the telegram API.
type Bot struct {
// Token stores the bot's secret token obtained from t.me/BotFather, and used to interact with telegram's API.
Expand Down Expand Up @@ -76,6 +83,24 @@ func NewBot(token string, opts *BotOpts) (*Bot, error) {
return nil, fmt.Errorf("failed to check bot token: %w", err)
}
b.User = *botUser
} else {
// If token checks are disabled, we populate the bot's ID from the token.
split := strings.Split(token, ":")
if len(split) != 2 {
return nil, fmt.Errorf("%w: expected '123:abcd', got %s", ErrInvalidTokenFormat, token)
}

id, err := strconv.ParseInt(split[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse bot ID from token: %w", err)
}
b.User = User{
Id: id,
IsBot: true,
// We mark these fields as missing so we can know why they're not available
FirstName: "<missing>",
Username: "<missing>",
}
}

return &b, nil
Expand All @@ -89,8 +114,6 @@ func (bot *Bot) UseMiddleware(mw func(client BotClient) BotClient) *Bot {
return bot
}

var ErrNilBotClient = errors.New("nil BotClient")

func (bot *Bot) Request(method string, params map[string]string, data map[string]FileReader, opts *RequestOpts) (json.RawMessage, error) {
return bot.RequestWithContext(context.Background(), method, params, data, opts)
}
Expand Down
8 changes: 6 additions & 2 deletions ext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
type Context struct {
// gotgbot.Update is inlined so that we can access all fields immediately if necessary.
*gotgbot.Update
// Bot represents gotgbot.User behind the Bot that received this update, so we can keep track of update ownership.
// Note: this information may be incomplete in the case where token validation is disabled.
Bot gotgbot.User
// Data represents update-local storage.
// This can be used to pass data across handlers - for example, to cache operations relevant to the current update,
// such as admin checks.
Expand All @@ -35,9 +38,9 @@ type Context struct {
EffectiveSender *gotgbot.Sender
}

// NewContext populates a context with the relevant fields from the current update.
// NewContext populates a context with the relevant fields from the current bot and update.
// It takes a data field in the case where custom data needs to be passed.
func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context {
func NewContext(b *gotgbot.Bot, update *gotgbot.Update, data map[string]interface{}) *Context {
var msg *gotgbot.Message
var chat *gotgbot.Chat
var user *gotgbot.User
Expand Down Expand Up @@ -162,6 +165,7 @@ func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context {

return &Context{
Update: update,
Bot: b.User,
Data: data,
EffectiveMessage: msg,
EffectiveChat: chat,
Expand Down
2 changes: 1 addition & 1 deletion ext/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (d *Dispatcher) processRawUpdate(b *gotgbot.Bot, r json.RawMessage) error {
// ProcessUpdate iterates over the list of groups to execute the matching handlers.
// This is also where we recover from any panics that are thrown by user code, to avoid taking down the bot.
func (d *Dispatcher) ProcessUpdate(b *gotgbot.Bot, u *gotgbot.Update, data map[string]interface{}) (err error) {
ctx := NewContext(u, data)
ctx := NewContext(b, u, data)

defer func() {
if r := recover(); r != nil {
Expand Down
2 changes: 1 addition & 1 deletion ext/dispatcher_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestDispatcher(t *testing.T) {
}

t.Log("Processing one update...")
err := d.ProcessUpdate(nil, &gotgbot.Update{
err := d.ProcessUpdate(&gotgbot.Bot{}, &gotgbot.Update{
Message: &gotgbot.Message{Text: "test text"},
}, nil)
if err != nil {
Expand Down
14 changes: 7 additions & 7 deletions ext/handlers/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func NewTestBot() *gotgbot.Bot {
return &gotgbot.Bot{
Token: "use-me",
User: gotgbot.User{
Id: 0,
Id: rand.Int63(),
IsBot: false,
FirstName: "gobot",
LastName: "",
Expand All @@ -33,13 +33,13 @@ func NewTestBot() *gotgbot.Bot {
}
}

func NewMessage(userId int64, chatId int64, message string) *ext.Context {
return newMessage(userId, chatId, message, nil)
func NewMessage(b *gotgbot.Bot, userId int64, chatId int64, message string) *ext.Context {
return newMessage(b, userId, chatId, message, nil)
}

func NewCommandMessage(userId int64, chatId int64, command string, args []string) *ext.Context {
func NewCommandMessage(b *gotgbot.Bot, userId int64, chatId int64, command string, args []string) *ext.Context {
msg, ents := buildCommand(command, args)
return newMessage(userId, chatId, msg, ents)
return newMessage(b, userId, chatId, msg, ents)
}

func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) {
Expand All @@ -53,13 +53,13 @@ func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) {
}
}

func newMessage(userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context {
func newMessage(b *gotgbot.Bot, userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context {
chatType := "supergroup"
if userId == chatId {
chatType = "private"
}

return ext.NewContext(&gotgbot.Update{
return ext.NewContext(b, &gotgbot.Update{
UpdateId: rand.Int63(), // should this be consistent?
Message: &gotgbot.Message{
MessageId: rand.Int63(), // should this be consistent?
Expand Down
7 changes: 3 additions & 4 deletions ext/handlers/conversation/key_strategies.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package conversation
import (
"errors"
"fmt"
"strconv"

"github.com/PaulSonOfLars/gotgbot/v2/ext"
)
Expand All @@ -27,23 +26,23 @@ func KeyStrategySenderAndChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil || ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing sender or chat fields: %w", ErrEmptyKey)
}
return fmt.Sprintf("%d/%d", ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil
return fmt.Sprintf("%d/%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil
}

// KeyStrategySender gives a unique conversation to each sender, and that single conversation is available in all chats.
func KeyStrategySender(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil {
return "", fmt.Errorf("missing sender field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveSender.Id(), 10), nil
return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id()), nil
}

// KeyStrategyChat gives a unique conversation to each chat, which all senders can interact in together.
func KeyStrategyChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing chat field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveChat.Id, 10), nil
return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveChat.Id), nil
}

// StateKey provides a sane default for handling incoming updates.
Expand Down
45 changes: 28 additions & 17 deletions ext/handlers/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ func TestBasicConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
textMessage := NewMessage(userId, chatId, "message")
textMessage := NewMessage(b, userId, chatId, "message")
runHandler(t, b, &conv, textMessage, nextStep, "")
if !ended {
t.Fatalf("expected the internal handler to have run")
Expand Down Expand Up @@ -79,8 +79,8 @@ func TestBasicKeyedConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startFromUserOne := NewCommandMessage(userIdOne, chatId, "start", []string{})
messageFromTwo := NewMessage(userIdTwo, chatId, "message")
startFromUserOne := NewCommandMessage(b, userIdOne, chatId, "start", []string{})
messageFromTwo := NewMessage(b, userIdTwo, chatId, "message")

runHandler(t, b, &conv, startFromUserOne, "", nextStep)

Expand All @@ -89,6 +89,11 @@ func TestBasicKeyedConversation(t *testing.T) {

// But user two doesnt exist
checkExpectedState(t, &conv, messageFromTwo, "")

b2 := NewTestBot()
messageTo2 := NewMessage(b2, userIdOne, chatId, "message")
// And bot two hasn't changed either
checkExpectedState(t, &conv, messageTo2, "")
}

func TestBasicConversationExit(t *testing.T) {
Expand Down Expand Up @@ -121,14 +126,14 @@ func TestBasicConversationExit(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint, and starting the conversation.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "cancel" command, triggering the exitpoint, and immediately ending the conversation.
cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{})
cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{})
runHandler(t, b, &conv, cancelCommand, nextStep, "")
if !ended {
t.Fatalf("expected the cancel command to have run")
Expand All @@ -138,7 +143,7 @@ func TestBasicConversationExit(t *testing.T) {
checkExpectedState(t, &conv, cancelCommand, "")

// Emulate sending the "message" text, which now should not interact with the conversation.
textMessage := NewMessage(userId, chatId, "message")
textMessage := NewMessage(b, userId, chatId, "message")
if conv.CheckUpdate(b, textMessage) {
t.Fatalf("did not expect the internal handler to run")
}
Expand Down Expand Up @@ -177,14 +182,14 @@ func TestFallbackConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "cancel" command, triggering the fallback handler (and causing it to "end").
cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{})
cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{})
runHandler(t, b, &conv, cancelCommand, nextStep, "")
if !fallback {
t.Fatalf("expected the fallback handler to have run")
Expand Down Expand Up @@ -220,14 +225,14 @@ func TestReEntryConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if startCount != 1 {
t.Fatalf("expected the entrypoint handler to have run")
}

// Send a message which matches both the entrypoint, and the "nextStep" state.
cancelCommand := NewCommandMessage(userId, chatId, "start", []string{"message"})
cancelCommand := NewCommandMessage(b, userId, chatId, "start", []string{"message"})
runHandler(t, b, &conv, cancelCommand, nextStep, nextStep) // Should hit
if startCount != 2 {
t.Fatalf("expected the entrypoint handler to have run a second time")
Expand Down Expand Up @@ -285,20 +290,20 @@ func TestNestedConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
start := NewCommandMessage(userId, chatId, startCmd, []string{})
start := NewCommandMessage(b, userId, chatId, startCmd, []string{})
runHandler(t, b, &conv, start, "", firstStep)

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
textMessage := NewMessage(userId, chatId, messageText)
textMessage := NewMessage(b, userId, chatId, messageText)
runHandler(t, b, &conv, textMessage, firstStep, secondStep)

// Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation.
nestedStart := NewCommandMessage(userId, chatId, nestedStartCmd, []string{})
nestedStart := NewCommandMessage(b, userId, chatId, nestedStartCmd, []string{})
willRunHandler(t, b, &nestedConv, nestedStart, "")
runHandler(t, b, &conv, nestedStart, secondStep, secondStep)

// Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation.
nestedFinish := NewMessage(userId, chatId, finishNestedText)
nestedFinish := NewMessage(b, userId, chatId, finishNestedText)
willRunHandler(t, b, &nestedConv, nestedFinish, nestedStep)
runHandler(t, b, &conv, nestedFinish, secondStep, thirdStep)

Expand All @@ -307,7 +312,7 @@ func TestNestedConversation(t *testing.T) {
t.Log("Nested conversation finished")

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
finish := NewMessage(userId, chatId, finishText)
finish := NewMessage(b, userId, chatId, finishText)
runHandler(t, b, &conv, finish, thirdStep, "")

checkExpectedState(t, &conv, textMessage, "")
Expand All @@ -329,7 +334,7 @@ func TestEmptyKeyConversation(t *testing.T) {
)

// Run an empty
pollUpd := ext.NewContext(&gotgbot.Update{
pollUpd := ext.NewContext(b, &gotgbot.Update{
UpdateId: rand.Int63(), // should this be consistent?
Poll: &gotgbot.Poll{
Id: "some_id",
Expand Down Expand Up @@ -358,6 +363,8 @@ func TestEmptyKeyConversation(t *testing.T) {

// runHandler ensures that the incoming update will trigger the conversation.
func runHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, message *ext.Context, currentState string, nextState string) {
t.Helper()

willRunHandler(t, b, conv, message, currentState)
if err := conv.HandleUpdate(b, message); err != nil {
t.Fatalf("unexpected error from handler: %s", err.Error())
Expand All @@ -368,6 +375,8 @@ func runHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, messa

// willRunHandler ensures that the incoming update will trigger the conversation.
func willRunHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, message *ext.Context, expectedState string) {
t.Helper()

t.Logf("conv %p: checking message for %d in %d with text: %s", conv, message.EffectiveSender.Id(), message.EffectiveChat.Id, message.Message.Text)

checkExpectedState(t, conv, message, expectedState)
Expand All @@ -378,6 +387,8 @@ func willRunHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, m
}

func checkExpectedState(t *testing.T, conv *handlers.Conversation, message *ext.Context, nextState string) {
t.Helper()

currentState, err := conv.StateStorage.Get(message)
if err != nil {
if nextState == "" && errors.Is(err, conversation.ErrKeyNotFound) {
Expand Down

0 comments on commit 616c537

Please sign in to comment.