Skip to content

Commit

Permalink
feat(rp): Add UnauthorizedHandler (#503)
Browse files Browse the repository at this point in the history
* RP: Add UnauthorizedHandler

Signed-off-by: Jan-Otto Kröpke <[email protected]>

* remove race condition

Signed-off-by: Jan-Otto Kröpke <[email protected]>

* Use optional interface

Signed-off-by: Jan-Otto Kröpke <[email protected]>

---------

Signed-off-by: Jan-Otto Kröpke <[email protected]>
  • Loading branch information
jkroepke authored Jan 9, 2024
1 parent 5dcf6de commit 984e31a
Showing 1 changed file with 46 additions and 16 deletions.
62 changes: 46 additions & 16 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,28 @@ type RelyingParty interface {

// IDTokenVerifier returns the verifier used for oidc id_token verification
IDTokenVerifier() *IDTokenVerifier
// ErrorHandler returns the handler used for callback errors

// ErrorHandler returns the handler used for callback errors
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)

// Logger from the context, or a fallback if set.
Logger(context.Context) (logger *slog.Logger, ok bool)
}

type HasUnauthorizedHandler interface {
// UnauthorizedHandler returns the handler used for unauthorized errors
UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string)
}

type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string)

var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
}
var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) {
http.Error(w, desc, http.StatusUnauthorized)
}

type relyingParty struct {
issuer string
Expand All @@ -91,11 +100,12 @@ type relyingParty struct {
httpClient *http.Client
cookieHandler *httphelper.CookieHandler

errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption
signer jose.Signer
logger *slog.Logger
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string)
idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption
signer jose.Signer
logger *slog.Logger
}

func (rp *relyingParty) OAuthConfig() *oauth2.Config {
Expand Down Expand Up @@ -156,6 +166,10 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
return rp.errorHandler
}

func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) {
return rp.unauthorizedHandler
}

func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
logger, ok = logging.FromContext(ctx)
if ok {
Expand All @@ -169,9 +183,10 @@ func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok boo
// it will use the AuthURL and TokenURL set in config
func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) {
rp := &relyingParty{
oauthConfig: config,
httpClient: httphelper.DefaultHTTPClient,
oauth2Only: true,
oauthConfig: config,
httpClient: httphelper.DefaultHTTPClient,
oauth2Only: true,
unauthorizedHandler: DefaultUnauthorizedHandler,
}

for _, optFunc := range options {
Expand Down Expand Up @@ -268,6 +283,13 @@ func WithErrorHandler(errorHandler ErrorHandler) Option {
}
}

func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option {
return func(rp *relyingParty) error {
rp.unauthorizedHandler = unauthorizedHandler
return nil
}
}

func WithVerifierOpts(opts ...VerifierOption) Option {
return func(rp *relyingParty) error {
rp.verifierOpts = opts
Expand Down Expand Up @@ -355,13 +377,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam

state := stateFn()
if err := trySetStateCookie(w, state, rp); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
return
}
if rp.IsPKCE() {
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
if err != nil {
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
return
}
opts = append(opts, WithCodeChallenge(codeChallenge))
Expand Down Expand Up @@ -448,7 +470,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp)
if err != nil {
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
return
}
params := r.URL.Query()
Expand All @@ -464,7 +486,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
if rp.IsPKCE() {
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
if err != nil {
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp)
return
}
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
Expand All @@ -473,14 +495,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
if rp.Signer() != nil {
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer())
if err != nil {
http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp)
return
}
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
}
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp)
return
}
callback(w, r, tokens, state, rp)
Expand All @@ -500,7 +522,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)
return
}
f(w, r, tokens, state, rp, info)
Expand Down Expand Up @@ -727,3 +749,11 @@ func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHi
}
return ErrRelyingPartyNotSupportRevokeCaller
}

func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) {
if rp, ok := rp.(HasUnauthorizedHandler); ok {
rp.UnauthorizedHandler()(w, r, desc, state)
return
}
http.Error(w, desc, http.StatusUnauthorized)
}

0 comments on commit 984e31a

Please sign in to comment.