Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rp): Add UnauthorizedHandler #503

Merged
merged 5 commits into from
Jan 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,28 @@

// IDTokenVerifier returns the verifier used for oidc id_token verification
IDTokenVerifier() *IDTokenVerifier
// ErrorHandler returns the handler used for callback errors

// ErrorHandler returns the handler used for callback errors
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)

// Logger from the context, or a fallback if set.
Logger(context.Context) (logger *slog.Logger, ok bool)
}

type HasUnauthorizedHandler interface {
// UnauthorizedHandler returns the handler used for unauthorized errors
UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string)
}

type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string)

var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
}
var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) {
http.Error(w, desc, http.StatusUnauthorized)

Check warning on line 89 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L88-L89

Added lines #L88 - L89 were not covered by tests
}

type relyingParty struct {
issuer string
Expand All @@ -91,11 +100,12 @@
httpClient *http.Client
cookieHandler *httphelper.CookieHandler

errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption
signer jose.Signer
logger *slog.Logger
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string)
idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption
signer jose.Signer
logger *slog.Logger
}

func (rp *relyingParty) OAuthConfig() *oauth2.Config {
Expand Down Expand Up @@ -156,6 +166,10 @@
return rp.errorHandler
}

func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) {
return rp.unauthorizedHandler

Check warning on line 170 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L169-L170

Added lines #L169 - L170 were not covered by tests
}

func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
logger, ok = logging.FromContext(ctx)
if ok {
Expand All @@ -169,9 +183,10 @@
// it will use the AuthURL and TokenURL set in config
func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) {
rp := &relyingParty{
oauthConfig: config,
httpClient: httphelper.DefaultHTTPClient,
oauth2Only: true,
oauthConfig: config,
httpClient: httphelper.DefaultHTTPClient,
oauth2Only: true,
unauthorizedHandler: DefaultUnauthorizedHandler,

Check warning on line 189 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L186-L189

Added lines #L186 - L189 were not covered by tests
}

for _, optFunc := range options {
Expand Down Expand Up @@ -268,6 +283,13 @@
}
}

func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option {
return func(rp *relyingParty) error {
rp.unauthorizedHandler = unauthorizedHandler
return nil
}

Check warning on line 290 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L286-L290

Added lines #L286 - L290 were not covered by tests
}

func WithVerifierOpts(opts ...VerifierOption) Option {
return func(rp *relyingParty) error {
rp.verifierOpts = opts
Expand Down Expand Up @@ -355,13 +377,13 @@

state := stateFn()
if err := trySetStateCookie(w, state, rp); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)

Check warning on line 380 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L380

Added line #L380 was not covered by tests
return
}
if rp.IsPKCE() {
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
if err != nil {
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)

Check warning on line 386 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L386

Added line #L386 was not covered by tests
return
}
opts = append(opts, WithCodeChallenge(codeChallenge))
Expand Down Expand Up @@ -448,7 +470,7 @@
return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp)
if err != nil {
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)

Check warning on line 473 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L473

Added line #L473 was not covered by tests
return
}
params := r.URL.Query()
Expand All @@ -464,7 +486,7 @@
if rp.IsPKCE() {
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
if err != nil {
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp)

Check warning on line 489 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L489

Added line #L489 was not covered by tests
return
}
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
Expand All @@ -473,14 +495,14 @@
if rp.Signer() != nil {
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer())
if err != nil {
http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp)

Check warning on line 498 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L498

Added line #L498 was not covered by tests
return
}
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
}
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp)

Check warning on line 505 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L505

Added line #L505 was not covered by tests
return
}
callback(w, r, tokens, state, rp)
Expand All @@ -500,7 +522,7 @@
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)

Check warning on line 525 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L525

Added line #L525 was not covered by tests
return
}
f(w, r, tokens, state, rp, info)
Expand Down Expand Up @@ -727,3 +749,11 @@
}
return ErrRelyingPartyNotSupportRevokeCaller
}

func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) {
if rp, ok := rp.(HasUnauthorizedHandler); ok {
rp.UnauthorizedHandler()(w, r, desc, state)
return
}
http.Error(w, desc, http.StatusUnauthorized)

Check warning on line 758 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L753-L758

Added lines #L753 - L758 were not covered by tests
}
Loading