Skip to content

Commit

Permalink
Enhance: implement atomic.Value for strict logic + update example
Browse files Browse the repository at this point in the history
  • Loading branch information
yunginnanet committed Sep 24, 2021
1 parent fa5679a commit 27822b0
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 97 deletions.
232 changes: 140 additions & 92 deletions _examples/rated.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -58,12 +59,35 @@ type Client struct {

// UniqueKey is an implementation of our Identity interface, in short: Rate5 doesn't care where you derive the string used for ratelimiting
func (c Client) UniqueKey() string {
var err error
var host string
if c.loggedin {
return c.ID
}
if host, _, err = net.SplitHostPort(c.Conn.RemoteAddr().String()); err == nil {
return host
}
panic(err)
}

func argParse() {
if len(os.Args) < 1 {
return
}
for i, arg := range os.Args {
switch arg {
case "-e":
fallthrough
case "--exempt":
if len(os.Args) <= i+1 {
return
}
srv.Exempt[os.Args[i+1]] = true
default:
continue
}

host, _, _ := net.SplitHostPort(c.Conn.RemoteAddr().String())
return host
}
}

func init() {
Expand All @@ -82,97 +106,70 @@ func init() {
mu: &sync.RWMutex{},
}

//srv.Exempt["127.0.0.1"] = true
argParse()


rd := Rater.DebugChannel()
rrd := RegRater.DebugChannel()
crd := CmdRater.DebugChannel()
go watchDebug(rd, rrd, crd)
}

func watchDebug(rd, rrd, crd chan string) {
pre := "[Rate5] "
go func() {
var lastcount = 0
var count = 0
for {
select {
case msg := <-rd:
fmt.Printf("%s Limit: %s \n", pre, msg)
count++
case msg := <-rrd:
fmt.Printf("%s RegLimit: %s \n", pre, msg)
count++
case msg := <-crd:
fmt.Printf("%s CmdLimit: %s \n", pre, msg)
count++
default:
if count - lastcount >= 25 {
lastcount = count
fmt.Println("Rater: ", Rater.GetGrandTotalRated())
fmt.Println("RegRater: ", RegRater.GetGrandTotalRated())
fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated())
}
time.Sleep(time.Duration(10) * time.Millisecond)
var lastcount = 0
var count = 0
for {
select {
case msg := <-rd:
fmt.Printf("%s Limit: %s \n", pre, msg)
count++
case msg := <-rrd:
fmt.Printf("%s RegLimit: %s \n", pre, msg)
count++
case msg := <-crd:
fmt.Printf("%s CmdLimit: %s \n", pre, msg)
count++
default:
if count-lastcount >= 25 {
lastcount = count
fmt.Println("Rater: ", Rater.GetGrandTotalRated())
fmt.Println("RegRater: ", RegRater.GetGrandTotalRated())
fmt.Println("CmdRater: ", CmdRater.GetGrandTotalRated())
}
time.Sleep(time.Duration(10) * time.Millisecond)
}
}()
}

func (s *Server) handleTCP(c *Client) {
if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil {
fmt.Println("error while setting setlinger:", err.Error())
}
}

// skip ratelimit checking for exempt clients
srv.mu.RLock()
_, exempt := srv.Exempt[c.UniqueKey()]
srv.mu.RUnlock()

defer func() {
c.Conn.Close()
println("closed: " + c.Conn.RemoteAddr().String())
}()

// Returns true if ratelimited
if Rater.Check(c) {
c.Conn.Write([]byte("too many connections"))
println(c.UniqueKey() + " ratelimited")
func (s *Server) preLogin(c *Client) {
c.send("Auth: ")
in := c.recv()
switch {
case s.authCheck(c, in):
c.loggedin = true
c.deadline = time.Duration(480) * time.Second
c.send("successful login")
return
}

c.read = bufio.NewReader(c.Conn)

c.Conn.Write(loginBanner())

for {
if !c.connected {
case in == "register":
// no exemption for strict ratelimiter (rate5 testing)
if RegRater.Check(c) {
c.send("you already registered recently\n")
return
}
println("new registration from " + c.UniqueKey())
s.setID(c, s.getUnusedID())
c.send("\nregistration success\n[New ID]: " + c.ID)
return
default:
c.send("invalid. type 'REGISTER' to register a new ID\n")
return
}

time.Sleep(time.Duration(25) * time.Millisecond)
if !c.loggedin {
c.send("Auth: ")
in := c.recv()
switch {
case s.authCheck(c, in):
c.loggedin = true
c.deadline = time.Duration(480) * time.Second
c.send("successful login")
continue
case in == "register":
if !RegRater.Check(c) || exempt {
println("new registration from " + c.UniqueKey())
s.setID(c, s.getUnusedID())
c.send("\nregistration success\n[New ID]: " + c.ID)
return
} else {
c.send("you already registered recently\n")
}
continue
default:
c.send("invalid. type 'REGISTER' to register a new ID\n")
continue
}
}
}

func (s *Server) mainPrompt(c *Client) {
for c.connected {
c.send("\nRate5 > ")
switch c.recv() {
case "history":
Expand All @@ -190,10 +187,60 @@ func (s *Server) handleTCP(c *Client) {
case "logout":
c.loggedin = false
return
default:
c.send("unknown command, are you lost?")
continue
}
}
}

func isExempt(c *Client) bool {
srv.mu.RLock()
_, exempt := srv.Exempt[c.UniqueKey()]
srv.mu.RUnlock()
return exempt
}

func connRateCheck(c *Client) bool {
if isExempt(c) {
return false
}
if Rater.Check(c) {
c.send("too many connections")
println(c.UniqueKey() + " ratelimited")
return true
}
return false
}

func closeConn(c *Client) {
if err := c.Conn.Close(); err != nil {
println(err.Error())
}
println("closed: " + c.Conn.RemoteAddr().String())
}

func (s *Server) handleTCP(c *Client) {
if err := c.Conn.(*net.TCPConn).SetLinger(0); err != nil {
fmt.Println("error while setting setlinger:", err.Error())
}
defer closeConn(c)
if rated := connRateCheck(c); rated {
return
}
c.read = bufio.NewReader(c.Conn)
if _, err := c.Conn.Write(loginBanner()); err != nil {
return
}
for !c.loggedin {
if !c.connected {
return
}
s.preLogin(c)
}
s.mainPrompt(c)
}

func (c *Client) send(data string) {
if err := c.Conn.SetReadDeadline(time.Now().Add(c.deadline)); err != nil {
fmt.Println("error while setting deadline:", err.Error())
Expand All @@ -208,18 +255,16 @@ func (c *Client) recv() string {
fmt.Println("error while setting deadline:", err.Error())
}

// skip ratelimit checking for exempt clients
srv.mu.RLock()
_, ok := srv.Exempt[c.UniqueKey()]
srv.mu.RUnlock()

if CmdRater.Check(c) && !ok {
if !c.loggedin {
// if they hit the ratelimiter during log-in, disconnect them
c.connected = false
if !isExempt(c) {
if CmdRater.Check(c) {
if !c.loggedin {
// if they hit the ratelimiter during log-in, disconnect them
c.connected = false
}
time.Sleep(time.Duration(1250) * time.Millisecond)
}
time.Sleep(time.Duration(1250) * time.Millisecond)
}

in, err := c.read.ReadString('\n')
if err != nil {
println(c.UniqueKey() + ": " + err.Error())
Expand Down Expand Up @@ -289,20 +334,23 @@ func (s *Server) authCheck(c *Client, id string) bool {
if old, ok := s.Map[id]; ok {
s.mu.RUnlock()
old.connected = false
old.Conn.Close()
closeConn(old)
s.replaceSession(c, id)
return true
}

s.mu.RUnlock()
return false

}

func loginBanner() []byte {
var data []byte
var err error
login := "CnwgG1s5MDs0MG1SG1swbRtbMG0gG1s5Nzs0MG3DhhtbMG0bWzBtIBtbOTc7NDBtzpMbWzBtG1swbSAbWzk3OzQwbc6jG1swbRtbMG0gG1swbRtbOTc7MzJtNRtbMG0bWzBtIHwKCg=="
data, _ := base64.StdEncoding.DecodeString(login)
return data
if data, err = base64.StdEncoding.DecodeString(login); err == nil {
return data
}
panic(err)
}

func main() {
Expand Down
7 changes: 6 additions & 1 deletion models.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rate5

import (
"sync"
"sync/atomic"

"github.com/patrickmn/go-cache"
)
Expand All @@ -20,6 +21,10 @@ type Identity interface {
UniqueKey() string
}

type rated struct {
seen atomic.Value
}

// Limiter implements an Enforcer to create an arbitrary ratelimiter.
type Limiter struct {
Source Identity
Expand All @@ -32,7 +37,7 @@ type Limiter struct {
Debug bool

count int
known map[interface{}]int
known map[interface{}]*rated
mu *sync.RWMutex
}

Expand Down
19 changes: 15 additions & 4 deletions ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rate5
import (
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/patrickmn/go-cache"
Expand Down Expand Up @@ -53,7 +54,7 @@ func newLimiter(policy Policy) *Limiter {
q := new(Limiter)
q.Ruleset = policy
q.Patrons = cache.New(time.Duration(q.Ruleset.Window)*time.Second, 5*time.Second)
q.known = make(map[interface{}]int)
q.known = make(map[interface{}]*rated)
q.mu = &sync.RWMutex{}
return q
}
Expand All @@ -68,14 +69,24 @@ func (q *Limiter) DebugChannel() chan string {
return debugChannel
}

func (s *rated) inc() {
if s.seen.Load() == nil {
s.seen.Store(1)
return
}
s.seen.Store(s.seen.Load().(int) + 1)
}

func (q *Limiter) strictLogic(src string, count int) {
q.mu.Lock()
if _, ok := q.known[src]; !ok {
q.known[src] = 1
q.known[src]=&rated{
seen: atomic.Value{},
}
}

q.known[src]++
extwindow := q.Ruleset.Window + q.known[src]
q.known[src].inc()
extwindow := q.Ruleset.Window + q.known[src].seen.Load().(int)

if err := q.Patrons.Replace(src, count, time.Duration(extwindow)*time.Second); err != nil {
q.debugPrint("Rate5: " + err.Error())
Expand Down

0 comments on commit 27822b0

Please sign in to comment.