Skip to content

Commit

Permalink
Attempt to automatically recover from ChannelInvalidStateError
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Dec 15, 2023
1 parent 8a7bcf6 commit 96cf69b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/kiwipy/rmq/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
self._broadcast_queue = None # type: typing.Optional[aio_pika.Queue]
self._broadcast_consumer_tag = None

@utils.auto_reopen_channel
async def add_rpc_subscriber(self, subscriber, identifier=None):
# Create an RPC queue
rpc_queue = await self._channel.declare_queue(exclusive=True, arguments=self._rmq_queue_arguments)
Expand All @@ -129,6 +130,7 @@ async def add_rpc_subscriber(self, subscriber, identifier=None):
self._rpc_subscribers[identifier] = rpc_queue
return identifier

@utils.auto_reopen_channel
async def remove_rpc_subscriber(self, identifier):
try:
rpc_queue = self._rpc_subscribers.pop(identifier)
Expand All @@ -138,6 +140,7 @@ async def remove_rpc_subscriber(self, identifier):
await rpc_queue.cancel(identifier)
await rpc_queue.unbind(self._exchange, routing_key=f'{defaults.RPC_TOPIC}.{identifier}')

@utils.auto_reopen_channel
async def add_broadcast_subscriber(self, subscriber, identifier=None):
identifier = identifier or shortuuid.uuid()
if identifier in self._broadcast_subscribers:
Expand All @@ -149,6 +152,7 @@ async def add_broadcast_subscriber(self, subscriber, identifier=None):
self._broadcast_consumer_tag = await self._broadcast_queue.consume(self._on_broadcast)
return identifier

@utils.auto_reopen_channel
async def remove_broadcast_subscriber(self, identifier):
try:
del self._broadcast_subscribers[identifier]
Expand Down Expand Up @@ -177,6 +181,7 @@ async def connect(self):

await self._create_broadcast_queue()

@utils.auto_reopen_channel
async def _create_broadcast_queue(self):
"""
Create and bind the broadcast queue
Expand Down Expand Up @@ -593,9 +598,9 @@ async def async_connect(
"""
connection_params = connection_params or {}
if isinstance(connection_params, dict):
connection = await connection_factory(**connection_params)
connection = await connection_factory(**connection_params, reconnect_interval=1, fail_fast=False)
else:
connection = await connection_factory(connection_params)
connection = await connection_factory(connection_params, reconnect_interval=1, fail_fast=False)

communicator = RmqCommunicator(
connection=connection,
Expand Down
2 changes: 2 additions & 0 deletions src/kiwipy/rmq/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def action_message(self, message):
message.send(self)
return message.future

@utils.auto_reopen_channel
async def publish(self, message, routing_key, mandatory=True):
"""
Send a fire-and-forget message i.e. no response expected.
Expand All @@ -208,6 +209,7 @@ async def publish(self, message, routing_key, mandatory=True):
result = await self._exchange.publish(message, routing_key=routing_key, mandatory=mandatory)
return result

@utils.auto_reopen_channel
async def publish_expect_response(self, message, routing_key, mandatory=True):
# If there is no correlation id we have to set on so that we know what the response will be to
if not message.correlation_id:
Expand Down
34 changes: 34 additions & 0 deletions src/kiwipy/rmq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import collections.abc
import functools
import inspect
import logging
import os
import socket
import traceback

import aio_pika

from kiwipy import exceptions

__all__ = ()
Expand All @@ -21,6 +24,37 @@
PENDING_KEY = 'pending'


def auto_reopen_channel(wrapped):
"""Call the wrapped method and automatically recover from closed channels and connections.
This decorator should be used on methods that attempt to use the open channel. It calls the method catching
:class:`aio_pika.ChannelInvalidStateError`, which is thrown when the channel was closed. RabbitMQ will close
channels if a call is made over it that errors, in order to protect other channels that may be using the same
connection. The decorator will attempt to reopen the channel. If the connection has also closed, it will wait
for it to be restored throught the robust connection mechanism.
"""
logger = logging.getLogger(auto_reopen_channel.__name__)

async def wrapper(self, *args, **kwargs):
from aio_pika.exceptions import ChannelInvalidStateError

while True:
try:
return await wrapped(self, *args, **kwargs)
except ChannelInvalidStateError as exception:
# This is thrown when the ``Channel`` was closed, so attempt to reopen it.
logger.exception('Channel was closed: <%s>. Attempting to reopen it.', exception)
try:
await self._channel.reopen()
except RuntimeError as exc:
logger.exception(
'Caught `RuntimeError`: %s . Maybe connection closed, waiting for it to be restored', exc
)
await asyncio.sleep(2)

return wrapper


def get_host_info():
return {'hostname': socket.gethostname(), 'pid': os.getpid()}

Expand Down

0 comments on commit 96cf69b

Please sign in to comment.