From 9b199a2675080c149ebed30663cbce55acab7cb6 Mon Sep 17 00:00:00 2001 From: joerger Date: Mon, 27 Nov 2023 18:09:20 -0800 Subject: [PATCH] Fix lint and race condition in test. --- event-handler/cli.go | 3 +- event-handler/teleport_event_test.go | 4 +- event-handler/teleport_events_watcher.go | 5 +- event-handler/teleport_events_watcher_test.go | 64 ++++++++++++------- 4 files changed, 46 insertions(+), 30 deletions(-) diff --git a/event-handler/cli.go b/event-handler/cli.go index fa5557643..2ee3871b8 100644 --- a/event-handler/cli.go +++ b/event-handler/cli.go @@ -22,10 +22,9 @@ import ( "time" "github.com/alecthomas/kong" - "github.com/gravitational/trace" - "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/stringset" + "github.com/gravitational/trace" "github.com/gravitational/teleport-plugins/event-handler/lib" ) diff --git a/event-handler/teleport_event_test.go b/event-handler/teleport_event_test.go index bfe768e8d..085eca442 100644 --- a/event-handler/teleport_event_test.go +++ b/event-handler/teleport_event_test.go @@ -21,6 +21,8 @@ import ( "encoding/hex" "testing" + auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" + "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,8 +30,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport-plugins/event-handler/lib" - auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" - "github.com/gravitational/teleport/api/types/events" ) func TestNew(t *testing.T) { diff --git a/event-handler/teleport_events_watcher.go b/event-handler/teleport_events_watcher.go index 12972a808..0c1444632 100644 --- a/event-handler/teleport_events_watcher.go +++ b/event-handler/teleport_events_watcher.go @@ -21,9 +21,6 @@ import ( "fmt" "time" - "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" @@ -32,6 +29,8 @@ import ( "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/credentials" "github.com/gravitational/teleport/integrations/lib/logger" + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" ) const ( diff --git a/event-handler/teleport_events_watcher_test.go b/event-handler/teleport_events_watcher_test.go index 8ab2f5f28..b3242ac7c 100644 --- a/event-handler/teleport_events_watcher_test.go +++ b/event-handler/teleport_events_watcher_test.go @@ -18,28 +18,46 @@ package main import ( "strconv" + "sync" "testing" "time" - "github.com/gravitational/trace" - "github.com/stretchr/testify/require" - "golang.org/x/net/context" - "github.com/gravitational/teleport/api/client/proto" auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" ) // mockTeleportEventWatcher is Teleport client mock type mockTeleportEventWatcher struct { + mu sync.Mutex // events is the mock list of events events []events.AuditEvent // mockSearchErr is an error to return mockSearchErr error } +func (c *mockTeleportEventWatcher) setEvents(events []events.AuditEvent) { + c.mu.Lock() + defer c.mu.Unlock() + + c.events = events +} + +func (c *mockTeleportEventWatcher) setSearchEventsError(err error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.mockSearchErr = err +} + func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.mockSearchErr != nil { return nil, "", c.mockSearchErr } @@ -151,7 +169,7 @@ func TestEvents(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } } @@ -160,25 +178,25 @@ func TestEvents(t *testing.T) { select { case _, ok := <-chEvt: require.False(t, ok, "Events channel should be closed") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } select { case _, ok := <-chErr: require.False(t, ok, "Error channel should be closed") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } // Events goroutine should return next page errors mockErr := trace.Errorf("error") - mockEventWatcher.mockSearchErr = mockErr + mockEventWatcher.setSearchEventsError(mockErr) select { case err := <-chErr: require.Error(t, mockErr, err) - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } @@ -186,14 +204,14 @@ func TestEvents(t *testing.T) { select { case _, ok := <-chEvt: require.False(t, ok, "Events channel should be closed") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } select { case _, ok := <-chErr: require.False(t, ok, "Error channel should be closed") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } } @@ -221,7 +239,7 @@ func TestUpdatePage(t *testing.T) { chEvt, chErr := client.Events(ctx) // Add an incomplete page of 3 events and collect them. - mockEventWatcher.events = testAuditEvents[:3] + mockEventWatcher.setEvents(testAuditEvents[:3]) var i int for ; i < 3; i++ { select { @@ -234,7 +252,7 @@ func TestUpdatePage(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } } @@ -245,11 +263,11 @@ func TestUpdatePage(t *testing.T) { t.Fatalf("Events channel should be open") case <-chErr: t.Fatalf("Events channel should be open") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): } // Update the event watcher with the full page of events an collect. - mockEventWatcher.events = testAuditEvents[:5] + mockEventWatcher.setEvents(testAuditEvents[:5]) for ; i < 5; i++ { select { case event, ok := <-chEvt: @@ -261,7 +279,7 @@ func TestUpdatePage(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } } @@ -272,11 +290,11 @@ func TestUpdatePage(t *testing.T) { t.Fatalf("Events channel should be open") case <-chErr: t.Fatalf("Events channel should be open") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): } // Add another partial page and collect the events - mockEventWatcher.events = testAuditEvents[:7] + mockEventWatcher.setEvents(testAuditEvents[:7]) for ; i < 7; i++ { select { case event, ok := <-chEvt: @@ -288,19 +306,19 @@ func TestUpdatePage(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } } // Events goroutine should return update page errors mockErr := trace.Errorf("error") - mockEventWatcher.mockSearchErr = mockErr + mockEventWatcher.setSearchEventsError(mockErr) select { case err := <-chErr: require.Error(t, mockErr, err) - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } @@ -308,14 +326,14 @@ func TestUpdatePage(t *testing.T) { select { case _, ok := <-chEvt: require.False(t, ok, "Events channel should be closed") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } select { case _, ok := <-chErr: require.False(t, ok, "Error channel should be closed") - case <-time.After(time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatalf("No events received within deadline") } }