diff --git a/server/api.go b/server/api.go index d8ca18316d..c0eacb8f80 100644 --- a/server/api.go +++ b/server/api.go @@ -85,7 +85,7 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap Name: req.Client.Name, LogoURL: req.Client.LogoUrl, } - if err := d.s.CreateClient(c); err != nil { + if err := d.s.CreateClient(ctx, c); err != nil { if err == storage.ErrAlreadyExists { return &api.CreateClientResp{AlreadyExists: true}, nil } @@ -177,7 +177,7 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) Username: req.Password.Username, UserID: req.Password.UserId, } - if err := d.s.CreatePassword(p); err != nil { + if err := d.s.CreatePassword(ctx, p); err != nil { if err == storage.ErrAlreadyExists { return &api.CreatePasswordResp{AlreadyExists: true}, nil } diff --git a/server/api_test.go b/server/api_test.go index bc0dcf1128..b023beb558 100644 --- a/server/api_test.go +++ b/server/api_test.go @@ -262,7 +262,7 @@ func TestRefreshToken(t *testing.T) { ConnectorData: []byte(`{"some":"data"}`), } - if err := s.CreateRefresh(r); err != nil { + if err := s.CreateRefresh(ctx, r); err != nil { t.Fatalf("create refresh token: %v", err) } @@ -280,7 +280,7 @@ func TestRefreshToken(t *testing.T) { } session.Refresh[tokenRef.ClientID] = &tokenRef - if err := s.CreateOfflineSessions(session); err != nil { + if err := s.CreateOfflineSessions(ctx, session); err != nil { t.Fatalf("create offline session: %v", err) } diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index 95fed3b3c3..5683e9441a 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -58,6 +58,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() pollIntervalSeconds := 5 switch r.Method { @@ -106,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { Expiry: expireTime, } - if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { + if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil { s.logger.Errorf("Failed to store device request; %v", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return @@ -125,7 +126,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { }, } - if err := s.storage.CreateDeviceToken(deviceToken); err != nil { + if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil { s.logger.Errorf("Failed to store device token %v", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return @@ -280,6 +281,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() switch r.Method { case http.MethodGet: userCode := r.FormValue("state") @@ -336,7 +338,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { return } - resp, err := s.exchangeAuthCode(w, authCode, client) + resp, err := s.exchangeAuthCode(ctx, w, authCode, client) if err != nil { s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err) s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.") diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go index 9abe4a6229..151c75082d 100644 --- a/server/deviceflowhandlers_test.go +++ b/server/deviceflowhandlers_test.go @@ -366,15 +366,15 @@ func TestDeviceCallback(t *testing.T) { }) defer httpServer.Close() - if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil { + if err := s.storage.CreateAuthCode(ctx, tc.testAuthCode); err != nil { t.Fatalf("failed to create auth code: %v", err) } - if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil { t.Fatalf("failed to create device request: %v", err) } - if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { + if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil { t.Fatalf("failed to create device token: %v", err) } @@ -383,7 +383,7 @@ func TestDeviceCallback(t *testing.T) { Secret: "", RedirectURIs: []string{deviceCallbackURI}, } - if err := s.storage.CreateClient(client); err != nil { + if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } @@ -660,11 +660,11 @@ func TestDeviceTokenResponse(t *testing.T) { }) defer httpServer.Close() - if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil { t.Fatalf("Failed to store device token %v", err) } - if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { + if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil { t.Fatalf("Failed to store device token %v", err) } @@ -794,7 +794,7 @@ func TestVerifyCodeResponse(t *testing.T) { }) defer httpServer.Close() - if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil { t.Fatalf("Failed to store device token %v", err) } diff --git a/server/handlers.go b/server/handlers.go index 08a60d48da..58366a8848 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/hmac" "crypto/sha256" "crypto/subtle" @@ -187,6 +188,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() authReq, err := s.parseAuthorizationRequest(r) if err != nil { s.logger.Errorf("Failed to parse authorization request: %v", err) @@ -229,7 +231,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Actually create the auth request authReq.Expiry = s.now().Add(s.authRequestsValidFor) - if err := s.storage.CreateAuthRequest(*authReq); err != nil { + if err := s.storage.CreateAuthRequest(ctx, *authReq); err != nil { s.logger.Errorf("Failed to create authorization request: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.") return @@ -305,6 +307,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() authID := r.URL.Query().Get("state") if authID == "" { s.renderError(r, w, http.StatusBadRequest, "User session error.") @@ -360,7 +363,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { password := r.FormValue("password") scopes := parseScopes(authReq.Scopes) - identity, ok, err := pwConn.Login(r.Context(), scopes, username, password) + identity, ok, err := pwConn.Login(ctx, scopes, username, password) if err != nil { s.logger.Errorf("Failed to login user: %v", err) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) @@ -372,7 +375,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { } return } - redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector) + redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector) if err != nil { s.logger.Errorf("Failed to finalize login: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") @@ -397,6 +400,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() var authID string switch r.Method { case http.MethodGet: // OAuth2 callback @@ -471,7 +475,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } - redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector) + redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector) if err != nil { s.logger.Errorf("Failed to finalize login: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") @@ -494,7 +498,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) // finalizeLogin associates the user's identity with the current AuthRequest, then returns // the approval page's path. -func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) { +func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) { claims := storage.Claims{ UserID: identity.UserID, Username: identity.Username, @@ -566,7 +570,7 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth // Create a new OfflineSession object for the user and add a reference object for // the newly received refreshtoken. - if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { s.logger.Errorf("failed to create offline session: %v", err) return "", false, err } @@ -649,6 +653,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { } func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) { + ctx := r.Context() if s.now().After(authReq.Expiry) { s.renderError(r, w, http.StatusBadRequest, "User session has expired.") return @@ -701,7 +706,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe ConnectorData: authReq.ConnectorData, PKCE: authReq.PKCE, } - if err := s.storage.CreateAuthCode(code); err != nil { + if err := s.storage.CreateAuthCode(ctx, code); err != nil { s.logger.Errorf("Failed to create auth code: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return @@ -876,6 +881,7 @@ func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string // handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3 func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) { + ctx := r.Context() code := r.PostFormValue("code") redirectURI := r.PostFormValue("redirect_uri") @@ -926,7 +932,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } - tokenResponse, err := s.exchangeAuthCode(w, authCode, client) + tokenResponse, err := s.exchangeAuthCode(ctx, w, authCode, client) if err != nil { s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return @@ -934,7 +940,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.writeAccessToken(w, tokenResponse) } -func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) { +func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) { accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) @@ -1002,7 +1008,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo return nil, err } - if err := s.storage.CreateRefresh(refresh); err != nil { + if err := s.storage.CreateRefresh(ctx, refresh); err != nil { s.logger.Errorf("failed to create refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err @@ -1047,7 +1053,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo // Create a new OfflineSession object for the user and add a reference object for // the newly received refreshtoken. - if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { s.logger.Errorf("failed to create offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true @@ -1080,6 +1086,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo } func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() const prefix = "Bearer " auth := r.Header.Get("authorization") @@ -1091,7 +1098,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { rawIDToken := auth[len(prefix):] verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) - idToken, err := verifier.Verify(r.Context(), rawIDToken) + idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden) return @@ -1108,6 +1115,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { } func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { + ctx := r.Context() // Parse the fields if err := r.ParseForm(); err != nil { s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest) @@ -1177,7 +1185,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli // Login username := q.Get("username") password := q.Get("password") - identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password) + identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password) if err != nil { s.logger.Errorf("Failed to login user: %v", err) s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest) @@ -1252,7 +1260,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli return } - if err := s.storage.CreateRefresh(refresh); err != nil { + if err := s.storage.CreateRefresh(ctx, refresh); err != nil { s.logger.Errorf("failed to create refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return @@ -1298,7 +1306,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli // Create a new OfflineSession object for the user and add a reference object for // the newly received refreshtoken. - if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { s.logger.Errorf("failed to create offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true diff --git a/server/handlers_test.go b/server/handlers_test.go index 212b25fe72..f6ada3634d 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -213,7 +213,7 @@ func TestHandleAuthCode(t *testing.T) { Secret: "testclientsecret", RedirectURIs: []string{redirectURL}, } - err = s.storage.CreateClient(client) + err = s.storage.CreateClient(ctx, client) require.NoError(t, err) oauth2Client.config = &oauth2.Config{ @@ -233,6 +233,7 @@ func TestHandleAuthCode(t *testing.T) { } func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { + ctx := context.Background() c := storage.Client{ ID: "test", Secret: "barfoo", @@ -241,7 +242,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { LogoURL: "https://goo.gl/JIyzIC", } - err := s.CreateClient(c) + err := s.CreateClient(ctx, c) require.NoError(t, err) c1 := storage.Connector{ @@ -254,7 +255,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { }`), } - err = s.CreateConnector(c1) + err = s.CreateConnector(ctx, c1) require.NoError(t, err) c2 := storage.Connector{ @@ -263,7 +264,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { Name: "mockURLID", } - err = s.CreateConnector(c2) + err = s.CreateConnector(ctx, c2) require.NoError(t, err) } @@ -467,13 +468,13 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { ResourceVersion: "1", Config: []byte("{\"username\": \"foo\", \"password\": \"password\"}"), } - if err := s.storage.CreateConnector(sc); err != nil { + if err := s.storage.CreateConnector(ctx, sc); err != nil { t.Fatalf("create connector: %v", err) } if _, err := s.OpenConnector(sc); err != nil { t.Fatalf("open connector: %v", err) } - if err := s.storage.CreateAuthRequest(tc.authReq); err != nil { + if err := s.storage.CreateAuthRequest(ctx, tc.authReq); err != nil { t.Fatalf("failed to create AuthRequest: %v", err) } @@ -614,7 +615,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { }) defer httpServer.Close() - if err := s.storage.CreateAuthRequest(tc.authReq); err != nil { + if err := s.storage.CreateAuthRequest(ctx, tc.authReq); err != nil { t.Fatalf("failed to create AuthRequest: %v", err) } rr := httptest.NewRecorder() @@ -712,7 +713,7 @@ func TestHandleTokenExchange(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() httpServer, s := newTestServer(ctx, t, func(c *Config) { - c.Storage.CreateClient(storage.Client{ + c.Storage.CreateClient(ctx, storage.Client{ ID: "client_1", Secret: "secret_1", }) diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index c64c50b330..6b0925c2bd 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -18,6 +18,7 @@ import ( ) func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) { + ctx := context.Background() c := storage.Client{ ID: "test", Secret: "barfoo", @@ -26,7 +27,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo LogoURL: "https://goo.gl/JIyzIC", } - err := s.CreateClient(c) + err := s.CreateClient(ctx, c) require.NoError(t, err) c1 := storage.Connector{ @@ -36,7 +37,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo Config: nil, } - err = s.CreateConnector(c1) + err = s.CreateConnector(ctx, c1) require.NoError(t, err) refresh := storage.RefreshToken{ @@ -64,7 +65,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo refresh.ObsoleteToken = "bar" } - err = s.CreateRefresh(refresh) + err = s.CreateRefresh(ctx, refresh) require.NoError(t, err) offlineSessions := storage.OfflineSessions{ @@ -74,7 +75,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo ConnectorData: nil, } - err = s.CreateOfflineSessions(offlineSessions) + err = s.CreateOfflineSessions(ctx, offlineSessions) require.NoError(t, err) } diff --git a/server/server_test.go b/server/server_test.go index dd21d737e0..f9bfa4a3ba 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -119,7 +119,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi Name: "Mock", ResourceVersion: "1", } - if err := config.Storage.CreateConnector(connector); err != nil { + if err := config.Storage.CreateConnector(ctx, connector); err != nil { t.Fatalf("create connector: %v", err) } @@ -172,10 +172,10 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo Name: "Mock", ResourceVersion: "1", } - if err := config.Storage.CreateConnector(connector); err != nil { + if err := config.Storage.CreateConnector(ctx, connector); err != nil { t.Fatalf("create connector: %v", err) } - if err := config.Storage.CreateConnector(connector2); err != nil { + if err := config.Storage.CreateConnector(ctx, connector2); err != nil { t.Fatalf("create connector: %v", err) } @@ -837,11 +837,11 @@ func TestOAuth2CodeFlow(t *testing.T) { Secret: clientSecret, RedirectURIs: []string{redirectURL}, } - if err := s.storage.CreateClient(client); err != nil { + if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } - if err := s.storage.CreateRefresh(storage.RefreshToken{ + if err := s.storage.CreateRefresh(ctx, storage.RefreshToken{ ID: "existedrefrestoken", ClientID: "unexcistedclientid", }); err != nil { @@ -955,7 +955,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) { Secret: "testclientsecret", RedirectURIs: []string{redirectURL}, } - if err := s.storage.CreateClient(client); err != nil { + if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } @@ -1113,7 +1113,7 @@ func TestCrossClientScopes(t *testing.T) { Secret: "testclientsecret", RedirectURIs: []string{redirectURL}, } - if err := s.storage.CreateClient(client); err != nil { + if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } @@ -1123,7 +1123,7 @@ func TestCrossClientScopes(t *testing.T) { TrustedPeers: []string{"testclient"}, } - if err := s.storage.CreateClient(peer); err != nil { + if err := s.storage.CreateClient(ctx, peer); err != nil { t.Fatalf("failed to create client: %v", err) } @@ -1236,7 +1236,7 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) { Secret: "testclientsecret", RedirectURIs: []string{redirectURL}, } - if err := s.storage.CreateClient(client); err != nil { + if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } @@ -1246,7 +1246,7 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) { TrustedPeers: []string{"testclient"}, } - if err := s.storage.CreateClient(peer); err != nil { + if err := s.storage.CreateClient(ctx, peer); err != nil { t.Fatalf("failed to create client: %v", err) } @@ -1276,6 +1276,7 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) { } func TestPasswordDB(t *testing.T) { + ctx := context.Background() s := memory.New(logger) conn := newPasswordDB(s) @@ -1286,7 +1287,7 @@ func TestPasswordDB(t *testing.T) { t.Fatal(err) } - s.CreatePassword(storage.Password{ + s.CreatePassword(ctx, storage.Password{ Email: "jane@example.com", Username: "jane", UserID: "foobar", @@ -1534,7 +1535,7 @@ func TestRefreshTokenFlow(t *testing.T) { Secret: "testclientsecret", RedirectURIs: []string{redirectURL}, } - if err := s.storage.CreateClient(client); err != nil { + if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } @@ -1633,11 +1634,11 @@ func TestOAuth2DeviceFlow(t *testing.T) { RedirectURIs: []string{deviceCallbackURI}, Public: true, } - if err := s.storage.CreateClient(client); err != nil { + if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } - if err := s.storage.CreateRefresh(storage.RefreshToken{ + if err := s.storage.CreateRefresh(ctx, storage.RefreshToken{ ID: "existedrefrestoken", ClientID: "unexcistedclientid", }); err != nil { diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 2f791cdee4..6c2cb2e476 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -2,6 +2,7 @@ package conformance import ( + "context" "reflect" "sort" "testing" @@ -80,6 +81,7 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) { } func testAuthRequestCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() codeChallenge := storage.PKCE{ CodeChallenge: "code_challenge_test", CodeChallengeMethod: "plain", @@ -111,12 +113,12 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { identity := storage.Claims{Email: "foobar"} - if err := s.CreateAuthRequest(a1); err != nil { + if err := s.CreateAuthRequest(ctx, a1); err != nil { t.Fatalf("failed creating auth request: %v", err) } // Attempt to create same AuthRequest twice. - err := s.CreateAuthRequest(a1) + err := s.CreateAuthRequest(ctx, a1) mustBeErrAlreadyExists(t, "auth request", err) a2 := storage.AuthRequest{ @@ -142,7 +144,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { HMACKey: []byte("hmac_key"), } - if err := s.CreateAuthRequest(a2); err != nil { + if err := s.CreateAuthRequest(ctx, a2); err != nil { t.Fatalf("failed creating auth request: %v", err) } @@ -179,6 +181,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { } func testAuthCodeCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() a1 := storage.AuthCode{ ID: storage.NewID(), ClientID: "client1", @@ -201,7 +204,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { }, } - if err := s.CreateAuthCode(a1); err != nil { + if err := s.CreateAuthCode(ctx, a1); err != nil { t.Fatalf("failed creating auth code: %v", err) } @@ -224,10 +227,10 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { } // Attempt to create same AuthCode twice. - err := s.CreateAuthCode(a1) + err := s.CreateAuthCode(ctx, a1) mustBeErrAlreadyExists(t, "auth code", err) - if err := s.CreateAuthCode(a2); err != nil { + if err := s.CreateAuthCode(ctx, a2); err != nil { t.Fatalf("failed creating auth code: %v", err) } @@ -256,6 +259,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { } func testClientCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() id1 := storage.NewID() c1 := storage.Client{ ID: id1, @@ -267,12 +271,12 @@ func testClientCRUD(t *testing.T, s storage.Storage) { err := s.DeleteClient(id1) mustBeErrNotFound(t, "client", err) - if err := s.CreateClient(c1); err != nil { + if err := s.CreateClient(ctx, c1); err != nil { t.Fatalf("create client: %v", err) } // Attempt to create same Client twice. - err = s.CreateClient(c1) + err = s.CreateClient(ctx, c1) mustBeErrAlreadyExists(t, "client", err) id2 := storage.NewID() @@ -284,7 +288,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) { LogoURL: "https://goo.gl/JIyzIC", } - if err := s.CreateClient(c2); err != nil { + if err := s.CreateClient(ctx, c2); err != nil { t.Fatalf("create client: %v", err) } @@ -325,6 +329,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) { } func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() id := storage.NewID() refresh := storage.RefreshToken{ ID: id, @@ -345,12 +350,12 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { }, ConnectorData: []byte(`{"some":"data"}`), } - if err := s.CreateRefresh(refresh); err != nil { + if err := s.CreateRefresh(ctx, refresh); err != nil { t.Fatalf("create refresh token: %v", err) } // Attempt to create same Refresh Token twice. - err := s.CreateRefresh(refresh) + err := s.CreateRefresh(ctx, refresh) mustBeErrAlreadyExists(t, "refresh token", err) getAndCompare := func(id string, want storage.RefreshToken) { @@ -401,7 +406,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { ConnectorData: []byte(`{"some":"data"}`), } - if err := s.CreateRefresh(refresh2); err != nil { + if err := s.CreateRefresh(ctx, refresh2); err != nil { t.Fatalf("create second refresh token: %v", err) } @@ -443,6 +448,7 @@ func (n byEmail) Less(i, j int) bool { return n[i].Email < n[j].Email } func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func testPasswordCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() // Use bcrypt.MinCost to keep the tests short. passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) if err != nil { @@ -455,12 +461,12 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { Username: "jane", UserID: "foobar", } - if err := s.CreatePassword(password1); err != nil { + if err := s.CreatePassword(ctx, password1); err != nil { t.Fatalf("create password token: %v", err) } // Attempt to create same Password twice. - err = s.CreatePassword(password1) + err = s.CreatePassword(ctx, password1) mustBeErrAlreadyExists(t, "password", err) passwordHash2, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost) @@ -474,7 +480,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { Username: "john", UserID: "barfoo", } - if err := s.CreatePassword(password2); err != nil { + if err := s.CreatePassword(ctx, password2); err != nil { t.Fatalf("create password token: %v", err) } @@ -533,6 +539,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { } func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() userID1 := storage.NewID() session1 := storage.OfflineSessions{ UserID: userID1, @@ -543,12 +550,12 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { // Creating an OfflineSession with an empty Refresh list to ensure that // an empty map is translated as expected by the storage. - if err := s.CreateOfflineSessions(session1); err != nil { + if err := s.CreateOfflineSessions(ctx, session1); err != nil { t.Fatalf("create offline session with UserID = %s: %v", session1.UserID, err) } // Attempt to create same OfflineSession twice. - err := s.CreateOfflineSessions(session1) + err := s.CreateOfflineSessions(ctx, session1) mustBeErrAlreadyExists(t, "offline session", err) userID2 := storage.NewID() @@ -559,7 +566,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { ConnectorData: []byte(`{"some":"data"}`), } - if err := s.CreateOfflineSessions(session2); err != nil { + if err := s.CreateOfflineSessions(ctx, session2); err != nil { t.Fatalf("create offline session with UserID = %s: %v", session2.UserID, err) } @@ -607,6 +614,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { } func testConnectorCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() id1 := storage.NewID() config1 := []byte(`{"issuer": "https://accounts.google.com"}`) c1 := storage.Connector{ @@ -616,12 +624,12 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) { Config: config1, } - if err := s.CreateConnector(c1); err != nil { + if err := s.CreateConnector(ctx, c1); err != nil { t.Fatalf("create connector with ID = %s: %v", c1.ID, err) } // Attempt to create same Connector twice. - err := s.CreateConnector(c1) + err := s.CreateConnector(ctx, c1) mustBeErrAlreadyExists(t, "connector", err) id2 := storage.NewID() @@ -633,7 +641,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) { Config: config2, } - if err := s.CreateConnector(c2); err != nil { + if err := s.CreateConnector(ctx, c2); err != nil { t.Fatalf("create connector with ID = %s: %v", c2.ID, err) } @@ -744,6 +752,7 @@ func testKeysCRUD(t *testing.T, s storage.Storage) { } func testGC(t *testing.T, s storage.Storage) { + ctx := context.Background() est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) @@ -772,7 +781,7 @@ func testGC(t *testing.T, s storage.Storage) { }, } - if err := s.CreateAuthCode(c); err != nil { + if err := s.CreateAuthCode(ctx, c); err != nil { t.Fatalf("failed creating auth code: %v", err) } @@ -823,7 +832,7 @@ func testGC(t *testing.T, s storage.Storage) { HMACKey: []byte("hmac_key"), } - if err := s.CreateAuthRequest(a); err != nil { + if err := s.CreateAuthRequest(ctx, a); err != nil { t.Fatalf("failed creating auth request: %v", err) } @@ -860,7 +869,7 @@ func testGC(t *testing.T, s storage.Storage) { Expiry: expiry, } - if err := s.CreateDeviceRequest(d); err != nil { + if err := s.CreateDeviceRequest(ctx, d); err != nil { t.Fatalf("failed creating device request: %v", err) } @@ -900,7 +909,7 @@ func testGC(t *testing.T, s storage.Storage) { }, } - if err := s.CreateDeviceToken(dt); err != nil { + if err := s.CreateDeviceToken(ctx, dt); err != nil { t.Fatalf("failed creating device token: %v", err) } @@ -931,6 +940,7 @@ func testGC(t *testing.T, s storage.Storage) { // testTimezones tests that backends either fully support timezones or // do the correct standardization. func testTimezones(t *testing.T, s storage.Storage) { + ctx := context.Background() est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) @@ -956,7 +966,7 @@ func testTimezones(t *testing.T, s storage.Storage) { Groups: []string{"a", "b"}, }, } - if err := s.CreateAuthCode(c); err != nil { + if err := s.CreateAuthCode(ctx, c); err != nil { t.Fatalf("failed creating auth code: %v", err) } got, err := s.GetAuthCode(c.ID) @@ -975,6 +985,7 @@ func testTimezones(t *testing.T, s storage.Storage) { } func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() d1 := storage.DeviceRequest{ UserCode: storage.NewUserCode(), DeviceCode: storage.NewID(), @@ -984,12 +995,12 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { Expiry: neverExpire.Round(time.Second), } - if err := s.CreateDeviceRequest(d1); err != nil { + if err := s.CreateDeviceRequest(ctx, d1); err != nil { t.Fatalf("failed creating device request: %v", err) } // Attempt to create same DeviceRequest twice. - err := s.CreateDeviceRequest(d1) + err := s.CreateDeviceRequest(ctx, d1) mustBeErrAlreadyExists(t, "device request", err) got, err := s.GetDeviceRequest(d1.UserCode) @@ -1004,6 +1015,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { } func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { + ctx := context.Background() codeChallenge := storage.PKCE{ CodeChallenge: "code_challenge_test", CodeChallengeMethod: "plain", @@ -1020,12 +1032,12 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { PKCE: codeChallenge, } - if err := s.CreateDeviceToken(d1); err != nil { + if err := s.CreateDeviceToken(ctx, d1); err != nil { t.Fatalf("failed creating device token: %v", err) } // Attempt to create same Device Token twice. - err := s.CreateDeviceToken(d1) + err := s.CreateDeviceToken(ctx, d1) mustBeErrAlreadyExists(t, "device token", err) // Update the device token, simulate a redemption diff --git a/storage/conformance/transactions.go b/storage/conformance/transactions.go index 1d4011a423..69ed5517ad 100644 --- a/storage/conformance/transactions.go +++ b/storage/conformance/transactions.go @@ -1,6 +1,7 @@ package conformance import ( + "context" "testing" "time" @@ -26,6 +27,7 @@ func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) { } func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { + ctx := context.Background() c := storage.Client{ ID: storage.NewID(), Secret: "foobar", @@ -34,7 +36,7 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { LogoURL: "https://goo.gl/JIyzIC", } - if err := s.CreateClient(c); err != nil { + if err := s.CreateClient(ctx, c); err != nil { t.Fatalf("create client: %v", err) } @@ -55,6 +57,7 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { } func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { + ctx := context.Background() a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", @@ -78,7 +81,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { HMACKey: []byte("hmac_key"), } - if err := s.CreateAuthRequest(a); err != nil { + if err := s.CreateAuthRequest(ctx, a); err != nil { t.Fatalf("failed creating auth request: %v", err) } @@ -99,6 +102,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { } func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) { + ctx := context.Background() // Use bcrypt.MinCost to keep the tests short. passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) if err != nil { @@ -111,7 +115,7 @@ func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) { Username: "jane", UserID: "foobar", } - if err := s.CreatePassword(password); err != nil { + if err := s.CreatePassword(ctx, password); err != nil { t.Fatalf("create password token: %v", err) } diff --git a/storage/ent/client/authcode.go b/storage/ent/client/authcode.go index b6b263bff8..8ac1231484 100644 --- a/storage/ent/client/authcode.go +++ b/storage/ent/client/authcode.go @@ -7,7 +7,7 @@ import ( ) // CreateAuthCode saves provided auth code into the database. -func (d *Database) CreateAuthCode(code storage.AuthCode) error { +func (d *Database) CreateAuthCode(ctx context.Context, code storage.AuthCode) error { _, err := d.client.AuthCode.Create(). SetID(code.ID). SetClientID(code.ClientID). @@ -26,7 +26,7 @@ func (d *Database) CreateAuthCode(code storage.AuthCode) error { SetExpiry(code.Expiry.UTC()). SetConnectorID(code.ConnectorID). SetConnectorData(code.ConnectorData). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create auth code: %w", err) } diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go index d68fd438a1..42db702d68 100644 --- a/storage/ent/client/authrequest.go +++ b/storage/ent/client/authrequest.go @@ -8,7 +8,7 @@ import ( ) // CreateAuthRequest saves provided auth request into the database. -func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error { +func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.AuthRequest) error { _, err := d.client.AuthRequest.Create(). SetID(authRequest.ID). SetClientID(authRequest.ClientID). @@ -32,7 +32,7 @@ func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error { SetConnectorID(authRequest.ConnectorID). SetConnectorData(authRequest.ConnectorData). SetHmacKey(authRequest.HMACKey). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create auth request: %w", err) } diff --git a/storage/ent/client/client.go b/storage/ent/client/client.go index 07434bd60b..4cb02c0c83 100644 --- a/storage/ent/client/client.go +++ b/storage/ent/client/client.go @@ -7,7 +7,7 @@ import ( ) // CreateClient saves provided oauth2 client settings into the database. -func (d *Database) CreateClient(client storage.Client) error { +func (d *Database) CreateClient(ctx context.Context, client storage.Client) error { _, err := d.client.OAuth2Client.Create(). SetID(client.ID). SetName(client.Name). @@ -16,7 +16,7 @@ func (d *Database) CreateClient(client storage.Client) error { SetLogoURL(client.LogoURL). SetRedirectUris(client.RedirectURIs). SetTrustedPeers(client.TrustedPeers). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create oauth2 client: %w", err) } diff --git a/storage/ent/client/connector.go b/storage/ent/client/connector.go index bfec4418dd..1534e52241 100644 --- a/storage/ent/client/connector.go +++ b/storage/ent/client/connector.go @@ -7,14 +7,14 @@ import ( ) // CreateConnector saves a connector into the database. -func (d *Database) CreateConnector(connector storage.Connector) error { +func (d *Database) CreateConnector(ctx context.Context, connector storage.Connector) error { _, err := d.client.Connector.Create(). SetID(connector.ID). SetName(connector.Name). SetType(connector.Type). SetResourceVersion(connector.ResourceVersion). SetConfig(connector.Config). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create connector: %w", err) } diff --git a/storage/ent/client/devicerequest.go b/storage/ent/client/devicerequest.go index 6e9c25001d..d8d371c9ba 100644 --- a/storage/ent/client/devicerequest.go +++ b/storage/ent/client/devicerequest.go @@ -8,7 +8,7 @@ import ( ) // CreateDeviceRequest saves provided device request into the database. -func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error { +func (d *Database) CreateDeviceRequest(ctx context.Context, request storage.DeviceRequest) error { _, err := d.client.DeviceRequest.Create(). SetClientID(request.ClientID). SetClientSecret(request.ClientSecret). @@ -17,7 +17,7 @@ func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error { SetDeviceCode(request.DeviceCode). // Save utc time into database because ent doesn't support comparing dates with different timezones SetExpiry(request.Expiry.UTC()). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create device request: %w", err) } diff --git a/storage/ent/client/devicetoken.go b/storage/ent/client/devicetoken.go index 99cf077d02..18d483b98a 100644 --- a/storage/ent/client/devicetoken.go +++ b/storage/ent/client/devicetoken.go @@ -8,7 +8,7 @@ import ( ) // CreateDeviceToken saves provided token into the database. -func (d *Database) CreateDeviceToken(token storage.DeviceToken) error { +func (d *Database) CreateDeviceToken(ctx context.Context, token storage.DeviceToken) error { _, err := d.client.DeviceToken.Create(). SetDeviceCode(token.DeviceCode). SetToken([]byte(token.Token)). @@ -19,7 +19,7 @@ func (d *Database) CreateDeviceToken(token storage.DeviceToken) error { SetStatus(token.Status). SetCodeChallenge(token.PKCE.CodeChallenge). SetCodeChallengeMethod(token.PKCE.CodeChallengeMethod). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create device token: %w", err) } diff --git a/storage/ent/client/offlinesession.go b/storage/ent/client/offlinesession.go index 9f54ea1d3c..22469eced9 100644 --- a/storage/ent/client/offlinesession.go +++ b/storage/ent/client/offlinesession.go @@ -9,7 +9,7 @@ import ( ) // CreateOfflineSessions saves provided offline session into the database. -func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error { +func (d *Database) CreateOfflineSessions(ctx context.Context, session storage.OfflineSessions) error { encodedRefresh, err := json.Marshal(session.Refresh) if err != nil { return fmt.Errorf("encode refresh offline session: %w", err) @@ -22,7 +22,7 @@ func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error SetConnID(session.ConnID). SetConnectorData(session.ConnectorData). SetRefresh(encodedRefresh). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create offline session: %w", err) } diff --git a/storage/ent/client/password.go b/storage/ent/client/password.go index daaae30cea..3e4aace8ae 100644 --- a/storage/ent/client/password.go +++ b/storage/ent/client/password.go @@ -9,13 +9,13 @@ import ( ) // CreatePassword saves provided password into the database. -func (d *Database) CreatePassword(password storage.Password) error { +func (d *Database) CreatePassword(ctx context.Context, password storage.Password) error { _, err := d.client.Password.Create(). SetEmail(password.Email). SetHash(password.Hash). SetUsername(password.Username). SetUserID(password.UserID). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create password: %w", err) } diff --git a/storage/ent/client/refreshtoken.go b/storage/ent/client/refreshtoken.go index eca048f463..6861b07916 100644 --- a/storage/ent/client/refreshtoken.go +++ b/storage/ent/client/refreshtoken.go @@ -7,7 +7,7 @@ import ( ) // CreateRefresh saves provided refresh token into the database. -func (d *Database) CreateRefresh(refresh storage.RefreshToken) error { +func (d *Database) CreateRefresh(ctx context.Context, refresh storage.RefreshToken) error { _, err := d.client.RefreshToken.Create(). SetID(refresh.ID). SetClientID(refresh.ClientID). @@ -26,7 +26,7 @@ func (d *Database) CreateRefresh(refresh storage.RefreshToken) error { // Save utc time into database because ent doesn't support comparing dates with different timezones SetLastUsed(refresh.LastUsed.UTC()). SetCreatedAt(refresh.CreatedAt.UTC()). - Save(context.TODO()) + Save(ctx) if err != nil { return convertDBError("create refresh token: %w", err) } diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index 0343ade96d..e4b24b4a4a 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -29,6 +29,8 @@ const ( defaultStorageTimeout = 5 * time.Second ) +var _ storage.Storage = (*conn)(nil) + type conn struct { db *clientv3.Client logger log.Logger @@ -107,9 +109,7 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error return result, delErr } -func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error { return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a)) } @@ -147,9 +147,7 @@ func (c *conn) DeleteAuthRequest(id string) error { return c.deleteKey(ctx, keyID(authRequestPrefix, id)) } -func (c *conn) CreateAuthCode(a storage.AuthCode) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error { return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a)) } @@ -170,9 +168,7 @@ func (c *conn) DeleteAuthCode(id string) error { return c.deleteKey(ctx, keyID(authCodePrefix, id)) } -func (c *conn) CreateRefresh(r storage.RefreshToken) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error { return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r)) } @@ -227,9 +223,7 @@ func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) { return tokens, nil } -func (c *conn) CreateClient(cli storage.Client) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli) } @@ -281,9 +275,7 @@ func (c *conn) ListClients() (clients []storage.Client, err error) { return clients, nil } -func (c *conn) CreatePassword(p storage.Password) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error { return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p) } @@ -335,9 +327,7 @@ func (c *conn) ListPasswords() (passwords []storage.Password, err error) { return passwords, nil } -func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessions) error { return c.txnCreate(ctx, keySession(s.UserID, s.ConnID), fromStorageOfflineSessions(s)) } @@ -375,9 +365,7 @@ func (c *conn) DeleteOfflineSessions(userID string, connID string) error { return c.deleteKey(ctx, keySession(userID, connID)) } -func (c *conn) CreateConnector(connector storage.Connector) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error { return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector) } @@ -568,9 +556,7 @@ func keySession(userID, connID string) string { return offlineSessionPrefix + strings.ToLower(userID+"|"+connID) } -func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error { return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) } @@ -599,9 +585,7 @@ func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest return requests, nil } -func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) - defer cancel() +func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) error { return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t)) } diff --git a/storage/health.go b/storage/health.go index 1b6e22c662..8cdefddf32 100644 --- a/storage/health.go +++ b/storage/health.go @@ -9,7 +9,7 @@ import ( // NewCustomHealthCheckFunc returns a new health check function. func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func(context.Context) (details interface{}, err error) { - return func(_ context.Context) (details interface{}, err error) { + return func(ctx context.Context) (details interface{}, err error) { a := AuthRequest{ ID: NewID(), ClientID: NewID(), @@ -19,7 +19,7 @@ func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func(context.Cont HMACKey: NewHMACKey(crypto.SHA256), } - if err := s.CreateAuthRequest(a); err != nil { + if err := s.CreateAuthRequest(ctx, a); err != nil { return nil, fmt.Errorf("create auth request: %v", err) } diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 0979f14ac0..c08362b8eb 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -40,6 +40,8 @@ const ( resourceDeviceToken = "devicetokens" ) +var _ storage.Storage = (*client)(nil) + const ( gcResultLimit = 500 ) @@ -232,31 +234,31 @@ func (cli *client) Close() error { return nil } -func (cli *client) CreateAuthRequest(a storage.AuthRequest) error { +func (cli *client) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error { return cli.post(resourceAuthRequest, cli.fromStorageAuthRequest(a)) } -func (cli *client) CreateClient(c storage.Client) error { +func (cli *client) CreateClient(ctx context.Context, c storage.Client) error { return cli.post(resourceClient, cli.fromStorageClient(c)) } -func (cli *client) CreateAuthCode(c storage.AuthCode) error { +func (cli *client) CreateAuthCode(ctx context.Context, c storage.AuthCode) error { return cli.post(resourceAuthCode, cli.fromStorageAuthCode(c)) } -func (cli *client) CreatePassword(p storage.Password) error { +func (cli *client) CreatePassword(ctx context.Context, p storage.Password) error { return cli.post(resourcePassword, cli.fromStoragePassword(p)) } -func (cli *client) CreateRefresh(r storage.RefreshToken) error { +func (cli *client) CreateRefresh(ctx context.Context, r storage.RefreshToken) error { return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r)) } -func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error { +func (cli *client) CreateOfflineSessions(ctx context.Context, o storage.OfflineSessions) error { return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o)) } -func (cli *client) CreateConnector(c storage.Connector) error { +func (cli *client) CreateConnector(ctx context.Context, c storage.Connector) error { return cli.post(resourceConnector, cli.fromStorageConnector(c)) } @@ -681,7 +683,7 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e return result, delErr } -func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error { +func (cli *client) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error { return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d)) } @@ -693,7 +695,7 @@ func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, err return toStorageDeviceRequest(req), nil } -func (cli *client) CreateDeviceToken(t storage.DeviceToken) error { +func (cli *client) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) error { return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) } diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index 07475fcd8f..b4b42688e2 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -302,6 +302,7 @@ func TestRetryOnConflict(t *testing.T) { } func TestRefreshTokenLock(t *testing.T) { + ctx := context.Background() if os.Getenv(kubeconfigPathVariableName) == "" { t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName) } @@ -345,7 +346,7 @@ func TestRefreshTokenLock(t *testing.T) { ConnectorData: []byte(`{"some":"data"}`), } - err = kubeClient.CreateRefresh(r) + err = kubeClient.CreateRefresh(ctx, r) require.NoError(t, err) t.Run("Timeout lock error", func(t *testing.T) { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index a940665714..8e080c9faa 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -2,6 +2,7 @@ package memory import ( + "context" "strings" "sync" "time" @@ -10,6 +11,8 @@ import ( "github.com/dexidp/dex/storage" ) +var _ storage.Storage = (*memStorage)(nil) + // New returns an in memory storage. func New(logger log.Logger) storage.Storage { return &memStorage{ @@ -98,7 +101,7 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err return result, nil } -func (s *memStorage) CreateClient(c storage.Client) (err error) { +func (s *memStorage) CreateClient(ctx context.Context, c storage.Client) (err error) { s.tx(func() { if _, ok := s.clients[c.ID]; ok { err = storage.ErrAlreadyExists @@ -109,7 +112,7 @@ func (s *memStorage) CreateClient(c storage.Client) (err error) { return } -func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { +func (s *memStorage) CreateAuthCode(ctx context.Context, c storage.AuthCode) (err error) { s.tx(func() { if _, ok := s.authCodes[c.ID]; ok { err = storage.ErrAlreadyExists @@ -120,7 +123,7 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { return } -func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { +func (s *memStorage) CreateRefresh(ctx context.Context, r storage.RefreshToken) (err error) { s.tx(func() { if _, ok := s.refreshTokens[r.ID]; ok { err = storage.ErrAlreadyExists @@ -131,7 +134,7 @@ func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { return } -func (s *memStorage) CreateAuthRequest(a storage.AuthRequest) (err error) { +func (s *memStorage) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) (err error) { s.tx(func() { if _, ok := s.authReqs[a.ID]; ok { err = storage.ErrAlreadyExists @@ -142,7 +145,7 @@ func (s *memStorage) CreateAuthRequest(a storage.AuthRequest) (err error) { return } -func (s *memStorage) CreatePassword(p storage.Password) (err error) { +func (s *memStorage) CreatePassword(ctx context.Context, p storage.Password) (err error) { lowerEmail := strings.ToLower(p.Email) s.tx(func() { if _, ok := s.passwords[lowerEmail]; ok { @@ -154,7 +157,7 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) { return } -func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) { +func (s *memStorage) CreateOfflineSessions(ctx context.Context, o storage.OfflineSessions) (err error) { id := offlineSessionID{ userID: o.UserID, connID: o.ConnID, @@ -169,7 +172,7 @@ func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error return } -func (s *memStorage) CreateConnector(connector storage.Connector) (err error) { +func (s *memStorage) CreateConnector(ctx context.Context, connector storage.Connector) (err error) { s.tx(func() { if _, ok := s.connectors[connector.ID]; ok { err = storage.ErrAlreadyExists @@ -481,7 +484,7 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector return } -func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) { +func (s *memStorage) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) (err error) { s.tx(func() { if _, ok := s.deviceRequests[d.UserCode]; ok { err = storage.ErrAlreadyExists @@ -503,7 +506,7 @@ func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceReques return } -func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) { +func (s *memStorage) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) (err error) { s.tx(func() { if _, ok := s.deviceTokens[t.DeviceCode]; ok { err = storage.ErrAlreadyExists diff --git a/storage/memory/static_test.go b/storage/memory/static_test.go index 8513e0ee89..4be23a1e6a 100644 --- a/storage/memory/static_test.go +++ b/storage/memory/static_test.go @@ -1,6 +1,7 @@ package memory import ( + "context" "fmt" "os" "strings" @@ -12,6 +13,7 @@ import ( ) func TestStaticClients(t *testing.T) { + ctx := context.Background() logger := &logrus.Logger{ Out: os.Stderr, Formatter: &logrus.TextFormatter{DisableColors: true}, @@ -23,7 +25,7 @@ func TestStaticClients(t *testing.T) { c2 := storage.Client{ID: "bar", Secret: "bar_secret"} c3 := storage.Client{ID: "spam", Secret: "spam_secret"} - backing.CreateClient(c1) + backing.CreateClient(ctx, c1) s := storage.WithStaticClients(backing, []storage.Client{c2}) tests := []struct { @@ -82,7 +84,7 @@ func TestStaticClients(t *testing.T) { { name: "create client", action: func() error { - return s.CreateClient(c3) + return s.CreateClient(ctx, c3) }, }, } @@ -99,6 +101,7 @@ func TestStaticClients(t *testing.T) { } func TestStaticPasswords(t *testing.T) { + ctx := context.Background() logger := &logrus.Logger{ Out: os.Stderr, Formatter: &logrus.TextFormatter{DisableColors: true}, @@ -111,7 +114,7 @@ func TestStaticPasswords(t *testing.T) { p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"} p4 := storage.Password{Email: "Spam@example.com", Username: "Spam_secret"} - backing.CreatePassword(p1) + backing.CreatePassword(ctx, p1) s := storage.WithStaticPasswords(backing, []storage.Password{p2}, logger) tests := []struct { @@ -164,10 +167,10 @@ func TestStaticPasswords(t *testing.T) { { name: "create passwords", action: func() error { - if err := s.CreatePassword(p4); err != nil { + if err := s.CreatePassword(ctx, p4); err != nil { return err } - return s.CreatePassword(p3) + return s.CreatePassword(ctx, p3) }, wantErr: true, }, @@ -211,6 +214,7 @@ func TestStaticPasswords(t *testing.T) { } func TestStaticConnectors(t *testing.T) { + ctx := context.Background() logger := &logrus.Logger{ Out: os.Stderr, Formatter: &logrus.TextFormatter{DisableColors: true}, @@ -226,7 +230,7 @@ func TestStaticConnectors(t *testing.T) { c2 := storage.Connector{ID: storage.NewID(), Type: "ldap", Name: "ldap", ResourceVersion: "1", Config: config2} c3 := storage.Connector{ID: storage.NewID(), Type: "saml", Name: "saml", ResourceVersion: "1", Config: config3} - backing.CreateConnector(c1) + backing.CreateConnector(ctx, c1) s := storage.WithStaticConnectors(backing, []storage.Connector{c2}) tests := []struct { @@ -285,7 +289,7 @@ func TestStaticConnectors(t *testing.T) { { name: "create connector", action: func() error { - return s.CreateConnector(c3) + return s.CreateConnector(ctx, c3) }, }, } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 7f8666db05..1249243ced 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -1,6 +1,7 @@ package sql import ( + "context" "database/sql" "database/sql/driver" "encoding/json" @@ -83,6 +84,8 @@ type scanner interface { Scan(dest ...interface{}) error } +var _ storage.Storage = (*conn)(nil) + func (c *conn) GarbageCollect(now time.Time) (storage.GCResult, error) { result := storage.GCResult{} @@ -121,7 +124,7 @@ func (c *conn) GarbageCollect(now time.Time) (storage.GCResult, error) { return result, err } -func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { +func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error { _, err := c.Exec(` insert into auth_request ( id, client_id, response_types, scopes, redirect_uri, nonce, state, @@ -229,7 +232,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { return a, nil } -func (c *conn) CreateAuthCode(a storage.AuthCode) error { +func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error { _, err := c.Exec(` insert into auth_code ( id, client_id, scopes, nonce, redirect_uri, @@ -280,7 +283,7 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { return a, nil } -func (c *conn) CreateRefresh(r storage.RefreshToken) error { +func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error { _, err := c.Exec(` insert into refresh_token ( id, client_id, scopes, nonce, @@ -521,7 +524,7 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage }) } -func (c *conn) CreateClient(cli storage.Client) error { +func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { _, err := c.Exec(` insert into client ( id, secret, redirect_uris, trusted_peers, public, name, logo_url @@ -591,7 +594,7 @@ func scanClient(s scanner) (cli storage.Client, err error) { return cli, nil } -func (c *conn) CreatePassword(p storage.Password) error { +func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error { p.Email = strings.ToLower(p.Email) _, err := c.Exec(` insert into password ( @@ -688,7 +691,7 @@ func scanPassword(s scanner) (p storage.Password, err error) { return p, nil } -func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { +func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessions) error { _, err := c.Exec(` insert into offline_session ( user_id, conn_id, refresh, connector_data @@ -761,7 +764,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { return o, nil } -func (c *conn) CreateConnector(connector storage.Connector) error { +func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error { _, err := c.Exec(` insert into connector ( id, type, name, resource_version, config @@ -907,7 +910,7 @@ func (c *conn) delete(table, field, id string) error { return nil } -func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { +func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error { _, err := c.Exec(` insert into device_request ( user_code, device_code, client_id, client_secret, scopes, expiry @@ -926,7 +929,7 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { return nil } -func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { +func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) error { _, err := c.Exec(` insert into device_token ( device_code, status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method @@ -1001,7 +1004,7 @@ func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.Dev _, err = tx.Exec(` update device_token set - status = $1, + status = $1, token = $2, last_request = $3, poll_interval = $4, diff --git a/storage/static.go b/storage/static.go index 806b61f9cd..e8902b9b59 100644 --- a/storage/static.go +++ b/storage/static.go @@ -1,6 +1,7 @@ package storage import ( + "context" "errors" "strings" @@ -60,11 +61,11 @@ func (s staticClientsStorage) ListClients() ([]Client, error) { return append(clients[:n], s.clients...), nil } -func (s staticClientsStorage) CreateClient(c Client) error { +func (s staticClientsStorage) CreateClient(ctx context.Context, c Client) error { if s.isStatic(c.ID) { return errors.New("static clients: read-only cannot create client") } - return s.Storage.CreateClient(c) + return s.Storage.CreateClient(ctx, c) } func (s staticClientsStorage) DeleteClient(id string) error { @@ -140,11 +141,11 @@ func (s staticPasswordsStorage) ListPasswords() ([]Password, error) { return append(passwords[:n], s.passwords...), nil } -func (s staticPasswordsStorage) CreatePassword(p Password) error { +func (s staticPasswordsStorage) CreatePassword(ctx context.Context, p Password) error { if s.isStatic(p.Email) { return errors.New("static passwords: read-only cannot create password") } - return s.Storage.CreatePassword(p) + return s.Storage.CreatePassword(ctx, p) } func (s staticPasswordsStorage) DeletePassword(email string) error { @@ -210,11 +211,11 @@ func (s staticConnectorsStorage) ListConnectors() ([]Connector, error) { return append(connectors[:n], s.connectors...), nil } -func (s staticConnectorsStorage) CreateConnector(c Connector) error { +func (s staticConnectorsStorage) CreateConnector(ctx context.Context, c Connector) error { if s.isStatic(c.ID) { return errors.New("static connectors: read-only cannot create connector") } - return s.Storage.CreateConnector(c) + return s.Storage.CreateConnector(ctx, c) } func (s staticConnectorsStorage) DeleteConnector(id string) error { diff --git a/storage/storage.go b/storage/storage.go index 743d2ecb4d..214c2d49f7 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,6 +1,7 @@ package storage import ( + "context" "crypto" "crypto/rand" "encoding/base32" @@ -76,15 +77,15 @@ type Storage interface { Close() error // TODO(ericchiang): Let the storages set the IDs of these objects. - CreateAuthRequest(a AuthRequest) error - CreateClient(c Client) error - CreateAuthCode(c AuthCode) error - CreateRefresh(r RefreshToken) error - CreatePassword(p Password) error - CreateOfflineSessions(s OfflineSessions) error - CreateConnector(c Connector) error - CreateDeviceRequest(d DeviceRequest) error - CreateDeviceToken(d DeviceToken) error + CreateAuthRequest(ctx context.Context, a AuthRequest) error + CreateClient(ctx context.Context, c Client) error + CreateAuthCode(ctx context.Context, c AuthCode) error + CreateRefresh(ctx context.Context, r RefreshToken) error + CreatePassword(ctx context.Context, p Password) error + CreateOfflineSessions(ctx context.Context, s OfflineSessions) error + CreateConnector(ctx context.Context, c Connector) error + CreateDeviceRequest(ctx context.Context, d DeviceRequest) error + CreateDeviceToken(ctx context.Context, d DeviceToken) error // TODO(ericchiang): return (T, bool, error) so we can indicate not found // requests that way instead of using ErrNotFound.