Skip to content

Commit

Permalink
chore: rework auth.* keys, add ctxstore package
Browse files Browse the repository at this point in the history
Using so-called phantom types we can use the types themselves as keys directly without loosing performance.
You no longer need to remember which type was attached to the thing you passed in context and can look up
all fields access directly.

Part of siderolabs#37

Signed-off-by: Dmitriy Matrenichev <[email protected]>
  • Loading branch information
DmitriyMV committed Jul 15, 2024
1 parent 76263e1 commit 4cfc0e6
Show file tree
Hide file tree
Showing 27 changed files with 315 additions and 148 deletions.
3 changes: 2 additions & 1 deletion cmd/omni/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/user"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
"github.com/siderolabs/omni/internal/pkg/features"
"github.com/siderolabs/omni/internal/pkg/siderolink"
"github.com/siderolabs/omni/internal/version"
Expand Down Expand Up @@ -235,7 +236,7 @@ func runWithState(logger *zap.Logger) func(context.Context, state.State, *virtua
return fmt.Errorf("failed to update features config resources: %w", err)
}

ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, authres.Enabled(authConfig))
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: authres.Enabled(authConfig)})

handler, err := backend.NewFrontendHandler(rootCmdArgs.frontendDst, logger)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions internal/backend/grpc/configs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

//go:embed testdata/admin-kubeconfig.yaml
Expand Down Expand Up @@ -193,11 +194,11 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string {
md = metadata.New(nil)
}

ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, true)
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: true})

msg := message.NewGRPC(md, info.FullMethod)

ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: msg})

if r := md.Get("role"); len(r) > 0 {
var parsed role.Role
Expand All @@ -207,7 +208,7 @@ func runServer(t *testing.T, st state.State, opts ...grpc.ServerOption) string {
return nil, err
}

ctx = context.WithValue(ctx, auth.RoleContextKey{}, parsed)
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: parsed})
}

return handler(ctx, req)
Expand Down
10 changes: 6 additions & 4 deletions internal/backend/grpc/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
"github.com/siderolabs/omni/internal/pkg/siderolink"
)

Expand Down Expand Up @@ -939,9 +940,10 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster
return nil, err
}

userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None

if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}

newRole, err := role.Max(userRole, clusterRole)
Expand All @@ -953,7 +955,7 @@ func (s *managementServer) applyClusterAccessPolicy(ctx context.Context, cluster
return ctx, nil
}

return context.WithValue(ctx, auth.RoleContextKey{}, newRole), nil
return ctxstore.WithValue(ctx, auth.RoleContextKey{Role: newRole}), nil
}

func handleError(err error) error {
Expand Down
6 changes: 3 additions & 3 deletions internal/backend/grpc/router/talos_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/siderolabs/omni/internal/backend/dns"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
"github.com/siderolabs/omni/internal/pkg/grpcutil"
)

Expand Down Expand Up @@ -78,9 +79,8 @@ func (backend *TalosBackend) GetConnection(ctx context.Context, fullMethodName s
// we can't use regular gRPC server interceptors here, as proxy interface is a bit different

// prepare context values for the verifier
ctx = context.WithValue(ctx, auth.EnabledAuthContextKey{}, backend.authEnabled)
msg := message.NewGRPC(md, fullMethodName)
ctx = context.WithValue(ctx, auth.GRPCMessageContextKey{}, msg)
ctx = ctxstore.WithValue(ctx, auth.EnabledAuthContextKey{Enabled: backend.authEnabled})
ctx = ctxstore.WithValue(ctx, auth.GRPCMessageContextKey{Message: message.NewGRPC(md, fullMethodName)})

grpcutil.SetShouldLog(ctx, "talos-backend")

Expand Down
2 changes: 1 addition & 1 deletion internal/backend/k8sproxy/k8sproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// clusterContextKey is a type for cluster name.
type clusterContextKey struct{}
type clusterContextKey struct{ ClusterName string }

// Handler implements the HTTP reverse proxy for Kubernetes clusters.
//
Expand Down
4 changes: 3 additions & 1 deletion internal/backend/k8sproxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.uber.org/zap"
"k8s.io/client-go/transport"

"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

const authorizationHeader = "Authorization"
Expand Down Expand Up @@ -107,7 +109,7 @@ func AuthorizeRequest(next http.Handler, keyFunc KeyProvider, clusterUUIDResolve
}

// clone the request before modifying it
req = req.WithContext(context.WithValue(ctx, clusterContextKey{}, clusterName))
req = req.WithContext(ctxstore.WithValue(ctx, clusterContextKey{ClusterName: clusterName}))

// clean all headers which are going to be overridden
req.Header.Del(authorizationHeader)
Expand Down
5 changes: 3 additions & 2 deletions internal/backend/k8sproxy/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"k8s.io/client-go/transport"

"github.com/siderolabs/omni/internal/backend/k8sproxy"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

var mockClusterUUIDResolver = func(_ context.Context, clusterID resource.ID) (string, error) {
Expand Down Expand Up @@ -276,9 +277,9 @@ func TestAuthorize(t *testing.T) {
assert.Equal(t, tc.expectedImpersonateGroups, receivedReq.Header.Values(transport.ImpersonateGroupHeader))
assert.Nil(t, receivedReq.Header.Values("Authorization"))

v, ok := receivedReq.Context().Value(k8sproxy.ClusterContextKey{}).(string)
v, ok := ctxstore.Value[k8sproxy.ClusterContextKey](receivedReq.Context()) //nolint:contextcheck
assert.True(t, ok)
assert.Equal(t, tc.expectedCluster, v)
assert.Equal(t, tc.expectedCluster, v.ClusterName)
})
}
}
5 changes: 3 additions & 2 deletions internal/backend/k8sproxy/multiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/siderolabs/omni/client/api/common"
"github.com/siderolabs/omni/internal/backend/runtime"
"github.com/siderolabs/omni/internal/backend/runtime/kubernetes"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

// multiplexer provides an http.RoundTripper which selects the cluster based on the request context.
Expand Down Expand Up @@ -74,12 +75,12 @@ func newMultiplexer() *multiplexer {

// RoundTrip implements http.RoundTripper interface.
func (m *multiplexer) RoundTrip(req *http.Request) (*http.Response, error) {
clusterName, ok := req.Context().Value(clusterContextKey{}).(string)
clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context())
if !ok {
return nil, errors.New("cluster name not found in request context")
}

rt, err := m.getRT(req.Context(), clusterName)
rt, err := m.getRT(req.Context(), clusterNameVal.ClusterName)
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions internal/backend/k8sproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"go.uber.org/zap"

"github.com/siderolabs/omni/internal/backend/logging"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

// proxyHandler implements the HTTP reverse proxy.
Expand Down Expand Up @@ -44,14 +45,14 @@ func newProxyHandler(m *multiplexer, logger *zap.Logger) *proxyHandler {

// director sets the target URL for the reverse proxy.
func (p *proxyHandler) director(req *http.Request) {
clusterName, ok := req.Context().Value(clusterContextKey{}).(string)
clusterNameVal, ok := ctxstore.Value[clusterContextKey](req.Context())
if !ok {
ctxzap.Error(req.Context(), "cluster name not found in request context")

return
}

connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterName)
connector, err := p.multiplexer.getClusterConnector(req.Context(), clusterNameVal.ClusterName)
if err != nil {
ctxzap.Error(req.Context(), "failed to get cluster connector", zap.Error(err))

Expand Down
3 changes: 2 additions & 1 deletion internal/backend/runtime/omni/omni_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/siderolabs/omni/internal/backend/workloadproxy"
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

// using whitelisted for external API access type.
Expand Down Expand Up @@ -72,7 +73,7 @@ func (suite *OmniRuntimeSuite) SetupTest() {
suite.ctx, suite.ctxCancel = context.WithTimeout(context.Background(), 3*time.Minute)

// disable auth in the context
suite.ctx = context.WithValue(suite.ctx, auth.EnabledAuthContextKey{}, false)
suite.ctx = ctxstore.WithValue(suite.ctx, auth.EnabledAuthContextKey{Enabled: false})

var err error

Expand Down
3 changes: 2 additions & 1 deletion internal/backend/runtime/omni/state_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/config"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

var (
Expand Down Expand Up @@ -257,7 +258,7 @@ func checkForRole(ctx context.Context, st state.State, access state.Access, clus

if clusterRole != role.None && (!requireAll || (requireAll && matchesAll)) {
// override the role in the context with the computed role for this cluster
ctx = context.WithValue(ctx, auth.RoleContextKey{}, clusterRole)
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: clusterRole})
}
}

Expand Down
19 changes: 11 additions & 8 deletions internal/backend/runtime/omni/virtual/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/accesspolicy"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

// State is a virtual state implementation which provides virtual resources.
Expand Down Expand Up @@ -161,16 +162,17 @@ func (v *State) validateKind(kind resource.Kind) error {
}

func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) {
identity, _ := ctx.Value(auth.IdentityContextKey{}).(string) //nolint:errcheck
identityVal, _ := ctxstore.Value[auth.IdentityContextKey](ctx)

userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None

if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}

user := virtual.NewCurrentUser()

user.TypedSpec().Value.Identity = identity
user.TypedSpec().Value.Identity = identityVal.Identity
user.TypedSpec().Value.Role = string(userRole)

version, err := resource.ParseVersion("1")
Expand All @@ -184,9 +186,10 @@ func (v *State) currentUser(ctx context.Context) (*virtual.CurrentUser, error) {
}

func (v *State) permissions(ctx context.Context) (*virtual.Permissions, error) {
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None

if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}

permissions := virtual.NewPermissions()
Expand Down
5 changes: 3 additions & 2 deletions internal/backend/workloadproxy/accessvalidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth/accesspolicy"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

// RoleProvider provides the current actor's role for a cluster.
Expand Down Expand Up @@ -119,10 +120,10 @@ func (p *PGPAccessValidator) ValidateAccess(ctx context.Context, publicKeyID, pu
return parseErr
}

ctx = context.WithValue(ctx, auth.RoleContextKey{}, publicKeyRole)
ctx = ctxstore.WithValue(ctx, auth.RoleContextKey{Role: publicKeyRole})
}

ctx = context.WithValue(ctx, auth.IdentityContextKey{}, publicKey.TypedSpec().Value.GetIdentity().GetEmail())
ctx = ctxstore.WithValue(ctx, auth.IdentityContextKey{Identity: publicKey.TypedSpec().Value.GetIdentity().GetEmail()})

accessRole, err := p.roleProvider.RoleForCluster(ctx, clusterID)
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions internal/pkg/auth/accesspolicy/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
"github.com/siderolabs/omni/internal/pkg/auth"
"github.com/siderolabs/omni/internal/pkg/auth/actor"
"github.com/siderolabs/omni/internal/pkg/auth/role"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

// RoleForCluster returns the role of the current user for the given cluster, and whether the role matches all clusters.
func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.Role, bool, error) {
userRole, userRoleExists := ctx.Value(auth.RoleContextKey{}).(role.Role)
if !userRoleExists {
userRole = role.None
userRole := role.None

if val, ok := ctxstore.Value[auth.RoleContextKey](ctx); ok {
userRole = val.Role
}

ctx = actor.MarkContextAsInternalActor(ctx)
Expand All @@ -38,12 +40,12 @@ func RoleForCluster(ctx context.Context, id resource.ID, st state.State) (role.R
return role.None, false, err
}

identityStr, identityExists := ctx.Value(auth.IdentityContextKey{}).(string)
identityVal, identityExists := ctxstore.Value[auth.IdentityContextKey](ctx)
if !identityExists {
return userRole, false, nil
}

identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityStr).Metadata())
identity, err := safe.StateGet[*authres.Identity](ctx, st, authres.NewIdentity(resources.DefaultNamespace, identityVal.Identity).Metadata())
if err != nil {
if state.IsNotFoundError(err) {
return userRole, false, nil
Expand Down
12 changes: 9 additions & 3 deletions internal/pkg/auth/actor/actor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
// Package actor implements the context marking for internal/external actors.
package actor

import "context"
import (
"context"

"github.com/siderolabs/omni/internal/pkg/ctxstore"
)

// internalActorContextKey is the key for internal actor context.
type internalActorContextKey struct{}

// MarkContextAsInternalActor returns a new derived context from the given context, marked as an internal actor.
func MarkContextAsInternalActor(ctx context.Context) context.Context {
return context.WithValue(ctx, internalActorContextKey{}, struct{}{})
return ctxstore.WithValue(ctx, internalActorContextKey{})
}

// ContextIsInternalActor returns true if the given context is marked as an internal actor.
func ContextIsInternalActor(ctx context.Context) bool {
return ctx.Value(internalActorContextKey{}) != nil
_, ok := ctxstore.Value[internalActorContextKey](ctx)

return ok
}
Loading

0 comments on commit 4cfc0e6

Please sign in to comment.