Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ws): white/black list pubkey. #109

Merged
merged 3 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions delivery/websocket/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ type Limitation struct {
MaxSubidLength int32
MinPowDifficulty int32
AuthRequired bool
PaymentRequired bool // todo.
RestrictedWrites bool // todo.
PaymentRequired bool
RestrictedWrites bool
MaxEventTags int32
MaxContentLength int32
CreatedAtLowerLimit int64
Expand Down
188 changes: 93 additions & 95 deletions delivery/websocket/event_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
"strconv"
"time"

"github.com/dezh-tech/immortal/infrastructure/redis"
"github.com/dezh-tech/immortal/pkg/logger"
"github.com/dezh-tech/immortal/pkg/utils"
"github.com/dezh-tech/immortal/types/message"
"github.com/gorilla/websocket"
gredis "github.com/redis/go-redis/v9"
)

// handleEvent handles new incoming EVENT messages from client.
// todo::: too much complexity.
func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint
func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) {
s.mu.Lock()
defer s.mu.Unlock()
defer measureLatency(s.metrics.EventLatency)()
Expand Down Expand Up @@ -59,6 +60,14 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint

bloomCheckCmd := pipe.BFExists(qCtx, s.redis.BloomFilterName, eID[:])

var whiteListCheckCmd *gredis.BoolCmd

if s.config.Limitation.RestrictedWrites {
whiteListCheckCmd = pipe.CFExists(qCtx, s.redis.WhiteListFilterName, msg.Event.PublicKey)
}

blackListCheckCmd := pipe.CFExists(qCtx, s.redis.BlackListFilterName, msg.Event.PublicKey)

_, err := pipe.Exec(qCtx)
if err != nil {
logger.Error("checking bloom filter", "err", err.Error())
Expand All @@ -81,147 +90,84 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint
return
}

client, ok := s.conns[conn]
if !ok {
_ = conn.WriteMessage(1, message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: can't find connection %s",
conn.RemoteAddr())))

status = serverFail

return
}

if s.config.Limitation.AuthRequired && !*client.isKnown {
client.challenge = utils.GenerateChallenge(10)
authm := message.MakeAuth(client.challenge)

okm := message.MakeOK(false,
msg.Event.ID,
"auth-required: we only accept events from authenticated users.",
)

isBlackListed, err := blackListCheckCmd.Result()
if err != nil {
okm := message.MakeOK(false, msg.Event.ID, "error: internal error")
_ = conn.WriteMessage(1, okm)

_ = conn.WriteMessage(1, authm)
status = authFail
status = serverFail

return
}

if msg.Event.IsProtected() && msg.Event.PublicKey != *client.pubkey {
client.challenge = utils.GenerateChallenge(10)
authm := message.MakeAuth(client.challenge)

okm := message.MakeOK(false,
msg.Event.ID,
"auth-required: this event may only be published by its author.",
)

_ = conn.WriteMessage(1, authm)

if isBlackListed {
okm := message.MakeOK(false, msg.Event.ID, "blocked: pubkey is blocked, contact support for more details.")
_ = conn.WriteMessage(1, okm)

status = authFail
status = limitsFail

return
}

expirationTag := msg.Event.Tags.GetValue("expiration")

if expirationTag != "" {
expiration, err := strconv.ParseInt(expirationTag, 10, 64)
if s.config.Limitation.RestrictedWrites {
isWhiteListed, err := whiteListCheckCmd.Result()
if err != nil {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("invalid: expiration tag %s.", expirationTag),
)

_ = conn.WriteMessage(1, okm)

status = invalidFail

return
}

if time.Now().Unix() >= expiration {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("invalid: this event was expired in %s.", time.Unix(expiration, 0).String()),
)

okm := message.MakeOK(false, msg.Event.ID, "error: internal error")
_ = conn.WriteMessage(1, okm)

status = invalidFail
status = serverFail

return
}

if err := s.redis.AddDelayedTask(expirationTaskListName,
fmt.Sprintf("%s:%d", msg.Event.ID, msg.Event.Kind), time.Until(time.Unix(expiration, 0))); err != nil {
okm := message.MakeOK(false,
msg.Event.ID, "error: can't add event to expiration queue.",
)

if !isWhiteListed {
okm := message.MakeOK(false, msg.Event.ID, "restricted: not allowed to write.")
_ = conn.WriteMessage(1, okm)

status = invalidFail
status = limitsFail

return
}
}

if len(msg.Event.Content) > int(s.config.Limitation.MaxContentLength) {
okm := message.MakeOK(false,
client, ok := s.conns[conn]
if !ok {
_ = conn.WriteMessage(1, message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: max limit of content length is %d", s.config.Limitation.MaxContentLength),
)

_ = conn.WriteMessage(1, okm)
fmt.Sprintf("error: can't find connection %s",
conn.RemoteAddr())))

status = limitsFail
status = serverFail

return
}

if msg.Event.Difficulty() < int(s.config.Limitation.MinPowDifficulty) {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: min pow required is %d", s.config.Limitation.MinPowDifficulty),
)

_ = conn.WriteMessage(1, okm)

status = limitsFail

return
}
accepted, authFail, failType, resp := checkLimitations(client, s.redis, *s.config.Limitation, *msg)
if !accepted && authFail {
client.challenge = utils.GenerateChallenge(10)
authm := message.MakeAuth(client.challenge)

if len(msg.Event.Tags) > int(s.config.Limitation.MaxEventTags) {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: max limit of tags count is %d", s.config.Limitation.MaxEventTags),
resp,
)

_ = conn.WriteMessage(1, okm)

status = limitsFail
_ = conn.WriteMessage(1, authm)
status = failType

return
}

if msg.Event.CreatedAt < s.config.Limitation.CreatedAtLowerLimit ||
msg.Event.CreatedAt > s.config.Limitation.CreatedAtUpperLimit {
if !accepted {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: created at must be as least %d and at most %d",
s.config.Limitation.CreatedAtLowerLimit, s.config.Limitation.CreatedAtUpperLimit),
resp,
)

_ = conn.WriteMessage(1, okm)

status = limitsFail
status = failType

return
}
Expand Down Expand Up @@ -261,3 +207,55 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint
client.Unlock()
}
}

func checkLimitations(c clientState, r *redis.Redis,
limits Limitation, msg message.Event) (bool, bool,
string, string,
) {
if limits.AuthRequired && !*c.isKnown {
return false, true, authFail, "auth-required: we only accept events from authenticated users."
}

if msg.Event.IsProtected() && msg.Event.PublicKey != *c.pubkey {
return false, true, authFail, "auth-required: this event may only be published by its author."
}

expirationTag := msg.Event.Tags.GetValue("expiration")

if expirationTag != "" {
expiration, err := strconv.ParseInt(expirationTag, 10, 64)
if err != nil {
return false, false, serverFail, fmt.Sprintf("invalid: expiration tag %s.", expirationTag)
}

if time.Now().Unix() >= expiration {
return false, false, invalidFail, fmt.Sprintf("invalid: this event was expired in %s.",
time.Unix(expiration, 0).String())
}

if err := r.AddDelayedTask(expirationTaskListName,
fmt.Sprintf("%s:%d", msg.Event.ID, msg.Event.Kind), time.Until(time.Unix(expiration, 0))); err != nil {
return false, false, serverFail, "error: can't add event to expiration queue."
}
}

if len(msg.Event.Content) > int(limits.MaxContentLength) {
return false, false, limitsFail, fmt.Sprintf("error: max limit of content length is %d", limits.MaxContentLength)
}

if msg.Event.Difficulty() < int(limits.MinPowDifficulty) {
return false, false, limitsFail, fmt.Sprintf("error: min pow required is %d", limits.MinPowDifficulty)
}

if len(msg.Event.Tags) > int(limits.MaxEventTags) {
return false, false, limitsFail, fmt.Sprintf("error: max limit of tags count is %d", limits.MaxEventTags)
}

if msg.Event.CreatedAt < limits.CreatedAtLowerLimit ||
msg.Event.CreatedAt > limits.CreatedAtUpperLimit {
return false, false, limitsFail, fmt.Sprintf("error: created at must be as least %d and at most %d",
limits.CreatedAtLowerLimit, limits.CreatedAtUpperLimit)
}

return true, false, "", ""
}
Loading