Skip to content

Commit

Permalink
lf.concurrent_map: Cancel pending work items when a work item raise…
Browse files Browse the repository at this point in the history
…s error.

PiperOrigin-RevId: 573092561
  • Loading branch information
daiyip authored and langfun authors committed Oct 13, 2023
1 parent 0a011a6 commit a210224
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,17 @@ def update_progress_bar(progress: Progress) -> None:
progress_bar.update(1)

if ordered:
for future in pending_futures:
for i, future in enumerate(pending_futures):
job = future_to_job[future]
wait_time = (timeout - job.elapse) if timeout else None
try:
_ = future.result(timeout=wait_time)
if job.error is not None:
if not (
silence_on_errors and isinstance(job.error, silence_on_errors)):
# Cancel remaining futures before raising the error.
for f in pending_futures[i + 1:]:
f.cancel()
raise job.error
except concurrent.futures.TimeoutError:
future.cancel()
Expand All @@ -467,6 +470,10 @@ def update_progress_bar(progress: Progress) -> None:
if not (
silence_on_errors and isinstance(job.error, silence_on_errors)
):
# Cancel pending futures before raising the error.
for future in pending_futures:
if future not in completed_batch:
future.cancel()
raise job.error # pylint: disable=g-doc-exception
yield job.arg, job.result, job.error
progress.update(job)
Expand All @@ -491,9 +498,6 @@ def update_progress_bar(progress: Progress) -> None:
TimeoutError(f'Execution time ({job.elapse}) '
f'exceeds {timeout} seconds.'))

if job.error is not None:
last_error = job.error

yield job.arg, job.result, job.error
progress.update(job)
update_progress_bar(progress)
Expand Down

0 comments on commit a210224

Please sign in to comment.