Skip to content

Commit

Permalink
NIFI-13016 Add groups claim mapping from OIDC token for Registry (#9566)
Browse files Browse the repository at this point in the history
Signed-off-by: David Handermann <[email protected]>
  • Loading branch information
hazmat345 authored Jan 24, 2025
1 parent 5fd8cec commit 0c34095
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 18 deletions.
1 change: 1 addition & 0 deletions nifi-registry/nifi-registry-assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
<nifi.registry.security.user.oidc.client.id />
<nifi.registry.security.user.oidc.client.secret />
<nifi.registry.security.user.oidc.preferred.jwsalgorithm />
<nifi.registry.security.user.oidc.claim.groups>groups</nifi.registry.security.user.oidc.claim.groups>

<!-- nifi.registry.properties: revision management properties -->
<nifi.registry.revisions.enabled>false</nifi.registry.revisions.enabled>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ If this value is `none`, NiFi will attempt to validate unsecured/plain tokens. O
JSON Web Key (JWK) provided through the jwks_uri in the metadata found at the discovery URL
|`nifi.registry.security.user.oidc.additional.scopes` | Comma separated scopes that are sent to OpenID Connect Provider in addition to `openid` and `email`.
|`nifi.registry.security.user.oidc.claim.identifying.user` | Claim that identifies the authenticated user. The default value is `email`. Claim names may need to be requested using the `nifi.registry.security.user.oidc.additional.scopes` property
|`nifi.registry.security.user.oidc.claim.groups` | Name of the ID token claim that contains an array of group names of which the
user is a member. Application groups must be supplied from a User Group Provider with matching names in order for the
authorization process to use ID token claim groups. The default value is `groups`.
|==================================================================================================================================================

[[authorization]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
import java.io.IOException;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

public class StandardManagedAuthorizer implements ManagedAuthorizer {

Expand Down Expand Up @@ -95,19 +98,29 @@ public AuthorizationResult authorize(AuthorizationRequest request) throws Author

final UserAndGroups userAndGroups = userGroupProvider.getUserAndGroups(request.getIdentity());

final User user = userAndGroups.getUser();
if (user == null) {
return AuthorizationResult.denied(String.format("Unknown user with identity '%s'.", request.getIdentity()));
}
// combine groups from incoming request with groups from UserAndGroups because the request may contain groups from
// an external identity provider and the membership may not be maintained within any of the UserGroupProviders
final Set<Group> userGroups = new HashSet<>();
userGroups.addAll(userAndGroups.getGroups() == null ? Collections.emptySet() : userAndGroups.getGroups());
userGroups.addAll(getGroups(request.getGroups()));

final Set<Group> userGroups = userAndGroups.getGroups();
if (policy.getUsers().contains(user.getIdentifier()) || containsGroup(userGroups, policy)) {
if (containsUser(userAndGroups.getUser(), policy) || containsGroup(userGroups, policy)) {
return AuthorizationResult.approved();
}

return AuthorizationResult.denied(request.getExplanationSupplier().get());
}

private Set<Group> getGroups(final Set<String> groupNames) {
if (groupNames == null || groupNames.isEmpty()) {
return Collections.emptySet();
}

return userGroupProvider.getGroups().stream()
.filter(group -> groupNames.contains(group.getName()))
.collect(Collectors.toSet());
}

/**
* Determines if the policy contains one of the user's groups.
*
Expand All @@ -129,6 +142,20 @@ private boolean containsGroup(final Set<Group> userGroups, final AccessPolicy po
return false;
}

/**
* Determines if the policy contains the user's identifier.
*
* @param user the user
* @param policy the policy
* @return true if the user is non-null and the user's identifies is contained in the policy's users
*/
private boolean containsUser(final User user, final AccessPolicy policy) {
if (user == null || policy.getUsers().isEmpty()) {
return false;
}
return policy.getUsers().contains(user.getIdentifier());
}

@Override
public String getFingerprint() throws AuthorizationAccessException {
XMLStreamWriter writer = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public class NiFiRegistryProperties extends ApplicationProperties {
public static final String SECURITY_USER_OIDC_PREFERRED_JWSALGORITHM = "nifi.registry.security.user.oidc.preferred.jwsalgorithm";
public static final String SECURITY_USER_OIDC_ADDITIONAL_SCOPES = "nifi.registry.security.user.oidc.additional.scopes";
public static final String SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER = "nifi.registry.security.user.oidc.claim.identifying.user";
public static final String SECURITY_USER_OIDC_CLAIM_GROUPS = "nifi.registry.security.user.oidc.claim.groups";

// Revision Management Properties
public static final String REVISIONS_ENABLED = "nifi.registry.revisions.enabled";
Expand Down Expand Up @@ -481,6 +482,16 @@ public List<String> getOidcAdditionalScopes() {
public String getOidcClaimIdentifyingUser() {
return getProperty(SECURITY_USER_OIDC_CLAIM_IDENTIFYING_USER, "email").trim();
}
/**
* Returns the claim to be used to extract user groups from the OIDC payload.
* Claim must be requested by adding the scope for it.
* Default is 'groups'.
*
* @return The claim to be used to extract user groups.
*/
public String getOidcClaimGroups() {
return getProperty(SECURITY_USER_OIDC_CLAIM_GROUPS, "groups").trim();
}

/**
* Returns the network interface list to use for HTTPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ nifi.registry.security.user.oidc.read.timeout=${nifi.registry.security.user.oidc
nifi.registry.security.user.oidc.client.id=${nifi.registry.security.user.oidc.client.id}
nifi.registry.security.user.oidc.client.secret=${nifi.registry.security.user.oidc.client.secret}
nifi.registry.security.user.oidc.preferred.jwsalgorithm=${nifi.registry.security.user.oidc.preferred.jwsalgorithm}
nifi.registry.security.user.oidc.claim.groups=${nifi.registry.security.user.oidc.claim.groups}

# revision management #
# This feature should remain disabled until a future NiFi release that supports the revision API changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.nifi.registry.security.authentication;

import java.io.Serializable;
import java.util.Collections;
import java.util.Set;

/**
* Authentication response for a user login attempt.
Expand All @@ -27,6 +29,7 @@ public class AuthenticationResponse implements Serializable {
private final String username;
private final long expiration;
private final String issuer;
private final Set<String> groups;

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
Expand All @@ -37,10 +40,24 @@ public class AuthenticationResponse implements Serializable {
* @param issuer The issuer of the token
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer) {
this(identity, username, expiration, issuer, Collections.emptySet());
}

/**
* Creates an authentication response. The username and how long the authentication is valid in milliseconds
*
* @param identity The user identity
* @param username The username
* @param expiration The expiration in milliseconds
* @param issuer The issuer of the token
* @param groups The user groups
*/
public AuthenticationResponse(final String identity, final String username, final long expiration, final String issuer, final Set<String> groups) {
this.identity = identity;
this.username = username;
this.expiration = expiration;
this.issuer = issuer;
this.groups = groups;
}

public String getIdentity() {
Expand All @@ -64,6 +81,10 @@ public long getExpiration() {
return expiration;
}

public Set<String> getGroups() {
return groups;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class IdentityAuthenticationProvider implements AuthenticationProvider {

Expand Down Expand Up @@ -94,7 +95,7 @@ protected AuthenticationSuccessToken buildAuthenticatedToken(
return new AuthenticationSuccessToken(new NiFiUserDetails(
new StandardNiFiUser.Builder()
.identity(mappedIdentity)
.groups(getUserGroups(mappedIdentity))
.groups(getUserGroups(mappedIdentity, response))
.clientAddress(requestToken.getClientAddress())
.build()));
}
Expand All @@ -112,6 +113,12 @@ protected Set<String> getUserGroups(final String identity) {
return getUserGroups(authorizer, identity);
}

protected Set<String> getUserGroups(final String identity, AuthenticationResponse response) {
return Stream
.concat(getUserGroups(authorizer, identity).stream(), response.getGroups().stream())
.collect(Collectors.toSet());
}

private static Set<String> getUserGroups(final Authorizer authorizer, final String userIdentity) {
if (authorizer instanceof ManagedAuthorizer) {
final ManagedAuthorizer managedAuthorizer = (ManagedAuthorizer) authorizer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.nifi.registry.web.security.authentication.jwt;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtException;
import org.apache.nifi.registry.properties.NiFiRegistryProperties;
import org.apache.nifi.registry.security.authentication.AuthenticationRequest;
Expand All @@ -34,6 +36,7 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.Set;
import java.util.concurrent.TimeUnit;

@Component
Expand Down Expand Up @@ -61,16 +64,19 @@ public AuthenticationResponse authenticate(AuthenticationRequest authenticationR
}

final Object credentials = authenticationRequest.getCredentials();
String jwtAuthToken = credentials != null && credentials instanceof String ? (String) credentials : null;

if (credentials == null) {
logger.info("JWT not found in authenticationRequest credentials, returning null.");
return null;
}

try {
final String jwtPrincipal = jwtService.getUserIdentityFromToken(jwtAuthToken);
return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer);
String jwtAuthToken = credentials.toString();
final Jws<Claims> jws = jwtService.parseAndValidateToken(jwtAuthToken);

final String jwtPrincipal = jwtService.getUserIdentityFromToken(jws);
final Set<String> groups = jwtService.getUserGroupsFromToken(jws);

return new AuthenticationResponse(jwtPrincipal, jwtPrincipal, expiration, issuer, groups);
} catch (JwtException e) {
throw new InvalidAuthenticationException(e.getMessage(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
import org.springframework.stereotype.Service;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;

@Service
Expand All @@ -48,6 +54,7 @@ public class JwtService {
private static final MacAlgorithm SIGNATURE_ALGORITHM = Jwts.SIG.HS256;
private static final String KEY_ID_CLAIM = "kid";
private static final String USERNAME_CLAIM = "preferred_username";
private static final String GROUPS_CLAIM = "groups";

private final KeyService keyService;

Expand All @@ -56,7 +63,7 @@ public JwtService(final KeyService keyService) {
this.keyService = keyService;
}

public String getUserIdentityFromToken(final String base64EncodedToken) throws JwtException {
public Jws<Claims> parseAndValidateToken(final String base64EncodedToken) throws JwtException {
// The library representations of the JWT should be kept internal to this service.
try {
final Jws<Claims> jws = parseTokenFromBase64EncodedString(base64EncodedToken);
Expand All @@ -74,14 +81,24 @@ public String getUserIdentityFromToken(final String base64EncodedToken) throws J
if (StringUtils.isEmpty(jws.getPayload().getIssuer())) {
throw new JwtException("No issuer available in token");
}
return jws.getPayload().getSubject();

return jws;
} catch (JwtException e) {
final String errorMessage = "There was an error validating the JWT";
logger.error(errorMessage, e);
throw e;
throw new JwtException("There was an error validating the JWT", e);
}
}

public String getUserIdentityFromToken(final Jws<Claims> jws) throws JwtException {
return jws.getPayload().getSubject();
}

public Set<String> getUserGroupsFromToken(final Jws<Claims> jws) throws JwtException {
@SuppressWarnings("unchecked")
final List<String> groupsString = jws.getPayload().get(GROUPS_CLAIM, ArrayList.class);

return new HashSet<>(groupsString != null ? groupsString : Collections.emptyList());
}

private Jws<Claims> parseTokenFromBase64EncodedString(final String base64EncodedToken) throws JwtException {
try {
return Jwts.parser().setSigningKeyResolver(new SigningKeyResolverAdapter() {
Expand Down Expand Up @@ -125,11 +142,15 @@ public String generateSignedToken(final AuthenticationResponse authenticationRes
authenticationResponse.getUsername(),
authenticationResponse.getIssuer(),
authenticationResponse.getIssuer(),
authenticationResponse.getExpiration());
authenticationResponse.getExpiration(),
null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis) throws JwtException {
return this.generateSignedToken(identity, preferredUsername, issuer, audience, expirationMillis, null);
}

public String generateSignedToken(String identity, String preferredUsername, String issuer, String audience, long expirationMillis, Collection<String> groups) throws JwtException {
if (identity == null || StringUtils.isEmpty(identity)) {
String errorMessage = "Cannot generate a JWT for a token with an empty identity";
errorMessage = issuer != null ? errorMessage + " issued by " + issuer + "." : ".";
Expand All @@ -155,6 +176,7 @@ public String generateSignedToken(String identity, String preferredUsername, Str
.audience().add(audience).and()
.claim(USERNAME_CLAIM, preferredUsername)
.claim(KEY_ID_CLAIM, key.getId())
.claim(GROUPS_CLAIM, groups != null ? groups : Collections.EMPTY_LIST)
.issuedAt(now.getTime())
.expiration(expiration.getTime())
.signWith(Keys.hmacShaKeyFor(keyBytes), SIGNATURE_ALGORITHM).compact();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
String identityClaim = properties.getOidcClaimIdentifyingUser();
String identity = claimsSet.getStringClaim(identityClaim);

// Attempt to extract groups from the configured claim; default is 'groups'
final String groupsClaim = properties.getOidcClaimGroups();
final List<String> groups = claimsSet.getStringListClaim(groupsClaim);

// If default identity not available, attempt secondary identity extraction
if (StringUtils.isBlank(identity)) {
// Provide clear message to admin that desired claim is missing and present available claims
Expand All @@ -425,7 +429,7 @@ private String convertOIDCTokenToNiFiToken(OIDCTokenResponse response) throws Ba
final String issuer = claimsSet.getIssuer().getValue();

// convert into a nifi jwt for retrieval later
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn);
return jwtService.generateSignedToken(identity, identity, issuer, issuer, expiresIn, groups);
}

private String retrieveIdentityFromUserInfoEndpoint(OIDCTokens oidcTokens) throws IOException {
Expand Down

0 comments on commit 0c34095

Please sign in to comment.