From a8483bb12dae8f911f7a52e3d9718ea8efc639fa Mon Sep 17 00:00:00 2001 From: kayos Date: Tue, 25 Jun 2024 23:29:01 -0700 Subject: [PATCH] Perf: Increase speed and reduce memory allocations (#14) --- debug.go | 18 ++- models.go | 3 +- ratelimiter.go | 96 ++++++++---- ratelimiter_test.go | 56 +++++-- speedometer/speedometer.go | 108 +++++++++++-- speedometer/speedometer_test.go | 258 ++++++++++++++++++++++++++------ 6 files changed, 432 insertions(+), 107 deletions(-) diff --git a/debug.go b/debug.go index 906693e..3092b71 100644 --- a/debug.go +++ b/debug.go @@ -5,10 +5,24 @@ import ( "sync/atomic" ) +const ( + msgRateLimitExpired = "ratelimit (expired): %s | last count [%d]" + msgDebugEnabled = "rate5 debug enabled" + msgRateLimitedRst = "ratelimit for %s has been reset" + msgRateLimitedNew = "ratelimit %s (new) " + msgRateLimited = "ratelimit %s: last count %d. time: %s" + msgRateLimitStrict = "%s ratelimit for %s: last count %d. time: %s" +) + func (q *Limiter) debugPrintf(format string, a ...interface{}) { if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugDisabled) { return } + if len(a) == 2 { + if _, ok := a[1].(*atomic.Int64); ok { + a[1] = a[1].(*atomic.Int64).Load() + } + } msg := fmt.Sprintf(format, a...) select { case q.debugChannel <- msg: @@ -21,7 +35,7 @@ func (q *Limiter) debugPrintf(format string, a ...interface{}) { func (q *Limiter) setDebugEvict() { q.Patrons.OnEvicted(func(src string, count interface{}) { - q.debugPrintf("ratelimit (expired): %s | last count [%d]", src, count) + q.debugPrintf(msgRateLimitExpired, src, count.(*atomic.Int64).Load()) }) } @@ -29,7 +43,7 @@ func (q *Limiter) SetDebug(on bool) { switch on { case true: if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled) { - q.debugPrintf("rate5 debug enabled") + q.debugPrintf(msgDebugEnabled) } case false: atomic.CompareAndSwapUint32(&q.debug, DebugEnabled, DebugDisabled) diff --git a/models.go b/models.go index 48fd4c9..e5449cd 100644 --- a/models.go +++ b/models.go @@ -3,6 +3,7 @@ package rate5 import ( "fmt" "sync" + "sync/atomic" "github.com/patrickmn/go-cache" ) @@ -46,7 +47,7 @@ type Limiter struct { debug uint32 debugChannel chan string debugLost int64 - known map[interface{}]*int64 + known map[interface{}]*atomic.Int64 debugMutex *sync.RWMutex *sync.RWMutex } diff --git a/ratelimiter.go b/ratelimiter.go index bebfbd4..fa9e6ca 100644 --- a/ratelimiter.go +++ b/ratelimiter.go @@ -9,6 +9,29 @@ import ( "github.com/patrickmn/go-cache" ) +const ( + strictPrefix = "strict" + hardcorePrefix = "hardcore" +) + +var _counters = &sync.Pool{ + New: func() interface{} { + i := &atomic.Int64{} + i.Store(0) + return i + }, +} + +func getCounter() *atomic.Int64 { + got := _counters.Get().(*atomic.Int64) + got.Store(0) + return got +} + +func putCounter(i *atomic.Int64) { + _counters.Put(i) +} + /*NewDefaultLimiter returns a ratelimiter with default settings without Strict mode. * Default window: 25 seconds * Default burst: 25 requests */ @@ -70,28 +93,40 @@ func NewHardcoreLimiter(window int, burst int) *Limiter { return l } +// ResetItem removes an Identity from the limiter's cache. +// This effectively resets the rate limit for the Identity. func (q *Limiter) ResetItem(from Identity) { q.Patrons.Delete(from.UniqueKey()) - q.debugPrintf("ratelimit for %s has been reset", from.UniqueKey()) + q.debugPrintf(msgRateLimitedRst, from.UniqueKey()) +} + +func (q *Limiter) onEvict(src string, count interface{}) { + q.debugPrintf(msgRateLimitExpired, src, count) + putCounter(count.(*atomic.Int64)) + } func newLimiter(policy Policy) *Limiter { window := time.Duration(policy.Window) * time.Second - return &Limiter{ + q := &Limiter{ Ruleset: policy, Patrons: cache.New(window, time.Duration(policy.Window)*time.Second), - known: make(map[interface{}]*int64), + known: make(map[interface{}]*atomic.Int64), RWMutex: &sync.RWMutex{}, debugMutex: &sync.RWMutex{}, debug: DebugDisabled, } + q.Patrons.OnEvicted(q.onEvict) + return q } -func intPtr(i int64) *int64 { - return &i +func intPtr(i int64) *atomic.Int64 { + a := getCounter() + a.Store(i) + return a } -func (q *Limiter) getHitsPtr(src string) *int64 { +func (q *Limiter) getHitsPtr(src string) *atomic.Int64 { q.RLock() if _, ok := q.known[src]; ok { oldPtr := q.known[src] @@ -100,29 +135,29 @@ func (q *Limiter) getHitsPtr(src string) *int64 { } q.RUnlock() q.Lock() - newPtr := intPtr(0) + newPtr := getCounter() q.known[src] = newPtr q.Unlock() return newPtr } -func (q *Limiter) strictLogic(src string, count int64) { +func (q *Limiter) strictLogic(src string, count *atomic.Int64) { knownHits := q.getHitsPtr(src) - atomic.AddInt64(knownHits, 1) + knownHits.Add(1) var extwindow int64 - prefix := "hardcore" + prefix := hardcorePrefix switch { case q.Ruleset.Hardcore && q.Ruleset.Window > 1: - extwindow = atomic.LoadInt64(knownHits) * q.Ruleset.Window + extwindow = knownHits.Load() * q.Ruleset.Window case q.Ruleset.Hardcore && q.Ruleset.Window <= 1: - extwindow = atomic.LoadInt64(knownHits) * 2 + extwindow = knownHits.Load() * 2 case !q.Ruleset.Hardcore: - prefix = "strict" - extwindow = atomic.LoadInt64(knownHits) + q.Ruleset.Window + prefix = strictPrefix + extwindow = knownHits.Load() + q.Ruleset.Window } exttime := time.Duration(extwindow) * time.Second _ = q.Patrons.Replace(src, count, exttime) - q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime) + q.debugPrintf(msgRateLimitStrict, prefix, src, count.Load(), exttime) } func (q *Limiter) CheckStringer(from fmt.Stringer) bool { @@ -133,33 +168,32 @@ func (q *Limiter) CheckStringer(from fmt.Stringer) bool { // Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status. func (q *Limiter) Check(from Identity) (limited bool) { var count int64 - var err error - src := from.UniqueKey() - count, err = q.Patrons.IncrementInt64(src, 1) - if err != nil { - // IncrementInt64 should only error if the value is not an int64, so we can assume it's a new key. - q.debugPrintf("ratelimit %s (new) ", src) + aval, ok := q.Patrons.Get(from.UniqueKey()) + switch { + case !ok: + q.debugPrintf(msgRateLimitedNew, from.UniqueKey()) + aval = intPtr(1) // We can't reproduce this throwing an error, we can only assume that the key is new. - _ = q.Patrons.Add(src, int64(1), time.Duration(q.Ruleset.Window)*time.Second) - return false - } - if count < q.Ruleset.Burst { + _ = q.Patrons.Add(from.UniqueKey(), aval, time.Duration(q.Ruleset.Window)*time.Second) return false + case aval != nil: + count = aval.(*atomic.Int64).Add(1) + if count < q.Ruleset.Burst { + return false + } } if q.Ruleset.Strict { - q.strictLogic(src, count) - } else { - q.debugPrintf("ratelimit %s: last count %d. time: %s", - src, count, time.Duration(q.Ruleset.Window)*time.Second) + q.strictLogic(from.UniqueKey(), aval.(*atomic.Int64)) + return true } + q.debugPrintf(msgRateLimited, from.UniqueKey(), count, time.Duration(q.Ruleset.Window)*time.Second) return true } // Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count. func (q *Limiter) Peek(from Identity) bool { - q.Patrons.DeleteExpired() if ct, ok := q.Patrons.Get(from.UniqueKey()); ok { - count := ct.(int64) + count := ct.(*atomic.Int64).Load() if count > q.Ruleset.Burst { return true } diff --git a/ratelimiter_test.go b/ratelimiter_test.go index 57ba7f6..2098583 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "runtime" "sync" + "sync/atomic" "testing" "time" ) @@ -95,6 +96,7 @@ func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe, stringer bool) { // this test exists here for coverage, we are simulating the debug channel overflowing and then invoking println(). func Test_debugPrintf(t *testing.T) { + t.Parallel() limiter := NewLimiter(1, 1) _ = limiter.DebugChannel() for n := 0; n < 50; n++ { @@ -126,6 +128,7 @@ func Test_ResetItem(t *testing.T) { } func Test_NewDefaultLimiter(t *testing.T) { + t.Parallel() limiter := NewDefaultLimiter() limiter.Check(dummyTicker) peekCheckLimited(t, limiter, false, false) @@ -136,6 +139,7 @@ func Test_NewDefaultLimiter(t *testing.T) { } func Test_CheckAndPeekStringer(t *testing.T) { + t.Parallel() limiter := NewDefaultLimiter() limiter.CheckStringer(dummyTicker) peekCheckLimited(t, limiter, false, true) @@ -146,6 +150,7 @@ func Test_CheckAndPeekStringer(t *testing.T) { } func Test_NewLimiter(t *testing.T) { + t.Parallel() limiter := NewLimiter(5, 1) limiter.Check(dummyTicker) peekCheckLimited(t, limiter, false, false) @@ -154,9 +159,10 @@ func Test_NewLimiter(t *testing.T) { } func Test_NewDefaultStrictLimiter(t *testing.T) { + t.Parallel() limiter := NewDefaultStrictLimiter() - ctx, cancel := context.WithCancel(context.Background()) - go watchDebug(ctx, limiter, t) + // ctx, cancel := context.WithCancel(context.Background()) + // go watchDebug(ctx, limiter, t) time.Sleep(25 * time.Millisecond) for n := 0; n < 25; n++ { limiter.Check(dummyTicker) @@ -164,14 +170,15 @@ func Test_NewDefaultStrictLimiter(t *testing.T) { peekCheckLimited(t, limiter, false, false) limiter.Check(dummyTicker) peekCheckLimited(t, limiter, true, false) - cancel() + // cancel() limiter = nil } func Test_NewStrictLimiter(t *testing.T) { + t.Parallel() limiter := NewStrictLimiter(5, 1) - ctx, cancel := context.WithCancel(context.Background()) - go watchDebug(ctx, limiter, t) + // ctx, cancel := context.WithCancel(context.Background()) + // go watchDebug(ctx, limiter, t) limiter.Check(dummyTicker) peekCheckLimited(t, limiter, false, false) limiter.Check(dummyTicker) @@ -190,11 +197,12 @@ func Test_NewStrictLimiter(t *testing.T) { peekCheckLimited(t, limiter, true, false) time.Sleep(8 * time.Second) peekCheckLimited(t, limiter, false, false) - cancel() + // cancel() limiter = nil } func Test_NewHardcoreLimiter(t *testing.T) { + t.Parallel() limiter := NewHardcoreLimiter(1, 5) ctx, cancel := context.WithCancel(context.Background()) go watchDebug(ctx, limiter, t) @@ -305,7 +313,7 @@ testloop: if ci, ok = limiter.Patrons.Get(rp.UniqueKey()); !ok { t.Fatal("randomPatron does not exist in ratelimiter at all!") } - ct := ci.(int64) + ct := ci.(*atomic.Int64).Load() if limiter.Peek(rp) && !shouldLimit { t.Logf("(%d goroutines running)", runtime.NumGoroutine()) // runtime.Breakpoint() @@ -323,16 +331,19 @@ testloop: } func Test_ConcurrentShouldNotLimit(t *testing.T) { + t.Parallel() concurrentTest(t, 50, 20, 20, false) concurrentTest(t, 50, 50, 50, false) } func Test_ConcurrentShouldLimit(t *testing.T) { + t.Parallel() concurrentTest(t, 50, 21, 20, true) concurrentTest(t, 50, 51, 50, true) } func Test_debugChannelOverflow(t *testing.T) { + t.Parallel() limiter := NewDefaultLimiter() _ = limiter.DebugChannel() for n := 0; n != 78; n++ { @@ -347,10 +358,23 @@ func Test_debugChannelOverflow(t *testing.T) { } } +func TestDebugPrintfTypeAssertion(t *testing.T) { + t.Parallel() + limiter := NewDefaultLimiter() + limiter.SetDebug(true) + asdf := new(atomic.Int64) + asdf.Store(5) + limiter.debugChannel = make(chan string, 1) + limiter.debugPrintf("test %d %d", 1, asdf) + if <-limiter.debugChannel != "test 1 5" { + t.Fatalf("failed to type assert atomic.Int64") + } +} + func BenchmarkCheck(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewDefaultLimiter() + b.ReportAllocs() b.StartTimer() for n := 0; n < b.N; n++ { limiter.Check(dummyTicker) @@ -359,8 +383,8 @@ func BenchmarkCheck(b *testing.B) { func BenchmarkCheckHardcore(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewHardcoreLimiter(25, 25) + b.ReportAllocs() b.StartTimer() for n := 0; n < b.N; n++ { limiter.Check(dummyTicker) @@ -369,8 +393,8 @@ func BenchmarkCheckHardcore(b *testing.B) { func BenchmarkCheckStrict(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewStrictLimiter(25, 25) + b.ReportAllocs() b.StartTimer() for n := 0; n < b.N; n++ { limiter.Check(dummyTicker) @@ -379,8 +403,8 @@ func BenchmarkCheckStrict(b *testing.B) { func BenchmarkCheckStringer(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewDefaultLimiter() + b.ReportAllocs() b.StartTimer() for n := 0; n < b.N; n++ { limiter.CheckStringer(dummyTicker) @@ -389,8 +413,8 @@ func BenchmarkCheckStringer(b *testing.B) { func BenchmarkPeek(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewDefaultLimiter() + b.ReportAllocs() b.StartTimer() for n := 0; n < b.N; n++ { limiter.Peek(dummyTicker) @@ -399,8 +423,8 @@ func BenchmarkPeek(b *testing.B) { func BenchmarkConcurrentCheck(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewDefaultLimiter() + b.ReportAllocs() b.StartTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -411,8 +435,8 @@ func BenchmarkConcurrentCheck(b *testing.B) { func BenchmarkConcurrentSetAndCheckHardcore(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewHardcoreLimiter(25, 25) + b.ReportAllocs() b.StartTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -423,8 +447,8 @@ func BenchmarkConcurrentSetAndCheckHardcore(b *testing.B) { func BenchmarkConcurrentSetAndCheckStrict(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewDefaultStrictLimiter() + b.ReportAllocs() b.StartTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -435,8 +459,8 @@ func BenchmarkConcurrentSetAndCheckStrict(b *testing.B) { func BenchmarkConcurrentPeek(b *testing.B) { b.StopTimer() - b.ReportAllocs() limiter := NewDefaultLimiter() + b.ReportAllocs() b.StartTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { diff --git a/speedometer/speedometer.go b/speedometer/speedometer.go index 06b4f1d..837bf70 100644 --- a/speedometer/speedometer.go +++ b/speedometer/speedometer.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net" "sync" "sync/atomic" "time" @@ -25,6 +26,8 @@ type Speedometer struct { speedLimit *SpeedLimit internal atomics w io.Writer + r io.Reader + c io.Closer } type atomics struct { @@ -87,10 +90,7 @@ func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) { return speedLimit, nil } -func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, ceiling int64) (*Speedometer, error) { - if w == nil { - return nil, errors.New("writer cannot be nil") - } +func newSpeedometer(target any, speedLimit *SpeedLimit, ceiling int64) (*Speedometer, error) { var err error if speedLimit != nil { if speedLimit, err = regulateSpeedLimit(speedLimit); err != nil { @@ -98,12 +98,35 @@ func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, ceiling int64) (*Speedo } } - return &Speedometer{ - w: w, + spd := &Speedometer{ ceiling: ceiling, speedLimit: speedLimit, internal: newAtomics(), - }, nil + } + + switch t := target.(type) { + case io.ReadWriteCloser: + spd.w = t + spd.r = t + spd.c = t + case io.ReadWriter: + spd.w = t + spd.r = t + case io.WriteCloser: + spd.w = t + spd.c = t + case io.ReadCloser: + spd.r = t + spd.c = t + case io.Writer: + spd.w = t + case io.Reader: + spd.r = t + default: + return nil, errors.New("invalid target") + } + + return spd, nil } // NewSpeedometer creates a new Speedometer that wraps the given io.Writer. @@ -112,6 +135,10 @@ func NewSpeedometer(w io.Writer) (*Speedometer, error) { return newSpeedometer(w, nil, -1) } +func NewReadingSpeedometer(r io.Reader) (*Speedometer, error) { + return newSpeedometer(r, nil, -1) +} + // NewLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. // If the speed limit is exceeded, writes to the underlying writer will be limited. // See SpeedLimit for more information. @@ -160,14 +187,23 @@ func (s *Speedometer) Close() error { if s.internal.closed.Load() { return io.ErrClosedPipe } + + var err error + s.internal.stop.Do(func() { s.internal.closed.Store(true) stopped := time.Now() birth := s.internal.birth.Load() duration := stopped.Sub(*birth) s.internal.duration.Store(&duration) + if s.c != nil { + if cErr := s.c.Close(); cErr != nil && !errors.Is(cErr, net.ErrClosed) { + err = cErr + } + } }) - return nil + + return err } // Rate returns the bytes per second rate at which data is being written to the underlying writer. @@ -205,11 +241,48 @@ func (s *Speedometer) slowDown() error { return nil } -// Write writes p to the underlying writer, following all defined speed limits. -func (s *Speedometer) Write(p []byte) (n int, err error) { +var ( + ErrWriteOnly = errors.New("not a reader") + ErrReadOnly = errors.New("not a writer") +) + +type ioType int + +const ( + ioWriter ioType = iota + ioReader +) + +type actor func(p []byte) (n int, err error) + +func (s *Speedometer) chkIOType(t ioType) (actor, error) { if s.internal.closed.Load() { - return 0, io.ErrClosedPipe + return nil, io.ErrClosedPipe + } + + switch t { + case ioWriter: + if s.w == nil { + return nil, ErrReadOnly + } + return s.w.Write, nil + case ioReader: + if s.r == nil { + return nil, ErrWriteOnly + } + return s.r.Read, nil + default: + panic("invalid ioType") } + +} + +func (s *Speedometer) do(t ioType, p []byte) (n int, err error) { + var ioActor actor + if ioActor, err = s.chkIOType(t); err != nil { + return 0, err + } + s.internal.start.Do(func() { now := time.Now() s.internal.birth.Store(&now) @@ -217,7 +290,7 @@ func (s *Speedometer) Write(p []byte) (n int, err error) { // if no speed limit, just write and record if s.speedLimit == nil { - n, err = s.w.Write(p) + n, err = ioActor(p) if err != nil { return n, fmt.Errorf("error writing to underlying writer: %w", err) } @@ -237,8 +310,17 @@ func (s *Speedometer) Write(p []byte) (n int, err error) { _ = s.slowDown() var iErr error - if n, iErr = s.w.Write(p[:accepted]); iErr != nil { + if n, iErr = ioActor(p[:accepted]); iErr != nil { return n, fmt.Errorf("error writing to underlying writer: %w", iErr) } return } + +// Write writes p to the underlying writer, following all defined speed limits. +func (s *Speedometer) Write(p []byte) (n int, err error) { + return s.do(ioWriter, p) +} + +func (s *Speedometer) Read(p []byte) (n int, err error) { + return s.do(ioReader, p) +} diff --git a/speedometer/speedometer_test.go b/speedometer/speedometer_test.go index d0aeadf..5dfb022 100644 --- a/speedometer/speedometer_test.go +++ b/speedometer/speedometer_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "strings" "sync" "sync/atomic" "testing" @@ -22,6 +23,11 @@ func (w *testWriter) Write(p []byte) (n int, err error) { return len(p), nil } +func (w *testWriter) Read(p []byte) (n int, err error) { + atomic.AddInt64(&w.total, int64(len(p))) + return len(p), nil +} + func writeStuff(t *testing.T, target io.Writer, count int) error { t.Helper() write := func() error { @@ -47,37 +53,115 @@ func writeStuff(t *testing.T, target io.Writer, count int) error { return nil } -//nolint:funlen -func Test_Speedometer(t *testing.T) { - t.Parallel() - type results struct { - total int64 - written int - rate float64 - err error +func readStuff(t *testing.T, target io.Reader, count int) error { + t.Helper() + read := func() error { + _, err := target.Read(make([]byte, 1)) + if err != nil { + return fmt.Errorf("error reading: %w", err) + } + return nil } - isIt := func(want, have results) { - t.Helper() - if have.total != want.total { - t.Errorf("total: want %d, have %d", want.total, have.total) - } - if have.written != want.written { - t.Errorf("written: want %d, have %d", want.written, have.written) - } - if have.rate != want.rate { - t.Errorf("rate: want %f, have %f", want.rate, have.rate) + if count < 0 { + var err error + for err = read(); err == nil; err = read() { + time.Sleep(5 * time.Millisecond) } - if !errors.Is(have.err, want.err) { - t.Errorf("wantErr: want %v, have %v", want.err, have.err) + return err + } + for i := 0; i < count; i++ { + if err := read(); err != nil { + return err } } + return nil +} + +type results struct { + total int64 + written int + rate float64 + err error +} + +func isIt(want, have results, t *testing.T) { + t.Helper() + if have.total != want.total { + t.Errorf("total: want %d, have %d", want.total, have.total) + } + if have.written != want.written { + t.Errorf("written: want %d, have %d", want.written, have.written) + } + if have.rate != want.rate { + t.Errorf("rate: want %f, have %f", want.rate, have.rate) + } + if !errors.Is(have.err, want.err) { + t.Errorf("wantErr: want %v, have %v", want.err, have.err) + } +} + +//nolint:funlen +func Test_Speedometer(t *testing.T) { + t.Parallel() var ( errChan = make(chan error, 10) ) t.Run("EarlyClose", func(t *testing.T) { + t.Parallel() + t.Run("Write", func(t *testing.T) { + var ( + err error + cnt int + ) + t.Parallel() + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + go func() { + errChan <- writeStuff(t, sp, -1) + }() + time.Sleep(1 * time.Second) + if closeErr := sp.Close(); closeErr != nil { + t.Errorf("wantErr: want %v, have %v", nil, closeErr) + } + err = <-errChan + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err) + } + cnt, err = sp.Write([]byte("a")) + isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}, t) + }) + t.Run("Read", func(t *testing.T) { + var ( + err error + cnt int + ) + t.Parallel() + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + go func() { + errChan <- readStuff(t, sp.r, -1) + }() + time.Sleep(1 * time.Second) + if closeErr := sp.Close(); closeErr != nil { + t.Errorf("wantErr: want %v, have %v", nil, closeErr) + } + err = <-errChan + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err) + } + cnt, err = sp.Read(make([]byte, 1)) + isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}, t) + }) + }) + + t.Run("EarlyCloseReader", func(t *testing.T) { var ( err error cnt int @@ -88,7 +172,7 @@ func Test_Speedometer(t *testing.T) { t.Errorf("unexpected error: %v", nerr) } go func() { - errChan <- writeStuff(t, sp, -1) + errChan <- readStuff(t, sp, -1) }() time.Sleep(1 * time.Second) if closeErr := sp.Close(); closeErr != nil { @@ -98,8 +182,8 @@ func Test_Speedometer(t *testing.T) { if !errors.Is(err, io.ErrClosedPipe) { t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err) } - cnt, err = sp.Write([]byte("a")) - isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}) + cnt, err = sp.Read(make([]byte, 1)) + isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}, t) }) t.Run("Basic", func(t *testing.T) { @@ -113,13 +197,13 @@ func Test_Speedometer(t *testing.T) { t.Errorf("unexpected error: %v", nerr) } cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()}) + isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()}, t) cnt, err = sp.Write([]byte("aa")) - isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()}) + isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()}, t) cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()}) + isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()}, t) cnt, err = sp.Write([]byte("a")) - isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()}) + isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()}, t) }) t.Run("ConcurrentWrites", func(t *testing.T) { @@ -148,7 +232,7 @@ func Test_Speedometer(t *testing.T) { } wg.Wait() isIt(results{err: nil, written: 100, total: 100}, - results{err: err, written: int(atomic.LoadInt64(&count)), total: sp.Total()}) + results{err: err, written: int(atomic.LoadInt64(&count)), total: sp.Total()}, t) }) t.Run("GottaGoFast", func(t *testing.T) { @@ -294,7 +378,7 @@ func Test_Speedometer(t *testing.T) { } defer func(server net.Listener) { if cErr := server.Close(); cErr != nil { - t.Errorf("Failed to close server: %v", err) + t.Errorf("failed to close server: %v", cErr) } }(server) @@ -304,14 +388,15 @@ func Test_Speedometer(t *testing.T) { aErr error ) if conn, aErr = server.Accept(); aErr != nil { - t.Errorf("Failed to accept connection: %v", err) + t.Errorf("failed to accept connection: %v", aErr) } - t.Logf("Accepted connection from %s", conn.RemoteAddr().String()) + t.Logf("accepted connection from %s", conn.RemoteAddr().String()) defer func(conn net.Conn) { - if cErr := conn.Close(); cErr != nil { - t.Errorf("Failed to close connection: %v", err) + if cErr := conn.Close(); cErr != nil && + !strings.Contains(cErr.Error(), "use of closed network connection") { + t.Errorf("failed to close connection: %v", cErr) } }(conn) @@ -327,7 +412,7 @@ func Test_Speedometer(t *testing.T) { sErr error ) if speedometer, sErr = NewCappedLimitedSpeedometer(conn, speedLimit, 4096); sErr != nil { - t.Errorf("Failed to create speedometer: %v", sErr) + t.Errorf("failed to create speedometer: %v", sErr) } buf := make([]byte, 1024) @@ -344,11 +429,11 @@ func Test_Speedometer(t *testing.T) { case errors.Is(wErr, io.EOF), errors.Is(wErr, ErrLimitReached): return case wErr != nil: - t.Errorf("Failed to write: %v", wErr) + t.Errorf("failed to write: %v", wErr) case n != len(buf): - t.Errorf("Failed to write all bytes: %d", n) + t.Errorf("failed to write all bytes: %d", n) default: - t.Logf("Wrote %d bytes (rate: %v/bps)", n, speedometer.Rate()) + t.Logf("wrote %d bytes (rate: %v/bps)", n, speedometer.Rate()) } } }() @@ -359,12 +444,12 @@ func Test_Speedometer(t *testing.T) { ) if client, aErr = net.Dial("tcp", "localhost:8080"); aErr != nil { - t.Fatalf("Failed to connect to server: %v", err) + t.Fatalf("failed to connect to server: %v", err) } defer func(client net.Conn) { if clErr := client.Close(); clErr != nil { - t.Errorf("Failed to close client: %v", err) + t.Errorf("failed to close client: %v", err) } }(client) @@ -372,18 +457,18 @@ func Test_Speedometer(t *testing.T) { startTime := time.Now() n, cpErr := io.Copy(buf, client) if cpErr != nil { - t.Errorf("Failed to copy: %v", cpErr) + t.Errorf("failed to copy: %v", cpErr) } duration := time.Since(startTime) if buf.Len() == 0 || n == 0 { - t.Fatalf("No data received") + t.Fatalf("nNo data received") } rate := measureRate(t, n, duration) if rate > 512.0 { - t.Fatalf("Rate exceeded: got %f, expected <= 100.0", rate) + t.Fatalf("rate exceeded: got %f, expected <= 100.0", rate) } }) } @@ -395,6 +480,8 @@ func (bw badWrites) Write(_ []byte) (int, error) { return 0, io.EOF } +func (bw badWrites) Close() error { return io.ErrNoProgress } + func TestImprobableEdgeCasesForCoverage(t *testing.T) { t.Parallel() sp, _ := NewSpeedometer(io.Discard) @@ -415,8 +502,8 @@ func TestImprobableEdgeCasesForCoverage(t *testing.T) { if _, e := sp.Write([]byte("yeet")); !errors.Is(e, io.EOF) { t.Errorf("wrong error from underlying writer err passdown: %v", e) } - if e := sp.Close(); e != nil { - t.Fatal("close err not nil") + if e := sp.Close(); e == nil { + t.Fatal("should have received error when closing with bad writer") } if e := sp.Close(); !errors.Is(e, io.ErrClosedPipe) { t.Errorf("wrong error from already closed speedo: %v", e) @@ -443,7 +530,90 @@ func TestImprobableEdgeCasesForCoverage(t *testing.T) { if sp.speedLimit.Delay != time.Duration(100)*time.Millisecond { t.Fatal("speed limit regularization failed") } +} + +type writeCloser struct{} + +func (wc writeCloser) Write(p []byte) (int, error) { + return len(p), nil +} + +func (wc writeCloser) Close() error { + return nil +} + +type readCloser struct{} + +func (rc readCloser) Read(p []byte) (int, error) { + return len(p), nil +} + +func (rc readCloser) Close() error { + return nil +} + +type reader struct{} + +func (r reader) Read(p []byte) (int, error) { + return len(p), nil +} + +func TestMiscellaneousBehaviorForCoverage(t *testing.T) { + sp, err := NewSpeedometer(writeCloser{}) + if err != nil || sp == nil { + t.Fatal("unexpected error") + } + if act, actErr := sp.chkIOType(ioReader); act != nil || actErr == nil { + t.Fatal("should have received error when checking for reader on writecloser") + } + if sp.w == nil { + t.Fatal("unexpected nil writer") + } + if sp.r != nil { + t.Fatal("unexpected reader") + } + if sp.c == nil { + t.Fatal("unexpected nil closer") + } + sp, err = NewReadingSpeedometer(reader{}) + if err != nil || sp == nil { + t.Fatal("unexpected error") + } + if act, actErr := sp.chkIOType(ioWriter); act != nil || actErr == nil { + t.Fatal("should have received error when checking for writer on readcloser") + } + if sp.w != nil { + t.Fatal("unexpected writer") + } + if sp.r == nil { + t.Fatal("unexpected nil reader") + } + if sp.c != nil { + t.Fatal("unexpected closer") + } + sp, err = NewReadingSpeedometer(readCloser{}) + if err != nil || sp == nil { + t.Fatal("unexpected error") + } + if sp.w != nil { + t.Fatal("unexpected writer") + } + if sp.r == nil { + t.Fatal("unexpected nil reader") + } + if sp.c == nil { + t.Fatal("unexpected nil closer") + } +} +func TestMustPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic") + } + }() + sp, _ := NewSpeedometer(writeCloser{}) + _, _ = sp.chkIOType(ioType(55)) } func measureRate(t *testing.T, received int64, duration time.Duration) float64 {