Skip to content

Commit

Permalink
feat(ws): white/black list pubkey.
Browse files Browse the repository at this point in the history
  • Loading branch information
kehiy committed Jan 10, 2025
1 parent 165c5d0 commit 7e098d3
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 97 deletions.
4 changes: 2 additions & 2 deletions delivery/websocket/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ type Limitation struct {
MaxSubidLength int32
MinPowDifficulty int32
AuthRequired bool
PaymentRequired bool // todo.
RestrictedWrites bool // todo.
PaymentRequired bool
RestrictedWrites bool
MaxEventTags int32
MaxContentLength int32
CreatedAtLowerLimit int64
Expand Down
188 changes: 93 additions & 95 deletions delivery/websocket/event_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
"strconv"
"time"

"github.com/dezh-tech/immortal/infrastructure/redis"
"github.com/dezh-tech/immortal/pkg/logger"
"github.com/dezh-tech/immortal/pkg/utils"
"github.com/dezh-tech/immortal/types/message"
"github.com/gorilla/websocket"
gredis "github.com/redis/go-redis/v9"
)

// handleEvent handles new incoming EVENT messages from client.
// todo::: too much complexity.
func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint
func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) {
s.mu.Lock()
defer s.mu.Unlock()
defer measureLatency(s.metrics.EventLatency)()
Expand Down Expand Up @@ -59,6 +60,14 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint

bloomCheckCmd := pipe.BFExists(qCtx, s.redis.BloomFilterName, eID[:])

var whiteListCheckCmd *gredis.BoolCmd = nil

Check failure on line 63 in delivery/websocket/event_handler.go

View workflow job for this annotation

GitHub Actions / lint

var-declaration: should drop = nil from declaration of var whiteListCheckCmd; it is the zero value (revive)

if s.config.Limitation.RestrictedWrites {
whiteListCheckCmd = pipe.CFExists(qCtx, s.redis.WhiteListFilterName, msg.Event.PublicKey)
}

blackListCheckCmd := pipe.CFExists(qCtx, s.redis.BlackListFilterName, msg.Event.PublicKey)

_, err := pipe.Exec(qCtx)
if err != nil {
logger.Error("checking bloom filter", "err", err.Error())
Expand All @@ -81,147 +90,84 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint
return
}

client, ok := s.conns[conn]
if !ok {
_ = conn.WriteMessage(1, message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: can't find connection %s",
conn.RemoteAddr())))

status = serverFail

return
}

if s.config.Limitation.AuthRequired && !*client.isKnown {
client.challenge = utils.GenerateChallenge(10)
authm := message.MakeAuth(client.challenge)

okm := message.MakeOK(false,
msg.Event.ID,
"auth-required: we only accept events from authenticated users.",
)

isBlackListed, err := blackListCheckCmd.Result()
if err != nil {
okm := message.MakeOK(false, msg.Event.ID, "error: internal error")
_ = conn.WriteMessage(1, okm)

_ = conn.WriteMessage(1, authm)
status = authFail
status = serverFail

return
}

if msg.Event.IsProtected() && msg.Event.PublicKey != *client.pubkey {
client.challenge = utils.GenerateChallenge(10)
authm := message.MakeAuth(client.challenge)

okm := message.MakeOK(false,
msg.Event.ID,
"auth-required: this event may only be published by its author.",
)

_ = conn.WriteMessage(1, authm)

if isBlackListed {
okm := message.MakeOK(false, msg.Event.ID, "blocked: pubkey is blocked, contact support for more details.")
_ = conn.WriteMessage(1, okm)

status = authFail
status = limitsFail

return
}

expirationTag := msg.Event.Tags.GetValue("expiration")

if expirationTag != "" {
expiration, err := strconv.ParseInt(expirationTag, 10, 64)
if s.config.Limitation.RestrictedWrites {
isWhiteListed, err := whiteListCheckCmd.Result()
if err != nil {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("invalid: expiration tag %s.", expirationTag),
)

_ = conn.WriteMessage(1, okm)

status = invalidFail

return
}

if time.Now().Unix() >= expiration {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("invalid: this event was expired in %s.", time.Unix(expiration, 0).String()),
)

okm := message.MakeOK(false, msg.Event.ID, "error: internal error")
_ = conn.WriteMessage(1, okm)

status = invalidFail
status = serverFail

return
}

if err := s.redis.AddDelayedTask(expirationTaskListName,
fmt.Sprintf("%s:%d", msg.Event.ID, msg.Event.Kind), time.Until(time.Unix(expiration, 0))); err != nil {
okm := message.MakeOK(false,
msg.Event.ID, "error: can't add event to expiration queue.",
)

if !isWhiteListed {
okm := message.MakeOK(false, msg.Event.ID, "restricted: not allowed to write.")
_ = conn.WriteMessage(1, okm)

status = invalidFail
status = limitsFail

return
}
}

if len(msg.Event.Content) > int(s.config.Limitation.MaxContentLength) {
okm := message.MakeOK(false,
client, ok := s.conns[conn]
if !ok {
_ = conn.WriteMessage(1, message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: max limit of content length is %d", s.config.Limitation.MaxContentLength),
)

_ = conn.WriteMessage(1, okm)
fmt.Sprintf("error: can't find connection %s",
conn.RemoteAddr())))

status = limitsFail
status = serverFail

return
}

if msg.Event.Difficulty() < int(s.config.Limitation.MinPowDifficulty) {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: min pow required is %d", s.config.Limitation.MinPowDifficulty),
)

_ = conn.WriteMessage(1, okm)

status = limitsFail

return
}
accepted, authFail, failType, resp := checkLimitations(client, s.redis, *s.config.Limitation, *msg)
if !accepted && authFail {
client.challenge = utils.GenerateChallenge(10)
authm := message.MakeAuth(client.challenge)

if len(msg.Event.Tags) > int(s.config.Limitation.MaxEventTags) {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: max limit of tags count is %d", s.config.Limitation.MaxEventTags),
resp,
)

_ = conn.WriteMessage(1, okm)

status = limitsFail
_ = conn.WriteMessage(1, authm)
status = failType

return
}

if msg.Event.CreatedAt < s.config.Limitation.CreatedAtLowerLimit ||
msg.Event.CreatedAt > s.config.Limitation.CreatedAtUpperLimit {
if !accepted {
okm := message.MakeOK(false,
msg.Event.ID,
fmt.Sprintf("error: created at must be as least %d and at most %d",
s.config.Limitation.CreatedAtLowerLimit, s.config.Limitation.CreatedAtUpperLimit),
resp,
)

_ = conn.WriteMessage(1, okm)

status = limitsFail
status = failType

return
}
Expand Down Expand Up @@ -261,3 +207,55 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint
client.Unlock()
}
}

func checkLimitations(c clientState, redis *redis.Redis,

Check failure on line 211 in delivery/websocket/event_handler.go

View workflow job for this annotation

GitHub Actions / lint

importShadow: shadow of imported from 'github.com/dezh-tech/immortal/infrastructure/redis' package 'redis' (gocritic)
limits Limitation, msg message.Event) (accepted bool, isAuthFail bool,
failType string, resp string,
) {
if limits.AuthRequired && !*c.isKnown {
return false, true, authFail, "auth-required: we only accept events from authenticated users."
}

if msg.Event.IsProtected() && msg.Event.PublicKey != *c.pubkey {
return false, true, authFail, "auth-required: this event may only be published by its author."
}

expirationTag := msg.Event.Tags.GetValue("expiration")

if expirationTag != "" {
expiration, err := strconv.ParseInt(expirationTag, 10, 64)
if err != nil {
return false, false, serverFail, fmt.Sprintf("invalid: expiration tag %s.", expirationTag)
}

if time.Now().Unix() >= expiration {
return false, false, invalidFail, fmt.Sprintf("invalid: this event was expired in %s.",
time.Unix(expiration, 0).String())
}

if err := redis.AddDelayedTask(expirationTaskListName,
fmt.Sprintf("%s:%d", msg.Event.ID, msg.Event.Kind), time.Until(time.Unix(expiration, 0))); err != nil {
return false, false, serverFail, "error: can't add event to expiration queue."
}
}

if len(msg.Event.Content) > int(limits.MaxContentLength) {
return false, false, limitsFail, fmt.Sprintf("error: max limit of content length is %d", limits.MaxContentLength)
}

if msg.Event.Difficulty() < int(limits.MinPowDifficulty) {
return false, false, limitsFail, fmt.Sprintf("error: min pow required is %d", limits.MinPowDifficulty)
}

if len(msg.Event.Tags) > int(limits.MaxEventTags) {
return false, false, limitsFail, fmt.Sprintf("error: max limit of tags count is %d", limits.MaxEventTags)
}

if msg.Event.CreatedAt < limits.CreatedAtLowerLimit ||
msg.Event.CreatedAt > limits.CreatedAtUpperLimit {
return false, false, limitsFail, fmt.Sprintf("error: created at must be as least %d and at most %d",
limits.CreatedAtLowerLimit, limits.CreatedAtUpperLimit)
}

return true, false, "", ""
}

0 comments on commit 7e098d3

Please sign in to comment.