From 9117914b2a55851fafd8c457a7b3b03314a9bc66 Mon Sep 17 00:00:00 2001 From: Han Manjong Date: Fri, 10 Jan 2025 18:34:53 +0900 Subject: [PATCH 1/4] chore(sqs): write the test case for multiple predefined queues with aws sts session --- t/unit/transport/test_SQS.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index b82be5aa1..551756a34 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -996,6 +996,35 @@ def test_sts_session_not_expired(self): # Assert mock_generate_sts_session_token.assert_not_called() + def test_sts_session_with_multiple_predefined_queues(self): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': example_predefined_queues, + 'sts_role_arn': 'test::arn' + }) + channel = connection.channel() + sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel) + + mock_generate_sts_session_token = Mock() + mock_new_sqs_client = Mock() + channel.new_sqs_client = mock_new_sqs_client + mock_generate_sts_session_token.return_value = { + 'Expiration': datetime.utcnow() + timedelta(days=1), + 'SessionToken': 123, + 'AccessKeyId': 123, + 'SecretAccessKey': 123 + } + + channel.generate_sts_session_token = mock_generate_sts_session_token + + # Act + sqs(queue='queue-1') + sqs(queue='queue-2') + + # Assert + mock_generate_sts_session_token.assert_called() + mock_new_sqs_client.assert_called() + + def test_message_attribute(self): message = 'my test message' self.producer.publish(message, message_attributes={ From 2768f4564cea01e3b816e8ea6c6d1338100592a0 Mon Sep 17 00:00:00 2001 From: Han Manjong Date: Fri, 10 Jan 2025 18:40:08 +0900 Subject: [PATCH 2/4] fix(sqs): don't crash on multiple predefined queues with aws sts session --- kombu/transport/SQS.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 0c8d1ee4e..6aa36e0f2 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -792,6 +792,18 @@ def _handle_sts_session(self, queue, q): ) return c else: # STS token - ruse existing + if not hasattr(self._predefined_queue_clients, queue): + sts_creds = self.generate_sts_session_token( + self.transport_options.get('sts_role_arn'), + self.transport_options.get('sts_token_timeout', 900)) + self.sts_expiration = sts_creds['Expiration'] + c = self._predefined_queue_clients[queue] = self.new_sqs_client( + region=q.get('region', self.region), + access_key_id=sts_creds['AccessKeyId'], + secret_access_key=sts_creds['SecretAccessKey'], + session_token=sts_creds['SessionToken'], + ) + return c return self._predefined_queue_clients[queue] def generate_sts_session_token(self, role_arn, token_expiry_seconds): From c3022ae22f61a0c939f48ebf51cf758eda0f24e6 Mon Sep 17 00:00:00 2001 From: Han Manjong Date: Fri, 10 Jan 2025 18:45:11 +0900 Subject: [PATCH 3/4] refactor(sqs): make _new_predefined_queue_client_with_sts_session() --- kombu/transport/SQS.py | 50 ++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 6aa36e0f2..2b522bebc 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -766,46 +766,30 @@ def sqs(self, queue=None): return c def _handle_sts_session(self, queue, q): + region = q.get('region', self.region) if not hasattr(self, 'sts_expiration'): # STS token - token init - sts_creds = self.generate_sts_session_token( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900)) - self.sts_expiration = sts_creds['Expiration'] - c = self._predefined_queue_clients[queue] = self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], - ) - return c + return self._new_predefined_queue_client_with_sts_session(queue, region) # STS token - refresh if expired elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow(): - sts_creds = self.generate_sts_session_token( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900)) - self.sts_expiration = sts_creds['Expiration'] - c = self._predefined_queue_clients[queue] = self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], - ) - return c + return self._new_predefined_queue_client_with_sts_session(queue, region) else: # STS token - ruse existing if not hasattr(self._predefined_queue_clients, queue): - sts_creds = self.generate_sts_session_token( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900)) - self.sts_expiration = sts_creds['Expiration'] - c = self._predefined_queue_clients[queue] = self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], - ) - return c + return self._new_predefined_queue_client_with_sts_session(queue, region) return self._predefined_queue_clients[queue] + def _new_predefined_queue_client_with_sts_session(self, queue, region): + sts_creds = self.generate_sts_session_token( + self.transport_options.get('sts_role_arn'), + self.transport_options.get('sts_token_timeout', 900)) + self.sts_expiration = sts_creds['Expiration'] + c = self._predefined_queue_clients[queue] = self.new_sqs_client( + region=region, + access_key_id=sts_creds['AccessKeyId'], + secret_access_key=sts_creds['SecretAccessKey'], + session_token=sts_creds['SessionToken'], + ) + return c + def generate_sts_session_token(self, role_arn, token_expiry_seconds): sts_client = boto3.client('sts') sts_policy = sts_client.assume_role( From 9d90b5cd3b3ab266f3443452deef1550464b72db Mon Sep 17 00:00:00 2001 From: Han Manjong Date: Fri, 10 Jan 2025 19:13:40 +0900 Subject: [PATCH 4/4] fix(sqs): lint --- t/unit/transport/test_SQS.py | 1 - 1 file changed, 1 deletion(-) diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 551756a34..b6a1d6ae2 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -1024,7 +1024,6 @@ def test_sts_session_with_multiple_predefined_queues(self): mock_generate_sts_session_token.assert_called() mock_new_sqs_client.assert_called() - def test_message_attribute(self): message = 'my test message' self.producer.publish(message, message_attributes={