Skip to content

Commit

Permalink
Add support for websocket rate limiting
Browse files Browse the repository at this point in the history
Should be enabled with `enable_websockets_rate_limiting`
  • Loading branch information
buger committed Oct 11, 2024
1 parent 849d346 commit ac857d3
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 4 deletions.
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ type HttpServerOptionsConfig struct {
// Enabled WebSockets and server side events support
EnableWebSockets bool `json:"enable_websockets"`

EnableWebSocketRateLimiting bool `json:"enable_websockets_rate_limiting"`

// Deprecated. SSL certificates used by Gateway server.
Certificates CertsData `json:"certificates"`

Expand Down
151 changes: 147 additions & 4 deletions gateway/reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -1544,20 +1545,155 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
return fmt.Errorf("response flush: %w", err)
}
errc := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
spc := switchProtocolCopier{user: conn, backend: backConn, req: req, gw: p.Gw, spec: p.TykAPISpec}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
<-errc

res.Body = ioutil.NopCloser(strings.NewReader(""))
select {
case err := <-errc:
if rateLimitErr, ok := err.(*RateLimitError); ok {
statusCode := 4000
reason := rateLimitErr.Message

switch reason {
case "Rate limit exceeded":
statusCode = 4000
case "Quota exceeded":
statusCode = 4001
case "Access denied":
statusCode = 4003
}

// Construct the payload (status code + reason)
payload := make([]byte, 2+len(reason))
binary.BigEndian.PutUint16(payload[0:2], uint16(statusCode))
copy(payload[2:], reason)

// Construct the close frame
frameSize := 1 + 1 + len(payload) // 1 byte for opcode, 1 byte for payload length, payload
if len(payload) > 125 {
frameSize += 2 // 2 additional bytes for extended payload length
}
closeFrame := make([]byte, frameSize)

// Set opcode
closeFrame[0] = 0x88 // WebSocket close opcode

// Set payload length
if len(payload) <= 125 {
closeFrame[1] = byte(len(payload))
} else {
closeFrame[1] = 126 // Extended payload length (2 bytes)
binary.BigEndian.PutUint16(closeFrame[2:4], uint16(len(payload)))
}

// Copy payload (status code + reason)
copy(closeFrame[frameSize-len(payload):], payload)

// Write the close frame
_, err := conn.Write(closeFrame)
if err != nil {
// Handle write error
log.Printf("Error writing close frame: %v", err)
}

// Close the connection
conn.Close()

return nil
}
}

res.Body = io.NopCloser(strings.NewReader(""))

return nil
}

type RateLimitedReader struct {
r io.Reader
req *http.Request
gw *Gateway
session *user.SessionState
rateLimitKey string
quotaKey string
spec *APISpec
}

func NewRateLimitedReader(r io.Reader, req *http.Request, gw *Gateway, spec *APISpec) (*RateLimitedReader, error) {
session := ctxGetSession(req)
rateLimitKey := ctxGetAuthToken(req)
quotaKey := ""

if session == nil {
return nil, errors.New("There is no session and limits found")
}

if pattern, found := session.MetaData["rate_limit_pattern"]; found {
if patternString, ok := pattern.(string); ok && patternString != "" {
if customKeyValue := gw.replaceTykVariables(req, patternString, false); customKeyValue != "" {
rateLimitKey = customKeyValue
quotaKey = customKeyValue
}
}
}

return &RateLimitedReader{
r: r,
req: req,
gw: gw,
session: session,
rateLimitKey: rateLimitKey,
quotaKey: quotaKey,
spec: spec,
}, nil
}

// RateLimitError represents an error due to rate limiting
type RateLimitError struct {
Message string
}

func (e *RateLimitError) Error() string {
return e.Message
}

func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)

if n > 0 {
reason := r.gw.SessionLimiter.ForwardMessage(
r.req, // We don't have access to the original request here
r.session,
r.rateLimitKey,
r.quotaKey,
r.gw.GlobalSessionManager.Store(),
!r.spec.DisableRateLimit,
!r.spec.DisableQuota,
r.spec,
false, // Not a dry run
)

switch reason {
case sessionFailNone:
// Continue as normal
case sessionFailRateLimit:
return n, &RateLimitError{Message: "rate limit exceeded"}
case sessionFailQuota:
return n, &RateLimitError{Message: "quota exceeded"}
default:
return n, &RateLimitError{Message: "access denied"}
}
}
return
}

// switchProtocolCopier exists so goroutines proxying data back and
// forth have nice names in stacks.
type switchProtocolCopier struct {
user, backend io.ReadWriter
req *http.Request
gw *Gateway
spec *APISpec
}

func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
Expand All @@ -1566,7 +1702,14 @@ func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
}

func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
_, err := io.Copy(c.backend, c.user)
var input io.Reader

input, err := NewRateLimitedReader(c.user, c.req, c.gw, c.spec)
if err != nil {
input = c.user
}

_, err = io.Copy(c.backend, input)
errc <- err
}

Expand Down

0 comments on commit ac857d3

Please sign in to comment.