From 887ea62f9edeec7e51e71eeac3f967761c31b576 Mon Sep 17 00:00:00 2001 From: Zig Blathazar <42387185+ZigBalthazar@users.noreply.github.com> Date: Thu, 2 Jan 2025 01:09:23 +0330 Subject: [PATCH] feat(websocket): add white/black list filter. (#99) --- config/config.yml | 12 +++++-- relay/redis/redis.go | 23 +++++++----- relay/redis/redis_config.go | 10 +++--- server/websocket/event_handler.go | 60 +++++++++++++++++++++++++++++-- 4 files changed, 87 insertions(+), 18 deletions(-) diff --git a/config/config.yml b/config/config.yml index 1322b54..ffab3a2 100644 --- a/config/config.yml +++ b/config/config.yml @@ -66,6 +66,14 @@ redis: # default is 5000. connection_timeout_in_ms: 5000 - # bloom_name specifies the name of bloom filter key + # bloom_filter_name specifies the name of bloom filter key # default is IMMO_BLOOM. - bloom_name: IMMO_BLOOM + bloom_filter_name: IMMO_BLOOM + + # black_list_filter_name specifies the name of blacklist cuckoo filter key + # default is IMMO_BLACK_LIST. + black_list_filter_name: IMMO_BLACK_LIST + + # white_list_filter_name specifies the name of whitelist cuckoo filter key + # default is IMMO_WHITE_LIST. + white_list_filter_name: IMMO_WHITE_LIST diff --git a/relay/redis/redis.go b/relay/redis/redis.go index 6532cd2..3045a5a 100644 --- a/relay/redis/redis.go +++ b/relay/redis/redis.go @@ -9,9 +9,12 @@ import ( ) type Redis struct { - Client *redis.Client - BloomName string - QueryTimeout time.Duration + Client *redis.Client + BloomFilterName string + WhiteListFilterName string + BlackListFilterName string + Name string + QueryTimeout time.Duration } func New(cfg Config) (*Redis, error) { @@ -32,14 +35,16 @@ func New(cfg Config) (*Redis, error) { } return &Redis{ - Client: rc, - BloomName: cfg.BloomName, - QueryTimeout: time.Duration(cfg.QueryTimeout) * time.Millisecond, + Client: rc, + BloomFilterName: cfg.BloomFilterName, + WhiteListFilterName: cfg.WhiteListFilterName, + BlackListFilterName: cfg.BlackListFilterName, + QueryTimeout: time.Duration(cfg.QueryTimeout) * time.Millisecond, }, nil } // ! note: delayed tasks probably are not concurrent safe at the moment. -func (r Redis) AddDelayedTask(listName string, +func (r *Redis) AddDelayedTask(listName string, data string, delay time.Duration, ) error { taskReadyInSeconds := time.Now().Add(delay).Unix() @@ -59,7 +64,7 @@ func (r Redis) AddDelayedTask(listName string, return nil } -func (r Redis) GetReadyTasks(listName string) ([]string, error) { +func (r *Redis) GetReadyTasks(listName string) ([]string, error) { maxTime := time.Now().Unix() opt := &redis.ZRangeBy{ @@ -84,7 +89,7 @@ func (r Redis) GetReadyTasks(listName string) ([]string, error) { return resultSet, nil } -func (r Redis) RemoveTasks(listName string, tasks []string) error { +func (r *Redis) RemoveTasks(listName string, tasks []string) error { if len(tasks) == 0 { return nil } diff --git a/relay/redis/redis_config.go b/relay/redis/redis_config.go index 4773046..ac2fce8 100644 --- a/relay/redis/redis_config.go +++ b/relay/redis/redis_config.go @@ -1,8 +1,10 @@ package redis type Config struct { - URI string - BloomName string `yaml:"bloom_name"` - ConnectionTimeout int16 `yaml:"connection_timeout_in_ms"` - QueryTimeout int16 `yaml:"query_timeout_in_ms"` + URI string + BloomFilterName string `yaml:"bloom_filter_name"` + BlackListFilterName string `yaml:"black_list_filter_name"` + WhiteListFilterName string `yaml:"white_list_filter_name"` + ConnectionTimeout int16 `yaml:"connection_timeout_in_ms"` + QueryTimeout int16 `yaml:"query_timeout_in_ms"` } diff --git a/server/websocket/event_handler.go b/server/websocket/event_handler.go index 9f9671c..b5efc71 100644 --- a/server/websocket/event_handler.go +++ b/server/websocket/event_handler.go @@ -38,15 +38,33 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint } eID := msg.Event.GetRawID() + pubkey := msg.Event.PublicKey qCtx, cancel := context.WithTimeout(context.Background(), s.redis.QueryTimeout) defer cancel() - exists, err := s.redis.Client.BFExists(qCtx, s.redis.BloomName, eID[:]).Result() + pipe := s.redis.Client.Pipeline() + + bloomCheckCmd := pipe.BFExists(qCtx, s.redis.BloomFilterName, eID[:]) + + // TODO::: check config to enable filter checks + whiteListCheckCmd := pipe.CFExists(qCtx, s.redis.WhiteListFilterName, pubkey) + blackListCheckCmd := pipe.CFExists(qCtx, s.redis.BlackListFilterName, pubkey) + + _, err := pipe.Exec(qCtx) if err != nil { - log.Printf("error: checking bloom filter: %s", err.Error()) + log.Printf("error: checking filters: %s", err.Error()) } + exists, err := bloomCheckCmd.Result() + if err != nil { + okm := message.MakeOK(false, msg.Event.ID, "error: internal error") + _ = conn.WriteMessage(1, okm) + + status = serverFail + + return + } if exists { okm := message.MakeOK(true, msg.Event.ID, "") _ = conn.WriteMessage(1, okm) @@ -54,6 +72,42 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint return } + notAllowedToWrite, err := blackListCheckCmd.Result() + if err != nil { + okm := message.MakeOK(false, msg.Event.ID, "error: internal error") + _ = conn.WriteMessage(1, okm) + + status = serverFail + + return + } + if notAllowedToWrite { + okm := message.MakeOK(false, msg.Event.ID, "blocked: pubkey is blocked, contact support for more details.") + _ = conn.WriteMessage(1, okm) + + status = limitsFail + + return + } + + allowedToWrite, err := whiteListCheckCmd.Result() + if err != nil { + okm := message.MakeOK(false, msg.Event.ID, "error: internal error") + _ = conn.WriteMessage(1, okm) + + status = serverFail + + return + } + if !allowedToWrite { + okm := message.MakeOK(false, msg.Event.ID, "restricted: not allowed to write.") + _ = conn.WriteMessage(1, okm) + + status = limitsFail + + return + } + client, ok := s.conns[conn] if !ok { _ = conn.WriteMessage(1, message.MakeOK(false, @@ -229,7 +283,7 @@ func (s *Server) handleEvent(conn *websocket.Conn, m message.Message) { //nolint _ = conn.WriteMessage(1, message.MakeOK(true, msg.Event.ID, "")) } - _, err = s.redis.Client.BFAdd(qCtx, s.redis.BloomName, eID[:]).Result() + _, err = s.redis.Client.BFAdd(qCtx, s.redis.BloomFilterName, eID[:]).Result() if err != nil { log.Printf("error: adding event to bloom filter.") }