diff --git a/_examples/rated.go b/_examples/rated.go index 7d1cc3f..3eb03bd 100644 --- a/_examples/rated.go +++ b/_examples/rated.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "net" + "os" "strings" "sync" "time" @@ -58,12 +59,35 @@ type Client struct { // UniqueKey is an implementation of our Identity interface, in short: Rate5 doesn't care where you derive the string used for ratelimiting func (c Client) UniqueKey() string { + var err error + var host string if c.loggedin { return c.ID } + if host, _, err = net.SplitHostPort(c.Conn.RemoteAddr().String()); err == nil { + return host + } + panic(err) +} + +func argParse() { + if len(os.Args) < 1 { + return + } + for i, arg := range os.Args { + switch arg { + case "-e": + fallthrough + case "--exempt": + if len(os.Args) <= i+1 { + return + } + srv.Exempt[os.Args[i+1]] = true + default: + continue + } - host, _, _ := net.SplitHostPort(c.Conn.RemoteAddr().String()) - return host + } } func init() { @@ -82,97 +106,70 @@ func init() { mu: &sync.RWMutex{}, } - //srv.Exempt["127.0.0.1"] = true + argParse() + rd := Rater.DebugChannel() rrd := RegRater.DebugChannel() crd := CmdRater.DebugChannel() + go watchDebug(rd, rrd, crd) +} +func watchDebug(rd, rrd, crd chan string) { pre := "[Rate5] " - go func() { - var lastcount = 0 - var count = 0 - for { - select { - case msg := <-rd: - fmt.Printf("%s Limit: %s \n", pre, msg) - count++ - case msg := <-rrd: - fmt.Printf("%s RegLimit: %s \n", pre, msg) - count++ - case msg := <-crd: - fmt.Printf("%s CmdLimit: %s \n", pre, msg) - count++ - default: - if count - lastcount >= 25 { - lastcount = count - fmt.Println("Rater: ", Rater.GetGrandTotalRated()) - fmt.Println("RegRater: ", RegRater.GetGrandTotalRated()) - fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated()) - } - time.Sleep(time.Duration(10) * time.Millisecond) + var lastcount = 0 + var count = 0 + for { + select { + case msg := <-rd: + fmt.Printf("%s Limit: %s \n", pre, msg) + count++ + case msg := <-rrd: + fmt.Printf("%s RegLimit: %s \n", pre, msg) + count++ + case msg := <-crd: + fmt.Printf("%s CmdLimit: %s \n", pre, msg) + count++ + default: + if count-lastcount >= 25 { + lastcount = count + fmt.Println("Rater: ", Rater.GetGrandTotalRated()) + fmt.Println("RegRater: ", RegRater.GetGrandTotalRated()) + fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated()) } + time.Sleep(time.Duration(10) * time.Millisecond) } - }() -} - -func (s *Server) handleTCP(c *Client) { - if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil { - fmt.Println("error while setting setlinger:", err.Error()) } +} - // skip ratelimit checking for exempt clients - srv.mu.RLock() - _, exempt := srv.Exempt[c.UniqueKey()] - srv.mu.RUnlock() - - defer func() { - c.Conn.Close() - println("closed: " + c.Conn.RemoteAddr().String()) - }() - - // Returns true if ratelimited - if Rater.Check(c) { - c.Conn.Write([]byte("too many connections")) - println(c.UniqueKey() + " ratelimited") +func (s *Server) preLogin(c *Client) { + c.send("Auth: ") + in := c.recv() + switch { + case s.authCheck(c, in): + c.loggedin = true + c.deadline = time.Duration(480) * time.Second + c.send("successful login") return - } - - c.read = bufio.NewReader(c.Conn) - - c.Conn.Write(loginBanner()) - - for { - if !c.connected { + case in == "register": + // no exemption for strict ratelimiter (rate5 testing) + if RegRater.Check(c) { + c.send("you already registered recently\n") return } + println("new registration from " + c.UniqueKey()) + s.setID(c, s.getUnusedID()) + c.send("\nregistration success\n[New ID]: " + c.ID) + return + default: + c.send("invalid. type 'REGISTER' to register a new ID\n") + return + } - time.Sleep(time.Duration(25) * time.Millisecond) - if !c.loggedin { - c.send("Auth: ") - in := c.recv() - switch { - case s.authCheck(c, in): - c.loggedin = true - c.deadline = time.Duration(480) * time.Second - c.send("successful login") - continue - case in == "register": - if !RegRater.Check(c) || exempt { - println("new registration from " + c.UniqueKey()) - s.setID(c, s.getUnusedID()) - c.send("\nregistration success\n[New ID]: " + c.ID) - return - } else { - c.send("you already registered recently\n") - } - continue - default: - c.send("invalid. type 'REGISTER' to register a new ID\n") - continue - } - } +} +func (s *Server) mainPrompt(c *Client) { + for c.connected { c.send("\nRate5 > ") switch c.recv() { case "history": @@ -190,10 +187,60 @@ func (s *Server) handleTCP(c *Client) { case "logout": c.loggedin = false return + default: + c.send("unknown command, are you lost?") + continue } } } +func isExempt(c *Client) bool { + srv.mu.RLock() + _, exempt := srv.Exempt[c.UniqueKey()] + srv.mu.RUnlock() + return exempt +} + +func connRateCheck(c *Client) bool { + if isExempt(c) { + return false + } + if Rater.Check(c) { + c.send("too many connections") + println(c.UniqueKey() + " ratelimited") + return true + } + return false +} + +func closeConn(c *Client) { + if err := c.Conn.Close(); err != nil { + println(err.Error()) + } + println("closed: " + c.Conn.RemoteAddr().String()) +} + +func (s *Server) handleTCP(c *Client) { + if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil { + fmt.Println("error while setting setlinger:", err.Error()) + } + defer closeConn(c) + if rated := connRateCheck(c); rated { + return + } + c.read = bufio.NewReader(c.Conn) + if _, err := c.Conn.Write(loginBanner()); err != nil { + return + } + for !c.loggedin { + if !c.connected { + return + } + s.preLogin(c) + } + s.mainPrompt(c) +} + func (c *Client) send(data string) { if err := c.Conn.SetReadDeadline(time.Now().Add(c.deadline)); err != nil { fmt.Println("error while setting deadline:", err.Error()) @@ -208,18 +255,16 @@ func (c *Client) recv() string { fmt.Println("error while setting deadline:", err.Error()) } - // skip ratelimit checking for exempt clients - srv.mu.RLock() - _, ok := srv.Exempt[c.UniqueKey()] - srv.mu.RUnlock() - - if CmdRater.Check(c) && !ok { - if !c.loggedin { - // if they hit the ratelimiter during log-in, disconnect them - c.connected = false + if !isExempt(c) { + if CmdRater.Check(c) { + if !c.loggedin { + // if they hit the ratelimiter during log-in, disconnect them + c.connected = false + } + time.Sleep(time.Duration(1250) * time.Millisecond) } - time.Sleep(time.Duration(1250) * time.Millisecond) } + in, err := c.read.ReadString('\n') if err != nil { println(c.UniqueKey() + ": " + err.Error()) @@ -289,20 +334,23 @@ func (s *Server) authCheck(c *Client, id string) bool { if old, ok := s.Map[id]; ok { s.mu.RUnlock() old.connected = false - old.Conn.Close() + closeConn(old) s.replaceSession(c, id) return true } s.mu.RUnlock() return false - } func loginBanner() []byte { + var data []byte + var err error login := "CnwgG1s5MDs0MG1SG1swbRtbMG0gG1s5Nzs0MG3DhhtbMG0bWzBtIBtbOTc7NDBtzpMbWzBtG1swbSAbWzk3OzQwbc6jG1swbRtbMG0gG1swbRtbOTc7MzJtNRtbMG0bWzBtIHwKCg==" - data, _ := base64.StdEncoding.DecodeString(login) - return data + if data, err = base64.StdEncoding.DecodeString(login); err == nil { + return data + } + panic(err) } func main() { diff --git a/models.go b/models.go index 42f5e72..229c06d 100644 --- a/models.go +++ b/models.go @@ -2,6 +2,7 @@ package rate5 import ( "sync" + "sync/atomic" "github.com/patrickmn/go-cache" ) @@ -20,6 +21,10 @@ type Identity interface { UniqueKey() string } +type rated struct { + seen atomic.Value +} + // Limiter implements an Enforcer to create an arbitrary ratelimiter. type Limiter struct { Source Identity @@ -32,7 +37,7 @@ type Limiter struct { Debug bool count int - known map[interface{}]int + known map[interface{}]*rated mu *sync.RWMutex } diff --git a/ratelimiter.go b/ratelimiter.go index 8e7f3da..e2827f0 100644 --- a/ratelimiter.go +++ b/ratelimiter.go @@ -3,6 +3,7 @@ package rate5 import ( "fmt" "sync" + "sync/atomic" "time" "github.com/patrickmn/go-cache" @@ -53,7 +54,7 @@ func newLimiter(policy Policy) *Limiter { q := new(Limiter) q.Ruleset = policy q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second) - q.known = make(map[interface{}]int) + q.known = make(map[interface{}]*rated) q.mu = &sync.RWMutex{} return q } @@ -68,14 +69,24 @@ func (q *Limiter) DebugChannel() chan string { return debugChannel } +func (s *rated) inc() { + if s.seen.Load() == nil { + s.seen.Store(1) + return + } + s.seen.Store(s.seen.Load().(int) + 1) +} + func (q *Limiter) strictLogic(src string, count int) { q.mu.Lock() if _, ok := q.known[src]; !ok { - q.known[src] = 1 + q.known[src]=&rated{ + seen: atomic.Value{}, + } } - q.known[src]++ - extwindow := q.Ruleset.Window + q.known[src] + q.known[src].inc() + extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int) if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil { q.debugPrint("Rate5: " + err.Error())