Skip to content

Commit

Permalink
Add context to storage's Create endpoints (#2935)
Browse files Browse the repository at this point in the history
* Initial commit

Signed-off-by: PumpkinSeed <[email protected]>

* Finish the syntex fixes

Signed-off-by: PumpkinSeed <[email protected]>

* Add fixes after running the tests

Signed-off-by: PumpkinSeed <[email protected]>

* Change background context to request context

Signed-off-by: PumpkinSeed <[email protected]>

---------

Signed-off-by: PumpkinSeed <[email protected]>
  • Loading branch information
PumpkinSeed authored Jan 25, 2024
1 parent 7ca42d7 commit 2377b0a
Show file tree
Hide file tree
Showing 28 changed files with 214 additions and 186 deletions.
4 changes: 2 additions & 2 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap
Name: req.Client.Name,
LogoURL: req.Client.LogoUrl,
}
if err := d.s.CreateClient(c); err != nil {
if err := d.s.CreateClient(ctx, c); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreateClientResp{AlreadyExists: true}, nil
}
Expand Down Expand Up @@ -177,7 +177,7 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
Username: req.Password.Username,
UserID: req.Password.UserId,
}
if err := d.s.CreatePassword(p); err != nil {
if err := d.s.CreatePassword(ctx, p); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreatePasswordResp{AlreadyExists: true}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func TestRefreshToken(t *testing.T) {
ConnectorData: []byte(`{"some":"data"}`),
}

if err := s.CreateRefresh(r); err != nil {
if err := s.CreateRefresh(ctx, r); err != nil {
t.Fatalf("create refresh token: %v", err)
}

Expand All @@ -280,7 +280,7 @@ func TestRefreshToken(t *testing.T) {
}
session.Refresh[tokenRef.ClientID] = &tokenRef

if err := s.CreateOfflineSessions(session); err != nil {
if err := s.CreateOfflineSessions(ctx, session); err != nil {
t.Fatalf("create offline session: %v", err)
}

Expand Down
8 changes: 5 additions & 3 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
pollIntervalSeconds := 5

switch r.Method {
Expand Down Expand Up @@ -106,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
Expiry: expireTime,
}

if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil {
s.logger.Errorf("Failed to store device request; %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
Expand All @@ -125,7 +126,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
},
}

if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil {
s.logger.Errorf("Failed to store device token %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -280,6 +281,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
switch r.Method {
case http.MethodGet:
userCode := r.FormValue("state")
Expand Down Expand Up @@ -336,7 +338,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
return
}

resp, err := s.exchangeAuthCode(w, authCode, client)
resp, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
Expand Down
14 changes: 7 additions & 7 deletions server/deviceflowhandlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,15 @@ func TestDeviceCallback(t *testing.T) {
})
defer httpServer.Close()

if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil {
if err := s.storage.CreateAuthCode(ctx, tc.testAuthCode); err != nil {
t.Fatalf("failed to create auth code: %v", err)
}

if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("failed to create device request: %v", err)
}

if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil {
t.Fatalf("failed to create device token: %v", err)
}

Expand All @@ -383,7 +383,7 @@ func TestDeviceCallback(t *testing.T) {
Secret: "",
RedirectURIs: []string{deviceCallbackURI},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}

Expand Down Expand Up @@ -660,11 +660,11 @@ func TestDeviceTokenResponse(t *testing.T) {
})
defer httpServer.Close()

if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}

if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil {
t.Fatalf("Failed to store device token %v", err)
}

Expand Down Expand Up @@ -794,7 +794,7 @@ func TestVerifyCodeResponse(t *testing.T) {
})
defer httpServer.Close()

if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}

Expand Down
38 changes: 23 additions & 15 deletions server/handlers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
Expand Down Expand Up @@ -187,6 +188,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authReq, err := s.parseAuthorizationRequest(r)
if err != nil {
s.logger.Errorf("Failed to parse authorization request: %v", err)
Expand Down Expand Up @@ -229,7 +231,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {

// Actually create the auth request
authReq.Expiry = s.now().Add(s.authRequestsValidFor)
if err := s.storage.CreateAuthRequest(*authReq); err != nil {
if err := s.storage.CreateAuthRequest(ctx, *authReq); err != nil {
s.logger.Errorf("Failed to create authorization request: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.")
return
Expand Down Expand Up @@ -305,6 +307,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authID := r.URL.Query().Get("state")
if authID == "" {
s.renderError(r, w, http.StatusBadRequest, "User session error.")
Expand Down Expand Up @@ -360,7 +363,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password")
scopes := parseScopes(authReq.Scopes)

identity, ok, err := pwConn.Login(r.Context(), scopes, username, password)
identity, ok, err := pwConn.Login(ctx, scopes, username, password)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
Expand All @@ -372,7 +375,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
}
return
}
redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector)
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
Expand All @@ -397,6 +400,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var authID string
switch r.Method {
case http.MethodGet: // OAuth2 callback
Expand Down Expand Up @@ -471,7 +475,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
return
}

redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector)
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
Expand All @@ -494,7 +498,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)

// finalizeLogin associates the user's identity with the current AuthRequest, then returns
// the approval page's path.
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) {
func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) {
claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
Expand Down Expand Up @@ -566,7 +570,7 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
return "", false, err
}
Expand Down Expand Up @@ -649,6 +653,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
ctx := r.Context()
if s.now().After(authReq.Expiry) {
s.renderError(r, w, http.StatusBadRequest, "User session has expired.")
return
Expand Down Expand Up @@ -701,7 +706,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
ConnectorData: authReq.ConnectorData,
PKCE: authReq.PKCE,
}
if err := s.storage.CreateAuthCode(code); err != nil {
if err := s.storage.CreateAuthCode(ctx, code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
Expand Down Expand Up @@ -876,6 +881,7 @@ func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string

// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()
code := r.PostFormValue("code")
redirectURI := r.PostFormValue("redirect_uri")

Expand Down Expand Up @@ -926,15 +932,15 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}

tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
tokenResponse, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, tokenResponse)
}

func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
Expand Down Expand Up @@ -1002,7 +1008,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
return nil, err
}

if err := s.storage.CreateRefresh(refresh); err != nil {
if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err
Expand Down Expand Up @@ -1047,7 +1053,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
Expand Down Expand Up @@ -1080,6 +1086,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
}

func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
const prefix = "Bearer "

auth := r.Header.Get("authorization")
Expand All @@ -1091,7 +1098,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
rawIDToken := auth[len(prefix):]

verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(r.Context(), rawIDToken)
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
return
Expand All @@ -1108,6 +1115,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()
// Parse the fields
if err := r.ParseForm(); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest)
Expand Down Expand Up @@ -1177,7 +1185,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Login
username := q.Get("username")
password := q.Get("password")
identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password)
identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest)
Expand Down Expand Up @@ -1252,7 +1260,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
return
}

if err := s.storage.CreateRefresh(refresh); err != nil {
if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -1298,7 +1306,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
Expand Down
Loading

0 comments on commit 2377b0a

Please sign in to comment.