From 7ecfa2e284ee7c66e0372ee2a83fb029d71e4b0a Mon Sep 17 00:00:00 2001 From: Umputun Date: Mon, 9 Dec 2024 14:54:17 -0600 Subject: [PATCH] add BasicAuthWithBcryptHashAndPrompt middleware --- basic_auth.go | 27 +++++++++++++++ basic_auth_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/basic_auth.go b/basic_auth.go index 4c4aa99..6e7b6cd 100644 --- a/basic_auth.go +++ b/basic_auth.go @@ -113,6 +113,33 @@ func BasicAuthWithPrompt(user, passwd string) func(http.Handler) http.Handler { } } +// BasicAuthWithBcryptHashAndPrompt middleware requires basic auth and matches user & bcrypt hashed password +// If the user is not authorized, it will prompt for basic auth +func BasicAuthWithBcryptHashAndPrompt(user, hashedPassword string) func(http.Handler) http.Handler { + checkFn := func(reqUser, reqPasswd string) bool { + if reqUser != user { + return false + } + err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(reqPasswd)) + return err == nil + } + + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + // extract basic auth from request + u, p, ok := r.BasicAuth() + if ok && checkFn(u, p) { + h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey(baContextKey), true))) + return + } + // not authorized, prompt for basic auth + w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + return http.HandlerFunc(fn) + } +} + // GenerateBcryptHash generates a bcrypt hash from a password func GenerateBcryptHash(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) diff --git a/basic_auth_test.go b/basic_auth_test.go index 58b699c..c7eb259 100644 --- a/basic_auth_test.go +++ b/basic_auth_test.go @@ -360,3 +360,87 @@ func TestArgon2InvalidInputs(t *testing.T) { assert.Equal(t, http.StatusForbidden, resp.StatusCode) }) } + +func TestBasicAuthWithBcryptHashAndPrompt(t *testing.T) { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte("good"), bcrypt.MinCost) + require.NoError(t, err) + t.Logf("hashed password: %s", string(hashedPassword)) + + mw := BasicAuthWithBcryptHashAndPrompt("dev", string(hashedPassword)) + + ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("request %s", r.URL) + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("blah")) + require.NoError(t, err) + assert.True(t, IsAuthorized(r.Context())) + }))) + defer ts.Close() + + u := fmt.Sprintf("%s%s", ts.URL, "/something") + client := http.Client{Timeout: 5 * time.Second} + + tests := []struct { + name string + username string + password string + expectedStatus int + checkPrompt bool + }{ + { + name: "no auth provided", + username: "", + password: "", + expectedStatus: http.StatusUnauthorized, + checkPrompt: true, + }, + { + name: "correct credentials", + username: "dev", + password: "good", + expectedStatus: http.StatusOK, + checkPrompt: false, + }, + { + name: "wrong username", + username: "wrong", + password: "good", + expectedStatus: http.StatusUnauthorized, + checkPrompt: true, + }, + { + name: "wrong password", + username: "dev", + password: "bad", + expectedStatus: http.StatusUnauthorized, + checkPrompt: true, + }, + { + name: "empty password", + username: "dev", + password: "", + expectedStatus: http.StatusUnauthorized, + checkPrompt: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", u, http.NoBody) + require.NoError(t, err) + + if tc.username != "" || tc.password != "" { + req.SetBasicAuth(tc.username, tc.password) + } + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, tc.expectedStatus, resp.StatusCode) + + if tc.checkPrompt { + assert.Equal(t, `Basic realm="restricted", charset="UTF-8"`, resp.Header.Get("WWW-Authenticate"), + "should include WWW-Authenticate header") + } + }) + } +}