From a210224412b1e82a271b1154646e1354f9a83237 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Thu, 12 Oct 2023 21:13:26 -0700 Subject: [PATCH] `lf.concurrent_map`: Cancel pending work items when a work item raises error. PiperOrigin-RevId: 573092561 --- langfun/core/concurrent.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index 28559946..79e207fa 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -438,7 +438,7 @@ def update_progress_bar(progress: Progress) -> None: progress_bar.update(1) if ordered: - for future in pending_futures: + for i, future in enumerate(pending_futures): job = future_to_job[future] wait_time = (timeout - job.elapse) if timeout else None try: @@ -446,6 +446,9 @@ def update_progress_bar(progress: Progress) -> None: if job.error is not None: if not ( silence_on_errors and isinstance(job.error, silence_on_errors)): + # Cancel remaining futures before raising the error. + for f in pending_futures[i + 1:]: + f.cancel() raise job.error except concurrent.futures.TimeoutError: future.cancel() @@ -467,6 +470,10 @@ def update_progress_bar(progress: Progress) -> None: if not ( silence_on_errors and isinstance(job.error, silence_on_errors) ): + # Cancel pending futures before raising the error. + for future in pending_futures: + if future not in completed_batch: + future.cancel() raise job.error # pylint: disable=g-doc-exception yield job.arg, job.result, job.error progress.update(job) @@ -491,9 +498,6 @@ def update_progress_bar(progress: Progress) -> None: TimeoutError(f'Execution time ({job.elapse}) ' f'exceeds {timeout} seconds.')) - if job.error is not None: - last_error = job.error - yield job.arg, job.result, job.error progress.update(job) update_progress_bar(progress)