diff --git a/delivery/websocket/config.go b/delivery/websocket/config.go index 8e84d87..650e0cc 100644 --- a/delivery/websocket/config.go +++ b/delivery/websocket/config.go @@ -9,8 +9,8 @@ type Limitation struct { MaxSubidLength int32 MinPowDifficulty int32 AuthRequired bool - PaymentRequired bool // todo. - RestrictedWrites bool // todo. + PaymentRequired bool + RestrictedWrites bool MaxEventTags int32 MaxContentLength int32 CreatedAtLowerLimit int64 diff --git a/delivery/websocket/event_handler.go b/delivery/websocket/event_handler.go index 84302ae..471d529 100644 --- a/delivery/websocket/event_handler.go +++ b/delivery/websocket/event_handler.go @@ -6,15 +6,16 @@ import ( "strconv" "time" + "github.com/dezh-tech/immortal/infrastructure/redis" "github.com/dezh-tech/immortal/pkg/logger" "github.com/dezh-tech/immortal/pkg/utils" "github.com/dezh-tech/immortal/types/message" "github.com/gorilla/websocket" + gredis "github.com/redis/go-redis/v9" ) // handleEvent handles new incoming EVENT messages from client. -// todo::: too much complexity. -func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint +func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { s.mu.Lock() defer s.mu.Unlock() defer measureLatency(s.metrics.EventLatency)() @@ -59,6 +60,14 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint bloomCheckCmd := pipe.BFExists(qCtx, s.redis.BloomFilterName, eID[:]) + var whiteListCheckCmd *gredis.BoolCmd + + if s.config.Limitation.RestrictedWrites { + whiteListCheckCmd = pipe.CFExists(qCtx, s.redis.WhiteListFilterName, msg.Event.PublicKey) + } + + blackListCheckCmd := pipe.CFExists(qCtx, s.redis.BlackListFilterName, msg.Event.PublicKey) + _, err := pipe.Exec(qCtx) if err != nil { logger.Error("checking bloom filter", "err", err.Error()) @@ -81,147 +90,84 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint return } - client, ok := s.conns[conn] - if !ok { - _ = conn.WriteMessage(1, message.MakeOK(false, - msg.Event.ID, - fmt.Sprintf("error: can't find connection %s", - conn.RemoteAddr()))) - - status = serverFail - - return - } - - if s.config.Limitation.AuthRequired && !*client.isKnown { - client.challenge = utils.GenerateChallenge(10) - authm := message.MakeAuth(client.challenge) - - okm := message.MakeOK(false, - msg.Event.ID, - "auth-required: we only accept events from authenticated users.", - ) - + isBlackListed, err := blackListCheckCmd.Result() + if err != nil { + okm := message.MakeOK(false, msg.Event.ID, "error: internal error") _ = conn.WriteMessage(1, okm) - _ = conn.WriteMessage(1, authm) - status = authFail + status = serverFail return } - - if msg.Event.IsProtected() && msg.Event.PublicKey != *client.pubkey { - client.challenge = utils.GenerateChallenge(10) - authm := message.MakeAuth(client.challenge) - - okm := message.MakeOK(false, - msg.Event.ID, - "auth-required: this event may only be published by its author.", - ) - - _ = conn.WriteMessage(1, authm) - + if isBlackListed { + okm := message.MakeOK(false, msg.Event.ID, "blocked: pubkey is blocked, contact support for more details.") _ = conn.WriteMessage(1, okm) - status = authFail + status = limitsFail return } - expirationTag := msg.Event.Tags.GetValue("expiration") - - if expirationTag != "" { - expiration, err := strconv.ParseInt(expirationTag, 10, 64) + if s.config.Limitation.RestrictedWrites { + isWhiteListed, err := whiteListCheckCmd.Result() if err != nil { - okm := message.MakeOK(false, - msg.Event.ID, - fmt.Sprintf("invalid: expiration tag %s.", expirationTag), - ) - - _ = conn.WriteMessage(1, okm) - - status = invalidFail - - return - } - - if time.Now().Unix() >= expiration { - okm := message.MakeOK(false, - msg.Event.ID, - fmt.Sprintf("invalid: this event was expired in %s.", time.Unix(expiration, 0).String()), - ) - + okm := message.MakeOK(false, msg.Event.ID, "error: internal error") _ = conn.WriteMessage(1, okm) - status = invalidFail + status = serverFail return } - if err := s.redis.AddDelayedTask(expirationTaskListName, - fmt.Sprintf("%s:%d", msg.Event.ID, msg.Event.Kind), time.Until(time.Unix(expiration, 0))); err != nil { - okm := message.MakeOK(false, - msg.Event.ID, "error: can't add event to expiration queue.", - ) - + if !isWhiteListed { + okm := message.MakeOK(false, msg.Event.ID, "restricted: not allowed to write.") _ = conn.WriteMessage(1, okm) - status = invalidFail + status = limitsFail return } } - if len(msg.Event.Content) > int(s.config.Limitation.MaxContentLength) { - okm := message.MakeOK(false, + client, ok := s.conns[conn] + if !ok { + _ = conn.WriteMessage(1, message.MakeOK(false, msg.Event.ID, - fmt.Sprintf("error: max limit of content length is %d", s.config.Limitation.MaxContentLength), - ) - - _ = conn.WriteMessage(1, okm) + fmt.Sprintf("error: can't find connection %s", + conn.RemoteAddr()))) - status = limitsFail + status = serverFail return } - if msg.Event.Difficulty() < int(s.config.Limitation.MinPowDifficulty) { - okm := message.MakeOK(false, - msg.Event.ID, - fmt.Sprintf("error: min pow required is %d", s.config.Limitation.MinPowDifficulty), - ) - - _ = conn.WriteMessage(1, okm) - - status = limitsFail - - return - } + accepted, authFail, failType, resp := checkLimitations(client, s.redis, *s.config.Limitation, *msg) + if !accepted && authFail { + client.challenge = utils.GenerateChallenge(10) + authm := message.MakeAuth(client.challenge) - if len(msg.Event.Tags) > int(s.config.Limitation.MaxEventTags) { okm := message.MakeOK(false, msg.Event.ID, - fmt.Sprintf("error: max limit of tags count is %d", s.config.Limitation.MaxEventTags), + resp, ) _ = conn.WriteMessage(1, okm) - status = limitsFail + _ = conn.WriteMessage(1, authm) + status = failType return } - if msg.Event.CreatedAt < s.config.Limitation.CreatedAtLowerLimit || - msg.Event.CreatedAt > s.config.Limitation.CreatedAtUpperLimit { + if !accepted { okm := message.MakeOK(false, msg.Event.ID, - fmt.Sprintf("error: created at must be as least %d and at most %d", - s.config.Limitation.CreatedAtLowerLimit, s.config.Limitation.CreatedAtUpperLimit), + resp, ) _ = conn.WriteMessage(1, okm) - status = limitsFail + status = failType return } @@ -261,3 +207,55 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint client.Unlock() } } + +func checkLimitations(c clientState, r *redis.Redis, + limits Limitation, msg message.Event) (bool, bool, + string, string, +) { + if limits.AuthRequired && !*c.isKnown { + return false, true, authFail, "auth-required: we only accept events from authenticated users." + } + + if msg.Event.IsProtected() && msg.Event.PublicKey != *c.pubkey { + return false, true, authFail, "auth-required: this event may only be published by its author." + } + + expirationTag := msg.Event.Tags.GetValue("expiration") + + if expirationTag != "" { + expiration, err := strconv.ParseInt(expirationTag, 10, 64) + if err != nil { + return false, false, serverFail, fmt.Sprintf("invalid: expiration tag %s.", expirationTag) + } + + if time.Now().Unix() >= expiration { + return false, false, invalidFail, fmt.Sprintf("invalid: this event was expired in %s.", + time.Unix(expiration, 0).String()) + } + + if err := r.AddDelayedTask(expirationTaskListName, + fmt.Sprintf("%s:%d", msg.Event.ID, msg.Event.Kind), time.Until(time.Unix(expiration, 0))); err != nil { + return false, false, serverFail, "error: can't add event to expiration queue." + } + } + + if len(msg.Event.Content) > int(limits.MaxContentLength) { + return false, false, limitsFail, fmt.Sprintf("error: max limit of content length is %d", limits.MaxContentLength) + } + + if msg.Event.Difficulty() < int(limits.MinPowDifficulty) { + return false, false, limitsFail, fmt.Sprintf("error: min pow required is %d", limits.MinPowDifficulty) + } + + if len(msg.Event.Tags) > int(limits.MaxEventTags) { + return false, false, limitsFail, fmt.Sprintf("error: max limit of tags count is %d", limits.MaxEventTags) + } + + if msg.Event.CreatedAt < limits.CreatedAtLowerLimit || + msg.Event.CreatedAt > limits.CreatedAtUpperLimit { + return false, false, limitsFail, fmt.Sprintf("error: created at must be as least %d and at most %d", + limits.CreatedAtLowerLimit, limits.CreatedAtUpperLimit) + } + + return true, false, "", "" +}