Skip to content

Commit

Permalink
Merge pull request #391 from geosolutions-it/token_expire
Browse files Browse the repository at this point in the history
Improve the way how we check the OIDC Access Token Expiration and Val…
  • Loading branch information
MV88 authored Nov 21, 2024
2 parents 80e19a9 + 0cf366e commit b3da4c3
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 53 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
<spring-security-oauth2.version>2.0.17.RELEASE</spring-security-oauth2.version>
<jasypt.version>1.9.3</jasypt.version>
<keycloak-spring-security-adapter.version>18.0.0</keycloak-spring-security-adapter.version>
<spring-security-jwt.version>1.0.11.RELEASE</spring-security-jwt.version>
<spring-security-jwt.version>1.1.1.RELEASE</spring-security-jwt.version>
<java-jwt.version>3.18.3</java-jwt.version>
<wiremock-standalone.version>2.1.12</wiremock-standalone.version>
<hamcrest-core.version>1.3</hamcrest-core.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import static it.geosolutions.geostore.services.rest.security.oauth2.OAuth2Utils.*;

import com.fasterxml.jackson.databind.ObjectMapper;
import it.geosolutions.geostore.core.model.User;
import it.geosolutions.geostore.core.security.password.SecurityUtils;
import it.geosolutions.geostore.services.UserService;
Expand All @@ -52,6 +53,9 @@
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.jwt.crypto.sign.InvalidSignatureException;
import org.springframework.security.oauth2.client.OAuth2ClientContext;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
Expand Down Expand Up @@ -94,66 +98,131 @@ public SessionToken refresh(String refreshToken, String accessToken) {
String errorMessage = "";
String warningMessage = "";
HttpServletRequest request = getRequest();
if (accessToken == null || accessToken.isEmpty())

// Ensure accessToken is available
if (accessToken == null || accessToken.isEmpty()) {
accessToken = OAuth2Utils.tokenFromParamsOrBearer(ACCESS_TOKEN_PARAM, request);
if (accessToken == null || accessToken.isEmpty())
}
if (accessToken == null || accessToken.isEmpty()) {
throw new NotFoundWebEx("Either the accessToken or the refresh token are missing");
}

OAuth2AccessToken currentToken = retrieveAccessToken(accessToken);
OAuth2AccessToken currentToken = retrieveAccessToken(accessToken, null);

// Determine refreshTokenToUse
String refreshTokenToUse =
currentToken.getRefreshToken() != null
&& currentToken.getRefreshToken().getValue() != null
&& !currentToken.getRefreshToken().getValue().isEmpty()
? currentToken.getRefreshToken().getValue()
: refreshToken;
if (refreshTokenToUse == null || refreshTokenToUse.isEmpty())
Optional.ofNullable(currentToken.getRefreshToken())
.map(OAuth2RefreshToken::getValue)
.filter(value -> !value.isEmpty())
.orElse(refreshToken);

if (refreshTokenToUse == null || refreshTokenToUse.isEmpty()) {
refreshTokenToUse = getParameterValue(REFRESH_TOKEN_PARAM, request);
}

SessionToken sessionToken = null;
OAuth2Configuration configuration = configuration();

if (configuration != null && configuration.isEnabled()) {
if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token.");
LOGGER.info("Attempting to refresh the token.");
try {
sessionToken = doRefresh(refreshTokenToUse, accessToken, configuration);
if (sessionToken != null) {
currentToken =
retrieveAccessToken(
sessionToken.getAccessToken(), sessionToken.getExpires());
}
} catch (UserRedirectRequiredException e) {
// Log the warning and set the warning message in the session token
warningMessage = "A redirect is required to get the user's approval.";
LOGGER.warn(warningMessage);
} catch (NullPointerException npe) {
// Log the error and set the error message in the session token
errorMessage = "Current configuration wasn't correctly initialized.";
LOGGER.error("Current configuration wasn't correctly initialized.", npe);
LOGGER.warn(warningMessage, e);
} catch (Exception e) {
// Log the error and set the error message in the session token
errorMessage = "An error occurred during token refresh: " + e.getMessage();
LOGGER.error(errorMessage);
LOGGER.error(errorMessage, e);
}
} else {
LOGGER.warn("Configuration is null or disabled; skipping token refresh.");
}
if (sessionToken == null && !isTokenExpired(currentToken)) {
if (warningMessage.isEmpty())
warningMessage =
"Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!";
sessionToken =
sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration());

if (sessionToken == null) {
if (isTokenExpired(currentToken) /* || !isTokenValid(currentToken) */) {
errorMessage = "Token is invalid or expired, and refresh failed.";
LOGGER.error(errorMessage);
handleRefreshFailure(accessToken, refreshTokenToUse, configuration);
return null;
} else {
if (warningMessage.isEmpty()) warningMessage = "Using existing access token.";
sessionToken =
sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration());
}
}

if (sessionToken != null) {
if (!warningMessage.isEmpty()) sessionToken.setWarning(warningMessage);
if (!errorMessage.isEmpty()) sessionToken.setError(errorMessage);
request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken());
request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE, sessionToken.getTokenType());
if (!warningMessage.isEmpty()) {
sessionToken.setWarning(warningMessage);
}
if (!errorMessage.isEmpty()) {
sessionToken.setError(errorMessage);
}

request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken());
request.setAttribute(
OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE, sessionToken.getTokenType());

return sessionToken;
}

private boolean isTokenExpired(OAuth2AccessToken token) {
return token != null
&& !token.getValue().isEmpty()
&& (token.getExpiration() == null
|| (token.getExpiration() != null
&& token.getExpiration().before(new Date())));
if (token == null || token.getValue().isEmpty()) {
return true;
}

Date expiration = token.getExpiration();

if (expiration == null) {
expiration = getExpirationDateFromToken(token.getValue());
if (expiration == null) {
return true;
}
}

// Allow clock skew if necessary
return expiration.before(new Date());
}

private Date getExpirationDateFromToken(String token) {
try {
Jwt decodedToken = JwtHelper.decode(token);
String claimsJson = decodedToken.getClaims();

ObjectMapper mapper = new ObjectMapper();
Map<String, Object> claims = mapper.readValue(claimsJson, Map.class);

Object exp = claims.get("exp");
if (exp != null) {
long expLong;
if (exp instanceof Integer) {
expLong = ((Integer) exp).longValue();
} else if (exp instanceof Long) {
expLong = (Long) exp;
} else if (exp instanceof String) {
expLong = Long.parseLong((String) exp);
} else {
throw new IllegalArgumentException("Cannot parse 'exp' claim from token");
}

// The 'exp' claim is usually in seconds since epoch
Date expiration = new Date(expLong * 1000);
return expiration;
} else {
return null;
}
} catch (InvalidSignatureException e) {
LOGGER.error("Invalid JWT signature: {}", e.getMessage());
return null;
} catch (Exception e) {
LOGGER.error("Failed to parse JWT token: {}", e.getMessage());
return null;
}
}

/**
Expand Down Expand Up @@ -203,7 +272,8 @@ protected SessionToken doRefresh(

if (response.getStatusCode().is2xxSuccessful()) {
OAuth2AccessToken newToken = response.getBody();
if (newToken != null && !isTokenExpired(newToken)) {
if (newToken != null
&& !isTokenExpired(newToken) /* && isTokenValid(newToken) */) {
OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken();
OAuth2RefreshToken refreshTokenToUse =
(newRefreshToken != null && newRefreshToken.getValue() != null)
Expand Down Expand Up @@ -290,9 +360,11 @@ public void handleRefreshFailure(
doLogout(null);

try {
String redirectUrl =
"../../openid/" + configuration.getProvider().toLowerCase() + "/login";
getResponse().sendRedirect(redirectUrl);
if (configuration != null && configuration.getProvider() != null) {
String redirectUrl =
"../../openid/" + configuration.getProvider().toLowerCase() + "/login";
getResponse().sendRedirect(redirectUrl);
}
} catch (IOException e) {
LOGGER.error("Error while sending redirect to login service: ", e);
throw new RuntimeException("Failed to redirect to login", e);
Expand Down Expand Up @@ -364,7 +436,7 @@ protected TokenDetails getTokenDetails(Authentication authentication) {
return OAuth2Utils.getTokenDetails(authentication);
}

protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
protected OAuth2AccessToken retrieveAccessToken(String accessToken, Long expires) {
Authentication authentication = cache() != null ? cache().get(accessToken) : null;
OAuth2AccessToken result = null;
if (authentication != null) {
Expand All @@ -378,7 +450,12 @@ protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
if (context != null) result = context.getAccessToken();
}
}
if (result == null) result = new DefaultOAuth2AccessToken(accessToken);
if (result == null) {
result = new DefaultOAuth2AccessToken(accessToken);
if (expires != null && expires > 0) {
((DefaultOAuth2AccessToken) result).setExpiration(new Date(expires));
}
}
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,7 @@ void testRefreshWithInvalidRefreshToken() {
"Refresh token should remain unchanged");
assertNotNull(sessionToken.getWarning(), "Warning message should be set");
assertTrue(
sessionToken
.getWarning()
.contains(
"Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!"),
sessionToken.getWarning().contains("Using existing access token."),
"Expected error message in SessionToken");
}

Expand Down Expand Up @@ -235,10 +232,7 @@ void testRefreshWithServerError() {
"Refresh token should remain unchanged after server error");
assertNotNull(sessionToken.getWarning(), "Warning message should be set");
assertTrue(
sessionToken
.getWarning()
.contains(
"Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!"),
sessionToken.getWarning().contains("Using existing access token."),
"Expected error message in SessionToken");
verify(restTemplate, times(3))
.exchange(
Expand Down Expand Up @@ -279,7 +273,7 @@ void testRefreshWithNullResponse() {
"Refresh token should remain unchanged");
assertNotNull(sessionToken.getWarning(), "Warning message should be set");
assertTrue(
sessionToken.getWarning().contains("Seeding it with previous Access Token!"),
sessionToken.getWarning().contains("Using existing access token."),
"Expected warning message in SessionToken");
}

Expand Down Expand Up @@ -558,7 +552,7 @@ protected TokenDetails getTokenDetails(Authentication authentication) {
}

@Override
protected OAuth2AccessToken retrieveAccessToken(String accessToken) {
protected OAuth2AccessToken retrieveAccessToken(String accessToken, Long expires) {
return currentAccessToken;
}

Expand Down

0 comments on commit b3da4c3

Please sign in to comment.