From 54ac9dbed95f5fdecc44deaad8b18ba7c1f379fa Mon Sep 17 00:00:00 2001 From: Egor Isichenko Date: Sun, 5 Jan 2025 16:15:54 +0300 Subject: [PATCH] implement sts_token_buffer_time attribute for transport_options to update token earlier than expiration time --- kombu/transport/SQS.py | 66 +++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 0c8d1ee4e..797bfe458 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -76,7 +76,8 @@ }, } 'sts_role_arn': 'arn:aws:iam:::role/STSTest', # optional - 'sts_token_timeout': 900 # optional + 'sts_token_timeout': 900, # optional + 'sts_token_buffer_time': 0 # optional } Note that FIFO and standard queues must be named accordingly (the name of @@ -91,6 +92,9 @@ sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying to access with. sts_token_timeout is the token timeout, defaults (and minimum) to 900 seconds. After the mentioned period, a new token will be created. +sts_token_buffer_time (seconds) is the time by which you want to refresh your token +earlier than its actual expiration time, defaults to 0 (no time buffer will be added), +should be less than sts_token_timeout. @@ -136,7 +140,7 @@ import socket import string import uuid -from datetime import datetime +from datetime import datetime, timedelta from queue import Empty from botocore.client import Config @@ -765,34 +769,38 @@ def sqs(self, queue=None): ) return c + def _refresh_sqs_client(self, queue, q): + sts_creds = self.generate_sts_session_token_with_buffer( + self.transport_options.get('sts_role_arn'), + self.transport_options.get('sts_token_timeout', 900), + self.transport_options.get('sts_token_buffer_time', 0), + ) + self.sts_expiration = sts_creds['Expiration'] + self._predefined_queue_clients[queue] = self.new_sqs_client( + region=q.get('region', self.region), + access_key_id=sts_creds['AccessKeyId'], + secret_access_key=sts_creds['SecretAccessKey'], + session_token=sts_creds['SessionToken'], + ) + return self._predefined_queue_clients[queue] + def _handle_sts_session(self, queue, q): - if not hasattr(self, 'sts_expiration'): # STS token - token init - sts_creds = self.generate_sts_session_token( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900)) - self.sts_expiration = sts_creds['Expiration'] - c = self._predefined_queue_clients[queue] = self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], - ) - return c - # STS token - refresh if expired - elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow(): - sts_creds = self.generate_sts_session_token( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900)) - self.sts_expiration = sts_creds['Expiration'] - c = self._predefined_queue_clients[queue] = self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], - ) - return c - else: # STS token - ruse existing - return self._predefined_queue_clients[queue] + """ + Refreshes the SQS client with a new token on STS token initialization + or expiration. Otherwise, using cached client. + """ + if ( + not hasattr(self, 'sts_expiration') or + self.sts_expiration.replace(tzinfo=None) < datetime.utcnow() + ): + return self._refresh_sqs_client(queue, q) + return self._predefined_queue_clients[queue] + + def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0): + credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds) + if token_buffer_seconds and token_buffer_seconds < token_expiry_seconds: + credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds) + return credentials def generate_sts_session_token(self, role_arn, token_expiry_seconds): sts_client = boto3.client('sts')