diff --git a/iamruntime/authorization.go b/iamruntime/authorization.go index 3c84b6c..2272512 100644 --- a/iamruntime/authorization.go +++ b/iamruntime/authorization.go @@ -38,6 +38,25 @@ func ContextCheckAccess(ctx context.Context, actions []*authorization.AccessRequ return nil } +// ContextCheckAccessTo builds a check access request and executes it on the runtime in the provided context. +// Arguments must be pairs of Resource ID and Role Actions. +func ContextCheckAccessTo(ctx context.Context, resourceIDActionPairs ...string) error { + if len(resourceIDActionPairs)%2 != 0 { + return fmt.Errorf("%w: invalid argument count", ErrResourceIDActionPairsInvalid) + } + + var checkActions []*authorization.AccessRequestAction + + for i := 0; i < len(resourceIDActionPairs); i += 2 { + checkActions = append(checkActions, &authorization.AccessRequestAction{ + ResourceId: resourceIDActionPairs[i], + Action: resourceIDActionPairs[i+1], + }) + } + + return ContextCheckAccess(ctx, checkActions) +} + // ContextCreateRelationships executes a create relationship request on the runtime in the context. // Context must have a runtime value. // The runtime must implement the iam-runtime's AuthorizationClient. diff --git a/iamruntime/authorization_test.go b/iamruntime/authorization_test.go index ace8236..be4c81d 100644 --- a/iamruntime/authorization_test.go +++ b/iamruntime/authorization_test.go @@ -107,6 +107,79 @@ func TestContextCheckAccess(t *testing.T) { } } +func TestContextCheckAccessTo(t *testing.T) { + authsrv := testauth.NewServer(t) + t.Cleanup(authsrv.Stop) + + testCases := []struct { + name string + actions []string + returnAccessResult authorization.CheckAccessResponse_Result + returnAccessError error + expectCalled map[string][]string + expectError error + }{ + { + "permitted", + []string{ + "testten-abc123", "action_one", + "testten-abc123", "action_two", + "testten-def456", "action_one", + }, + authorization.CheckAccessResponse_RESULT_ALLOWED, + nil, + map[string][]string{ + "testten-abc123": {"action_one", "action_two"}, + "testten-def456": {"action_one"}, + }, + nil, + }, + { + "denied", + []string{"testten-abc123", "action_one"}, + authorization.CheckAccessResponse_RESULT_DENIED, + nil, + map[string][]string{"testten-abc123": {"action_one"}}, + ErrAccessDenied, + }, + { + "error", + []string{"testten-abc123", "action_one"}, + 0, + grpc.ErrServerStopped, + map[string][]string{"testten-abc123": {"action_one"}}, + ErrAccessCheckFailed, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + runtime := new(mockruntime.MockRuntime) + + runtime.Mock.On("CheckAccess", tc.expectCalled).Return(tc.returnAccessResult, tc.returnAccessError) + + token, _, err := jwt.NewParser().ParseUnverified(authsrv.TSignSubject(t, "some subject"), jwt.MapClaims{}) + require.NoError(t, err, "unexpected error creating jwt") + + ctx := context.Background() + + ctx = SetContextRuntime(ctx, runtime) + ctx = SetContextToken(ctx, token) + + err = ContextCheckAccessTo(ctx, tc.actions...) + + if tc.expectError != nil { + require.Error(t, err, "expected error to be returned") + assert.ErrorIs(t, err, tc.expectError, "unexpected error returned") + } else { + assert.NoError(t, err, "expected no error to be returned") + } + + runtime.Mock.AssertExpectations(t) + }) + } +} + func TestContextCreateRelationships(t *testing.T) { testCases := []struct { name string @@ -260,6 +333,19 @@ func ExampleContextCheckAccess() { fmt.Println("Token has access to resource!") } +func ExampleContextCheckAccessTo() { + runtime, _ := NewClient("unix:///tmp/runtime.sock") + + ctx := SetContextRuntime(context.TODO(), runtime) + ctx = SetContextToken(ctx, &jwt.Token{Raw: "some token"}) + + if err := ContextCheckAccessTo(ctx, "resctyp-abc123", "resource_get"); err != nil { + panic("failed to check access: " + err.Error()) + } + + fmt.Println("Token has access to resource!") +} + // StorageResource is used in examples. type StorageResource struct { ID string diff --git a/iamruntime/errors.go b/iamruntime/errors.go index 667936d..e7cd913 100644 --- a/iamruntime/errors.go +++ b/iamruntime/errors.go @@ -27,6 +27,9 @@ var ( // AccessError is the root error for all access related errors. AccessError = fmt.Errorf("%w: access", Error) //nolint:revive,stylecheck // not returned directly, but used as a root error. + // ErrResourceIDActionPairsInvalid is returned when ContextCheckAccessTo has an invalid number of arguments. + ErrResourceIDActionPairsInvalid = fmt.Errorf("%w: ContextCheckAccessTo invalid Resource ID, Action argument pairs", AccessError) + // ErrAccessCheckFailed is the error returned when an access request failed to execute. ErrAccessCheckFailed = fmt.Errorf("%w: failed to check access", AccessError) diff --git a/middleware/echo/iamruntimemiddleware/authorization.go b/middleware/echo/iamruntimemiddleware/authorization.go index fe85bf2..91f9b23 100644 --- a/middleware/echo/iamruntimemiddleware/authorization.go +++ b/middleware/echo/iamruntimemiddleware/authorization.go @@ -23,6 +23,27 @@ func setRuntimeContext(r Runtime, c echo.Context) error { // If any error is returned, the error is converted to an echo error with a proper status code. func CheckAccess(c echo.Context, actions []*authorization.AccessRequestAction, opts ...grpc.CallOption) error { if err := iamruntime.ContextCheckAccess(c.Request().Context(), actions, opts...); err != nil { + switch { + case errors.Is(err, iamruntime.ErrTokenNotFound): + return echo.ErrBadRequest.WithInternal(err) + case errors.Is(err, iamruntime.ErrRuntimeNotFound), + errors.Is(err, iamruntime.ErrAccessCheckFailed), + errors.Is(err, iamruntime.ErrResourceIDActionPairsInvalid): + return echo.ErrInternalServerError.WithInternal(err) + case errors.Is(err, iamruntime.ErrAccessDenied): + return echo.ErrForbidden.WithInternal(err) + default: + return echo.ErrInternalServerError.WithInternal(fmt.Errorf("unknown error: %w", err)) + } + } + + return nil +} + +// CheckAccessTo builds a check access request and executes it on the runtime in the provided context. +// Arguments must be pairs of Resource ID and Role Actions. +func CheckAccessTo(c echo.Context, resourceIDActionPairs ...string) error { + if err := iamruntime.ContextCheckAccessTo(c.Request().Context(), resourceIDActionPairs...); err != nil { switch { case errors.Is(err, iamruntime.ErrTokenNotFound): return echo.ErrBadRequest.WithInternal(err) diff --git a/middleware/echo/iamruntimemiddleware/authorization_test.go b/middleware/echo/iamruntimemiddleware/authorization_test.go index 705e39f..44f55d9 100644 --- a/middleware/echo/iamruntimemiddleware/authorization_test.go +++ b/middleware/echo/iamruntimemiddleware/authorization_test.go @@ -151,6 +151,119 @@ func TestCheckAccess(t *testing.T) { } } +func TestCheckAccessTo(t *testing.T) { + authsrv := testauth.NewServer(t) + t.Cleanup(authsrv.Stop) + + testCases := []struct { + name string + actions []string + returnAccessResult authorization.CheckAccessResponse_Result + returnAccessError error + expectCalled map[string][]string + expectStatus int + expectBody map[string]any + }{ + { + "permitted", + []string{ + "testten-abc123", "action_one", + "testten-abc123", "action_two", + "testten-def456", "action_one", + }, + authorization.CheckAccessResponse_RESULT_ALLOWED, + nil, + map[string][]string{ + "testten-abc123": {"action_one", "action_two"}, + "testten-def456": {"action_one"}, + }, + http.StatusOK, + map[string]any{ + "success": true, + }, + }, + { + "denied", + []string{"testten-abc123", "action_one"}, + authorization.CheckAccessResponse_RESULT_DENIED, + nil, + map[string][]string{"testten-abc123": {"action_one"}}, + http.StatusForbidden, + map[string]any{ + "message": "Forbidden", + "error": "code=403, message=Forbidden, internal=iam-runtime error: access: denied", + }, + }, + { + "error", + []string{"testten-abc123", "action_one"}, + 0, + grpc.ErrServerStopped, + map[string][]string{"testten-abc123": {"action_one"}}, + http.StatusInternalServerError, + map[string]any{ + "message": "Internal Server Error", + "error": "code=500, message=Internal Server Error, internal=iam-runtime error: access: failed to check access: grpc: the server has been stopped", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + runtime := new(mockruntime.MockRuntime) + + runtime.Mock.On("ValidateCredential", "some subject").Return(&authentication.ValidateCredentialResponse{ + Result: authentication.ValidateCredentialResponse_RESULT_VALID, + }, nil) + + runtime.Mock.On("CheckAccess", tc.expectCalled).Return(tc.returnAccessResult, tc.returnAccessError) + + config := NewConfig().WithRuntime(runtime) + + middleware, err := config.ToMiddleware() + require.NoError(t, err, "unexpected error building middleware") + + engine := echo.New() + + engine.Debug = true + + engine.Use(middleware) + + engine.GET("/test", func(c echo.Context) error { + if err := CheckAccessTo(c, tc.actions...); err != nil { + return err + } + + return c.JSON(http.StatusOK, echo.Map{ + "success": true, + }) + }) + + ctx := context.Background() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) + require.NoError(t, err) + + req.Header.Add("Authorization", "Bearer "+authsrv.TSignSubject(t, "some subject")) + + resp := httptest.NewRecorder() + + engine.ServeHTTP(resp, req) + + runtime.Mock.AssertExpectations(t) + + assert.Equal(t, tc.expectStatus, resp.Code, "unexpected status code returned") + + var body map[string]any + + err = json.Unmarshal(resp.Body.Bytes(), &body) + require.NoError(t, err, "unexpected error decoding body") + + assert.Equal(t, tc.expectBody, body, "unexpected body returned") + }) + } +} + func TestCreateRelationships(t *testing.T) { testCases := []struct { name string @@ -339,6 +452,24 @@ func ExampleCheckAccess() { _ = http.ListenAndServe(":8080", engine) } +func ExampleCheckAccessTo() { + middleware, _ := NewConfig().ToMiddleware() + + engine := echo.New() + + engine.Use(middleware) + + engine.GET("/resources/:resource_id", func(c echo.Context) error { + if err := CheckAccessTo(c, c.Param("resource_id"), "resource_get"); err != nil { + return err + } + + return c.String(http.StatusOK, "user has access to resource") + }) + + _ = http.ListenAndServe(":8080", engine) +} + // StorageResource is used in examples. type StorageResource struct { ID string