Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sts_token_buffer_time parameter to transport options #2216

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
},
}
'sts_role_arn': 'arn:aws:iam::<xxx>:role/STSTest', # optional
'sts_token_timeout': 900 # optional
'sts_token_timeout': 900, # optional
'sts_token_buffer_time': 0 # optional
}

Note that FIFO and standard queues must be named accordingly (the name of
Expand All @@ -91,6 +92,9 @@
sts_token_timeout. sts_role_arn is the assumed IAM role ARN we are trying
to access with. sts_token_timeout is the token timeout, defaults (and minimum)
to 900 seconds. After the mentioned period, a new token will be created.
sts_token_buffer_time (seconds) is the time by which you want to refresh your token
earlier than its actual expiration time, defaults to 0 (no time buffer will be added),
should be less than sts_token_timeout.



Expand Down Expand Up @@ -136,7 +140,7 @@
import socket
import string
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from queue import Empty

from botocore.client import Config
Expand Down Expand Up @@ -765,34 +769,38 @@ def sqs(self, queue=None):
)
return c

def _refresh_sqs_client(self, queue, q):
sts_creds = self.generate_sts_session_token_with_buffer(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900),
self.transport_options.get('sts_token_buffer_time', 0),
)
self.sts_expiration = sts_creds['Expiration']
self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return self._predefined_queue_clients[queue]

def _handle_sts_session(self, queue, q):
if not hasattr(self, 'sts_expiration'): # STS token - token init
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
# STS token - refresh if expired
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
else: # STS token - ruse existing
return self._predefined_queue_clients[queue]
"""
Refreshes the SQS client with a new token on STS token initialization
or expiration. Otherwise, using cached client.
"""
if (
not hasattr(self, 'sts_expiration') or
self.sts_expiration.replace(tzinfo=None) < datetime.utcnow()
):
return self._refresh_sqs_client(queue, q)
return self._predefined_queue_clients[queue]

def generate_sts_session_token_with_buffer(self, role_arn, token_expiry_seconds, token_buffer_seconds=0):
credentials = self.generate_sts_session_token(role_arn, token_expiry_seconds)
if token_buffer_seconds and token_buffer_seconds < token_expiry_seconds:
credentials["Expiration"] -= timedelta(seconds=token_buffer_seconds)
return credentials

def generate_sts_session_token(self, role_arn, token_expiry_seconds):
sts_client = boto3.client('sts')
Expand Down