Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track query retry so that we can count the time impact of the failed rpc calls #401

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions langfun/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@

# Concurrent execute a function with parallel inputs with inheriting current
# context's defaults and overrides.
from langfun.core.concurrent import RetryEntry
from langfun.core.concurrent import concurrent_execute
from langfun.core.concurrent import concurrent_map
from langfun.core.concurrent import with_context_access
Expand Down
248 changes: 184 additions & 64 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import abc
import collections
from collections.abc import Mapping
import concurrent.futures
import dataclasses
import io
import random
import sys
import threading
import time
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union

from langfun.core import component
import pyglove as pg
Expand Down Expand Up @@ -143,43 +144,33 @@ def with_retry(
A function with the same signature of the input function, with the retry
capability.
"""
rand = random if seed is None else random.Random(seed)

def _func(*args, **kwargs) -> Any:
def base_interval() -> int:
if isinstance(retry_interval, tuple):
return rand.randint(retry_interval[0], retry_interval[1])
else:
assert isinstance(retry_interval, int)
return retry_interval

def next_wait_interval(attempt: int) -> float:
if not exponential_backoff:
attempt = 1
return min(max_retry_interval, base_interval() * (2 ** (attempt - 1)))

wait_intervals = []
errors = []
while True:
with pg.catch_errors(retry_on_errors) as error_context:
return func(*args, **kwargs)
def _func(*args, **kwargs):
job = Job(
func,
args,
kwargs,
retry_on_errors=retry_on_errors,
max_attempts=max_attempts,
retry_interval=retry_interval,
exponential_backoff=exponential_backoff,
max_retry_interval=max_retry_interval,
seed=seed,
)
job()
if job.error:
raise job.error
return job.result

# Branch when errors are met for retry.
errors.append(error_context.error)
if len(errors) < max_attempts:
wait_interval = next_wait_interval(len(errors))
wait_intervals.append(wait_interval)
return _func

pg.logging.warning(
f'Calling {func!r} encountered {error_context.error!r} '
f'(attempts={len(errors)}), retrying in {wait_interval} seconds...'
)

time.sleep(wait_interval)
else:
raise RetryError(func, errors, wait_intervals)
class RetryEntry(pg.Object):
"""Retry entry."""

return _func
call_interval: float
error: BaseException | None = None
wait_interval: float = 0.


def concurrent_execute(
Expand All @@ -197,6 +188,7 @@ def concurrent_execute(
retry_interval: int | tuple[int, int] = (5, 60),
exponential_backoff: bool = True,
max_retry_interval: int = 300,
return_jobs: bool = False,
) -> list[Any]:
"""Executes a function concurrently under current component context.

Expand All @@ -220,32 +212,52 @@ def concurrent_execute(
max_retry_interval: The max retry interval in seconds. This is useful when
the retry interval is exponential, to avoid the wait time to grow
exponentially.
return_jobs: If True, return a list of `Job` objects. Otherwise, return a
list of outputs.

Returns:
A list of ouputs. Each is the return value of `func` based on the input
value. Order is preserved.
"""
if retry_on_errors is not None:
func = with_retry(
func,
retry_on_errors,
max_attempts=max_attempts,
retry_interval=retry_interval,
exponential_backoff=exponential_backoff,
max_retry_interval=max_retry_interval,
jobs = []
for inputs in parallel_inputs:
jobs.append(
Job(
func,
(inputs,),
retry_on_errors=retry_on_errors,
max_attempts=max_attempts,
retry_interval=retry_interval,
exponential_backoff=exponential_backoff,
max_retry_interval=max_retry_interval,
)
)

# NOTE(daiyip): when executor is not specified and max_worker is 1,
# we don't need to create a executor pool. Instead, the inputs will be
# processed by the user function in sequence within the current thread.
if executor is None and max_workers == 1:
return [func(i) for i in parallel_inputs]
for job in jobs:
job()
if job.error:
raise job.error
return jobs if return_jobs else [job.result for job in jobs]

shutdown_after_finish = executor is None
executor = _executor_pool.executor_from(executor, max_workers=max_workers)

try:
return list(executor.map(with_context_access(func), parallel_inputs))
executed_jobs = list(
executor.map(
lambda job: job(), [with_context_access(job) for job in jobs]
)
)
for job in executed_jobs:
if job.error:
raise job.error
return (
executed_jobs if return_jobs else [job.result for job in executed_jobs]
)
finally:
if shutdown_after_finish:
# Do not wait threads to finish if they are timed out.
Expand All @@ -257,9 +269,61 @@ class Job:
"""Thread pool job."""

func: Callable[[Any], Any]
arg: Any
args: Sequence[Any] = ()
kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
_: dataclasses.KW_ONLY

result: Any = pg.MISSING_VALUE
error: BaseException | None = None
error: Annotated[
BaseException | None,
'The non-retryable error encountered during the job execution.',
] = None
retry_entries: Annotated[
Sequence[RetryEntry], 'Records of retry attempts.'
] = dataclasses.field(default_factory=list)

retry_on_errors: Annotated[
Sequence[Type[BaseException] | str],
(
'A sequence of exception types or tuples of exception type and error '
'messages (described in regular expression) as the desired exception '
'types to retry.'
),
] = ()
max_attempts: Annotated[
int, 'Max number of attempts if an error to retry is encountered.'
] = 5
retry_interval: Annotated[
int | tuple[int, int],
(
'The (base) retry interval in seconds. If a tuple, the retry '
'interval will be randomly chosen between the first and the second '
'element of the tuple.'
),
] = (5, 60)
exponential_backoff: Annotated[
bool,
(
'If True, exponential wait time will be applied on top of the base '
'retry interval.'
),
] = True
max_retry_interval: Annotated[
int,
(
'The max retry interval in seconds. This is useful when the retry '
'interval is exponential, to avoid the wait time to grow '
'exponentially.'
),
] = 300
seed: Annotated[
int | None,
(
'Random seed to generate retry interval. If None, the seed will be'
' determined based on current time.'
),
] = None

timeit: pg.object_utils.TimeIt = dataclasses.field(
default_factory=lambda: pg.object_utils.TimeIt('job')
)
Expand All @@ -269,14 +333,70 @@ def elapse(self) -> float:
"""Returns the running time in seconds since the job get started."""
return self.timeit.elapse

def __call__(self) -> Any:
def _retry_call(self) -> 'Job':
"""Retries func call on args."""
rand = random if self.seed is None else random.Random(self.seed)

def base_interval() -> int:
if isinstance(self.retry_interval, tuple):
return rand.randint(*self.retry_interval)
else:
assert isinstance(self.retry_interval, int)
return self.retry_interval

def next_wait_interval(attempt: int) -> float:
if not self.exponential_backoff:
attempt = 1
return min(
self.max_retry_interval, base_interval() * (2 ** (attempt - 1))
)

retry_entries = []
wait_interval = 0
while True:
with pg.catch_errors(self.retry_on_errors) as error_context:
begin_time = time.time()
self.result = self.func(*self.args, **self.kwargs)

end_time = time.time()
retry_entries.append(RetryEntry(
call_interval=end_time - begin_time,
wait_interval=wait_interval,
error=error_context.error,
))
if error_context.error is None:
self.retry_entries = retry_entries
return self

# Branch when errors are met for retry.
if len(retry_entries) < self.max_attempts:
wait_interval = next_wait_interval(len(retry_entries))

pg.logging.warning(
f'Calling {self.func!r} encountered {error_context.error!r} '
f'(attempts={len(retry_entries)}), retrying in '
f'{wait_interval} seconds...'
)

time.sleep(wait_interval)
else:
errors = [e.error for e in retry_entries]
# First wait interval is 0.
wait_intervals = [e.wait_interval for e in retry_entries[1:]]
raise RetryError(self.func, errors, wait_intervals)

def __call__(self) -> 'Job':
if getattr(self, '_has_call', False):
raise ValueError('Job can only be called once.')
self._has_call = True
try:
with self.timeit:
self.result = self.func(self.arg)
return self.result
if self.retry_on_errors:
return self._retry_call()
self.result = self.func(*self.args, **self.kwargs)
except BaseException as e: # pylint: disable=broad-exception-caught
self.error = e
return e
return self

def mark_canceled(self, error: BaseException) -> None:
"""Marks the job as canceled."""
Expand Down Expand Up @@ -537,7 +657,8 @@ def concurrent_map(
max_attempts: int = 5,
retry_interval: int | tuple[int, int] = (5, 60),
exponential_backoff: bool = True,
) -> Iterator[tuple[Any, Any, BaseException | None]]:
return_jobs: bool = False,
) -> Iterator[Any]:
"""Maps inputs to outptus via func concurrently under current context.

Args:
Expand Down Expand Up @@ -580,9 +701,10 @@ def concurrent_map(
of the tuple.
exponential_backoff: If True, exponential wait time will be applied on top
of the base retry interval.
return_jobs: If True, the returned iterator will emit `Job` objects.

Yields:
An iterator of (input, output, error).
An iterator of (input, output, error) or Job object.

Raises:
Exception: Errors that are not in `silence_on_errors` or `retry_on_errors`,
Expand All @@ -592,15 +714,6 @@ def concurrent_map(
"""
# Internal usage logging.

if retry_on_errors:
func = with_retry(
func,
retry_on_errors,
max_attempts=max_attempts,
retry_interval=retry_interval,
exponential_backoff=exponential_backoff,
)

status_fn = status_fn or (lambda p: { # pylint: disable=g-long-lambda
'Succeeded': '%.2f%% (%d/%d)' % (
p.success_rate * 100, p.succeeded, p.completed),
Expand All @@ -615,7 +728,14 @@ def concurrent_map(
pending_futures = []
total = 0
for inputs in parallel_inputs:
job = Job(func, inputs)
job = Job(
func,
(inputs,),
retry_on_errors=retry_on_errors,
max_attempts=max_attempts,
retry_interval=retry_interval,
exponential_backoff=exponential_backoff,
)
future = executor.submit(
with_context_access(job),
)
Expand Down Expand Up @@ -668,7 +788,7 @@ def update_progress_bar(progress: Progress) -> None:
silence_on_errors and isinstance(job.error, silence_on_errors)):
raise job.error # pylint: disable=g-doc-exception

yield job.arg, job.result, job.error
yield job if return_jobs else job.args[0], job.result, job.error
progress.update(job)
update_progress_bar(progress)
ProgressBar.refresh()
Expand All @@ -689,7 +809,7 @@ def update_progress_bar(progress: Progress) -> None:
if job.error is not None and not (
silence_on_errors and isinstance(job.error, silence_on_errors)):
raise job.error # pylint: disable=g-doc-exception
yield job.arg, job.result, job.error
yield job if return_jobs else job.args[0], job.result, job.error
progress.update(job)
update_progress_bar(progress)
completed_batch.add(future)
Expand All @@ -712,7 +832,7 @@ def update_progress_bar(progress: Progress) -> None:
and isinstance(job.error, silence_on_errors)):
raise job.error # pylint: disable=g-doc-exception

yield job.arg, job.result, job.error
yield job.args[0], job.result, job.error
progress.update(job)
update_progress_bar(progress)
else:
Expand Down
Loading
Loading