diff --git a/zmq/_asyncio_selector.py b/zmq/_asyncio_selector.py index eef0f7a5c..3118fc0d3 100644 --- a/zmq/_asyncio_selector.py +++ b/zmq/_asyncio_selector.py @@ -90,6 +90,7 @@ def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: None ) # type: Optional[Tuple[List[_FileDescriptorLike], List[_FileDescriptorLike]]] self._closing_selector = False + self._closed = False self._thread = threading.Thread( name="Tornado selector", daemon=True, @@ -121,6 +122,8 @@ def __del__(self) -> None: self._waker_w.close() def close(self) -> None: + if self._closed: + return with self._select_cond: self._closing_selector = True self._select_cond.notify() @@ -129,6 +132,7 @@ def close(self) -> None: _selector_loops.discard(self) self._waker_r.close() self._waker_w.close() + self._closed = True def _wake_selector(self) -> None: try: diff --git a/zmq/asyncio.py b/zmq/asyncio.py index d6e44e82b..76ecc28db 100644 --- a/zmq/asyncio.py +++ b/zmq/asyncio.py @@ -18,9 +18,7 @@ from zmq import _future # registry of asyncio loop : selector thread -_selectors: WeakKeyDictionary[ - asyncio.AbstractEventLoop, "_zmq._asyncio_selector.SelectorThread" -] = WeakKeyDictionary() +_selectors: WeakKeyDictionary = WeakKeyDictionary() def _get_selector_windows( @@ -53,7 +51,18 @@ def _get_selector_windows( # stacklevel 5 matches most likely zmq.asyncio.Context().socket() stacklevel=5, ) + selector = _selectors[io_loop] = SelectorThread(io_loop) + + # patch loop.close to also close the selector thread + loop_close = io_loop.close + + def _close_selector_and_loop(): + _selectors.pop(io_loop, None) + selector.close() + loop_close() + + io_loop.close = _close_selector_and_loop return selector else: return io_loop @@ -120,7 +129,8 @@ def _clear_io_state(self): called once at close """ - self._selector.remove_reader(self._fd) + if not self.io_loop.is_closed(): + self._selector.remove_reader(self._fd) Poller._socket_class = Socket diff --git a/zmq/tests/test_asyncio.py b/zmq/tests/test_asyncio.py index ba6bac911..e88890e2b 100644 --- a/zmq/tests/test_asyncio.py +++ b/zmq/tests/test_asyncio.py @@ -57,8 +57,12 @@ def setUp(self): super(TestAsyncIOSocket, self).setUp() def tearDown(self): - self.loop.close() super().tearDown() + self.loop.close() + # verify cleanup of references to selectors + assert zaio._selectors == {} + if 'zmq._asyncio_selector' in sys.modules: + assert zmq._asyncio_selector._selector_loops == set() def test_socket_class(self): s = self.context.socket(zmq.PUSH) @@ -432,7 +436,7 @@ def shortDescription(self): return doc def setUp(self): - self.loop = zaio.ZMQEventLoop() + self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) super().setUp() diff --git a/zmq/tests/test_retry_eintr.py b/zmq/tests/test_retry_eintr.py index be841f948..e635f3292 100644 --- a/zmq/tests/test_retry_eintr.py +++ b/zmq/tests/test_retry_eintr.py @@ -17,7 +17,7 @@ class TestEINTRSysCall(BaseZMQTestCase): - """ Base class for EINTR tests. """ + """Base class for EINTR tests.""" # delay for initial signal delivery signal_delay = 0.1