From dfcbec863e259bdee2fba530b2b7d5320649bc30 Mon Sep 17 00:00:00 2001 From: Da Huang Date: Tue, 7 Jan 2025 21:26:01 -0800 Subject: [PATCH] Track query retry so that we can count the time impact of the failed rpc calls Also fuse retry logics into concurrent.Job Add the retry entries to LMSamplingUsage. PiperOrigin-RevId: 713146608 --- langfun/core/__init__.py | 1 + langfun/core/concurrent.py | 248 +++++++++++++++++++++------- langfun/core/concurrent_test.py | 31 +++- langfun/core/language_model.py | 75 ++++++++- langfun/core/language_model_test.py | 113 +++++++++---- 5 files changed, 367 insertions(+), 101 deletions(-) diff --git a/langfun/core/__init__.py b/langfun/core/__init__.py index b30bd311..e284b3ad 100644 --- a/langfun/core/__init__.py +++ b/langfun/core/__init__.py @@ -72,6 +72,7 @@ # Concurrent execute a function with parallel inputs with inheriting current # context's defaults and overrides. +from langfun.core.concurrent import RetryEntry from langfun.core.concurrent import concurrent_execute from langfun.core.concurrent import concurrent_map from langfun.core.concurrent import with_context_access diff --git a/langfun/core/concurrent.py b/langfun/core/concurrent.py index 3eed9e74..5268bd9d 100644 --- a/langfun/core/concurrent.py +++ b/langfun/core/concurrent.py @@ -15,6 +15,7 @@ import abc import collections +from collections.abc import Mapping import concurrent.futures import dataclasses import io @@ -22,7 +23,7 @@ import sys import threading import time -from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union +from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union from langfun.core import component import pyglove as pg @@ -143,43 +144,33 @@ def with_retry( A function with the same signature of the input function, with the retry capability. """ - rand = random if seed is None else random.Random(seed) - def _func(*args, **kwargs) -> Any: - def base_interval() -> int: - if isinstance(retry_interval, tuple): - return rand.randint(retry_interval[0], retry_interval[1]) - else: - assert isinstance(retry_interval, int) - return retry_interval - - def next_wait_interval(attempt: int) -> float: - if not exponential_backoff: - attempt = 1 - return min(max_retry_interval, base_interval() * (2 ** (attempt - 1))) - - wait_intervals = [] - errors = [] - while True: - with pg.catch_errors(retry_on_errors) as error_context: - return func(*args, **kwargs) + def _func(*args, **kwargs): + job = Job( + func, + args, + kwargs, + retry_on_errors=retry_on_errors, + max_attempts=max_attempts, + retry_interval=retry_interval, + exponential_backoff=exponential_backoff, + max_retry_interval=max_retry_interval, + seed=seed, + ) + job() + if job.error: + raise job.error + return job.result - # Branch when errors are met for retry. - errors.append(error_context.error) - if len(errors) < max_attempts: - wait_interval = next_wait_interval(len(errors)) - wait_intervals.append(wait_interval) + return _func - pg.logging.warning( - f'Calling {func!r} encountered {error_context.error!r} ' - f'(attempts={len(errors)}), retrying in {wait_interval} seconds...' - ) - time.sleep(wait_interval) - else: - raise RetryError(func, errors, wait_intervals) +class RetryEntry(pg.Object): + """Retry entry.""" - return _func + call_interval: float + error: BaseException | None = None + wait_interval: float = 0. def concurrent_execute( @@ -197,6 +188,7 @@ def concurrent_execute( retry_interval: int | tuple[int, int] = (5, 60), exponential_backoff: bool = True, max_retry_interval: int = 300, + return_jobs: bool = False, ) -> list[Any]: """Executes a function concurrently under current component context. @@ -220,32 +212,52 @@ def concurrent_execute( max_retry_interval: The max retry interval in seconds. This is useful when the retry interval is exponential, to avoid the wait time to grow exponentially. + return_jobs: If True, return a list of `Job` objects. Otherwise, return a + list of outputs. Returns: A list of ouputs. Each is the return value of `func` based on the input value. Order is preserved. """ - if retry_on_errors is not None: - func = with_retry( - func, - retry_on_errors, - max_attempts=max_attempts, - retry_interval=retry_interval, - exponential_backoff=exponential_backoff, - max_retry_interval=max_retry_interval, + jobs = [] + for inputs in parallel_inputs: + jobs.append( + Job( + func, + (inputs,), + retry_on_errors=retry_on_errors, + max_attempts=max_attempts, + retry_interval=retry_interval, + exponential_backoff=exponential_backoff, + max_retry_interval=max_retry_interval, + ) ) # NOTE(daiyip): when executor is not specified and max_worker is 1, # we don't need to create a executor pool. Instead, the inputs will be # processed by the user function in sequence within the current thread. if executor is None and max_workers == 1: - return [func(i) for i in parallel_inputs] + for job in jobs: + job() + if job.error: + raise job.error + return jobs if return_jobs else [job.result for job in jobs] shutdown_after_finish = executor is None executor = _executor_pool.executor_from(executor, max_workers=max_workers) try: - return list(executor.map(with_context_access(func), parallel_inputs)) + executed_jobs = list( + executor.map( + lambda job: job(), [with_context_access(job) for job in jobs] + ) + ) + for job in executed_jobs: + if job.error: + raise job.error + return ( + executed_jobs if return_jobs else [job.result for job in executed_jobs] + ) finally: if shutdown_after_finish: # Do not wait threads to finish if they are timed out. @@ -257,9 +269,61 @@ class Job: """Thread pool job.""" func: Callable[[Any], Any] - arg: Any + args: Sequence[Any] = () + kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict) + _: dataclasses.KW_ONLY + result: Any = pg.MISSING_VALUE - error: BaseException | None = None + error: Annotated[ + BaseException | None, + 'The non-retryable error encountered during the job execution.', + ] = None + retry_entries: Annotated[ + Sequence[RetryEntry], 'Records of retry attempts.' + ] = dataclasses.field(default_factory=list) + + retry_on_errors: Annotated[ + Sequence[Type[BaseException] | str], + ( + 'A sequence of exception types or tuples of exception type and error ' + 'messages (described in regular expression) as the desired exception ' + 'types to retry.' + ), + ] = () + max_attempts: Annotated[ + int, 'Max number of attempts if an error to retry is encountered.' + ] = 5 + retry_interval: Annotated[ + int | tuple[int, int], + ( + 'The (base) retry interval in seconds. If a tuple, the retry ' + 'interval will be randomly chosen between the first and the second ' + 'element of the tuple.' + ), + ] = (5, 60) + exponential_backoff: Annotated[ + bool, + ( + 'If True, exponential wait time will be applied on top of the base ' + 'retry interval.' + ), + ] = True + max_retry_interval: Annotated[ + int, + ( + 'The max retry interval in seconds. This is useful when the retry ' + 'interval is exponential, to avoid the wait time to grow ' + 'exponentially.' + ), + ] = 300 + seed: Annotated[ + int | None, + ( + 'Random seed to generate retry interval. If None, the seed will be' + ' determined based on current time.' + ), + ] = None + timeit: pg.object_utils.TimeIt = dataclasses.field( default_factory=lambda: pg.object_utils.TimeIt('job') ) @@ -269,14 +333,70 @@ def elapse(self) -> float: """Returns the running time in seconds since the job get started.""" return self.timeit.elapse - def __call__(self) -> Any: + def _retry_call(self) -> 'Job': + """Retries func call on args.""" + rand = random if self.seed is None else random.Random(self.seed) + + def base_interval() -> int: + if isinstance(self.retry_interval, tuple): + return rand.randint(*self.retry_interval) + else: + assert isinstance(self.retry_interval, int) + return self.retry_interval + + def next_wait_interval(attempt: int) -> float: + if not self.exponential_backoff: + attempt = 1 + return min( + self.max_retry_interval, base_interval() * (2 ** (attempt - 1)) + ) + + retry_entries = [] + wait_interval = 0 + while True: + with pg.catch_errors(self.retry_on_errors) as error_context: + begin_time = time.time() + self.result = self.func(*self.args, **self.kwargs) + + end_time = time.time() + retry_entries.append(RetryEntry( + call_interval=end_time - begin_time, + wait_interval=wait_interval, + error=error_context.error, + )) + if error_context.error is None: + self.retry_entries = retry_entries + return self + + # Branch when errors are met for retry. + if len(retry_entries) < self.max_attempts: + wait_interval = next_wait_interval(len(retry_entries)) + + pg.logging.warning( + f'Calling {self.func!r} encountered {error_context.error!r} ' + f'(attempts={len(retry_entries)}), retrying in ' + f'{wait_interval} seconds...' + ) + + time.sleep(wait_interval) + else: + errors = [e.error for e in retry_entries] + # First wait interval is 0. + wait_intervals = [e.wait_interval for e in retry_entries[1:]] + raise RetryError(self.func, errors, wait_intervals) + + def __call__(self) -> 'Job': + if getattr(self, '_has_call', False): + raise ValueError('Job can only be called once.') + self._has_call = True try: with self.timeit: - self.result = self.func(self.arg) - return self.result + if self.retry_on_errors: + return self._retry_call() + self.result = self.func(*self.args, **self.kwargs) except BaseException as e: # pylint: disable=broad-exception-caught self.error = e - return e + return self def mark_canceled(self, error: BaseException) -> None: """Marks the job as canceled.""" @@ -537,7 +657,8 @@ def concurrent_map( max_attempts: int = 5, retry_interval: int | tuple[int, int] = (5, 60), exponential_backoff: bool = True, -) -> Iterator[tuple[Any, Any, BaseException | None]]: + return_jobs: bool = False, +) -> Iterator[Any]: """Maps inputs to outptus via func concurrently under current context. Args: @@ -580,9 +701,10 @@ def concurrent_map( of the tuple. exponential_backoff: If True, exponential wait time will be applied on top of the base retry interval. + return_jobs: If True, the returned iterator will emit `Job` objects. Yields: - An iterator of (input, output, error). + An iterator of (input, output, error) or Job object. Raises: Exception: Errors that are not in `silence_on_errors` or `retry_on_errors`, @@ -592,15 +714,6 @@ def concurrent_map( """ # Internal usage logging. - if retry_on_errors: - func = with_retry( - func, - retry_on_errors, - max_attempts=max_attempts, - retry_interval=retry_interval, - exponential_backoff=exponential_backoff, - ) - status_fn = status_fn or (lambda p: { # pylint: disable=g-long-lambda 'Succeeded': '%.2f%% (%d/%d)' % ( p.success_rate * 100, p.succeeded, p.completed), @@ -615,7 +728,14 @@ def concurrent_map( pending_futures = [] total = 0 for inputs in parallel_inputs: - job = Job(func, inputs) + job = Job( + func, + (inputs,), + retry_on_errors=retry_on_errors, + max_attempts=max_attempts, + retry_interval=retry_interval, + exponential_backoff=exponential_backoff, + ) future = executor.submit( with_context_access(job), ) @@ -668,7 +788,7 @@ def update_progress_bar(progress: Progress) -> None: silence_on_errors and isinstance(job.error, silence_on_errors)): raise job.error # pylint: disable=g-doc-exception - yield job.arg, job.result, job.error + yield job if return_jobs else job.args[0], job.result, job.error progress.update(job) update_progress_bar(progress) ProgressBar.refresh() @@ -689,7 +809,7 @@ def update_progress_bar(progress: Progress) -> None: if job.error is not None and not ( silence_on_errors and isinstance(job.error, silence_on_errors)): raise job.error # pylint: disable=g-doc-exception - yield job.arg, job.result, job.error + yield job if return_jobs else job.args[0], job.result, job.error progress.update(job) update_progress_bar(progress) completed_batch.add(future) @@ -712,7 +832,7 @@ def update_progress_bar(progress: Progress) -> None: and isinstance(job.error, silence_on_errors)): raise job.error # pylint: disable=g-doc-exception - yield job.arg, job.result, job.error + yield job.args[0], job.result, job.error progress.update(job) update_progress_bar(progress) else: diff --git a/langfun/core/concurrent_test.py b/langfun/core/concurrent_test.py index 27d04255..d88631d2 100644 --- a/langfun/core/concurrent_test.py +++ b/langfun/core/concurrent_test.py @@ -94,7 +94,7 @@ def test_eq(self): ) -class WithRetryTest(unittest.TestCase): +class RetryTest(unittest.TestCase): def assert_retry(self, func, expected_attempts, expected_wait_intervals): with pg.catch_errors(concurrent.RetryError) as error_context: @@ -162,6 +162,31 @@ def foo(): with self.assertRaises(ValueError): foo_with_retry() + def test_retry_with_job(self): + count = 0 + + def foo(): + nonlocal count + count += 1 + if count < 3: + raise ValueError('Foo temporary error.') + return 'Success' + + job = concurrent.Job( + foo, + retry_on_errors=ValueError, + retry_interval=1, + ) + job() + self.assertEqual(job.result, 'Success') + self.assertEqual( + [retry_entry.wait_interval for retry_entry in job.retry_entries], + [0, 1, 2], + ) + self.assertIsInstance(job.retry_entries[0].error, ValueError) + self.assertIsInstance(job.retry_entries[1].error, ValueError) + self.assertIsNone(job.retry_entries[2].error) + class ConcurrentExecuteTest(unittest.TestCase): @@ -217,8 +242,8 @@ def fun(x): def fun2(unused_x): raise ValueError('Intentional error.') - job1 = concurrent.Job(fun, 1) - job2 = concurrent.Job(fun2, 2) + job1 = concurrent.Job(fun, (1,)) + job2 = concurrent.Job(fun2, (2,)) job1() job2() diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index 9cf0e07f..7179d007 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -81,6 +81,55 @@ class LMSample(pg.Object): ] = None +class RetryStats(pg.Object): + """Retry stats, which is aggregated across multiple retry entries.""" + + num_occurences: Annotated[ + int, + 'Total number of retry attempts on LLM (excluding the first attempt).', + ] = 0 + total_wait_interval: Annotated[ + float, 'Total wait interval in seconds due to retry.' + ] = 0 + total_call_interval: Annotated[ + float, 'Total LLM call interval in seconds.' + ] = 0 + errors: Annotated[ + Sequence[str], + 'A list of error messages encountered during the retry attempts.', + ] = [] + + @classmethod + def from_retry_entries( + cls, retry_entries: Sequence[concurrent.RetryEntry] + ) -> 'RetryStats': + """Creates a RetryStats from a sequence of RetryEntry.""" + if not retry_entries: + return RetryStats() + return RetryStats( + num_occurences=len(retry_entries) - 1, + total_wait_interval=sum(e.wait_interval for e in retry_entries), + total_call_interval=sum(e.call_interval for e in retry_entries), + errors=[ + f'{retry.error.__class__.__name__}: {str(retry.error)}' + for retry in retry_entries if retry.error is not None + ], + ) + + def __add__(self, other: 'RetryStats') -> 'RetryStats': + return RetryStats( + num_occurences=self.num_occurences + other.num_occurences, + total_wait_interval=self.total_wait_interval + + other.total_wait_interval, + total_call_interval=self.total_call_interval + + other.total_call_interval, + errors=self.errors + other.errors, + ) + + def __radd__(self, other: 'RetryStats') -> 'RetryStats': + return self + other + + class LMSamplingUsage(pg.Object): """Usage information per completion.""" @@ -93,8 +142,9 @@ class LMSamplingUsage(pg.Object): ( 'Estimated cost in US dollars. If None, cost estimating is not ' 'suppported on the model being queried.' - ) + ), ] = None + retry_stats: RetryStats = RetryStats() def __bool__(self) -> bool: return self.num_requests > 0 @@ -136,6 +186,7 @@ def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage': total_tokens=self.total_tokens + other.total_tokens, num_requests=self.num_requests + other.num_requests, estimated_cost=estimated_cost, + retry_stats=self.retry_stats + other.retry_stats, ) def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage': @@ -511,7 +562,14 @@ def sample( total_tokens=usage.total_tokens // n, estimated_cost=( usage.estimated_cost / n if usage.estimated_cost else None - ) + ), + retry_stats=RetryStats( + num_occurences=usage.retry_stats.num_occurences // n, + total_wait_interval=usage.retry_stats.total_wait_interval + / n, + total_call_interval=usage.retry_stats.total_call_interval + / n, + ), ) # Track usage. @@ -584,16 +642,16 @@ def _sample( def _parallel_execute_with_currency_control( self, - action: Callable[..., Any], + action: Callable[..., LMSamplingResult], inputs: Sequence[Any], retry_on_errors: Union[ None, Union[Type[BaseException], Tuple[Type[BaseException], str]], Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]], ] = RetryableLMError, - ) -> Any: + ) -> list[LMSamplingResult]: """Helper method for subclasses for implementing _sample.""" - return concurrent.concurrent_execute( + executed_jobs = concurrent.concurrent_execute( action, inputs, executor=self.resource_id if self.max_concurrency else None, @@ -603,7 +661,14 @@ def _parallel_execute_with_currency_control( retry_interval=self.retry_interval, exponential_backoff=self.exponential_backoff, max_retry_interval=self.max_retry_interval, + return_jobs=True, ) + for job in executed_jobs: + job.result.usage.rebind( + retry_stats=RetryStats.from_retry_entries(job.retry_entries), + skip_notification=True, + ) + return [job.result for job in executed_jobs] def __call__( self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs diff --git a/langfun/core/language_model_test.py b/langfun/core/language_model_test.py index c61aa477..25f46e5c 100644 --- a/langfun/core/language_model_test.py +++ b/langfun/core/language_model_test.py @@ -35,34 +35,34 @@ def _sample(self, ) -> list[lm_lib.LMSamplingResult]: context = pg.Dict(attempt=0) - def fake_sample(prompts): + def fake_sample(prompt): if context.attempt >= self.failures_before_attempt: - return [ - lm_lib.LMSamplingResult( - [ - lm_lib.LMSample( # pylint: disable=g-complex-comprehension - response=prompt.text * self.sampling_options.top_k, - score=self.sampling_options.temperature or -1.0, - ) - ], - usage=lm_lib.LMSamplingUsage( - prompt_tokens=100, - completion_tokens=100, - total_tokens=200, - estimated_cost=1.0, - ), - ) - for prompt in prompts - ] - context.attempt += 1 + return lm_lib.LMSamplingResult( + [ + lm_lib.LMSample( # pylint: disable=g-complex-comprehension + response=prompt.text * self.sampling_options.top_k, + score=self.sampling_options.temperature or -1.0, + ) + ], + usage=lm_lib.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + estimated_cost=1.0, + ), + ) + else: + context.attempt += 1 raise ValueError('Failed to sample prompts.') - return concurrent.with_retry( - fake_sample, - retry_on_errors=ValueError, - max_attempts=self.max_attempts, - retry_interval=1, - )(prompts) + results = self._parallel_execute_with_currency_control( + fake_sample, prompts, retry_on_errors=ValueError + ) + for result in results: + result.usage.retry_stats.rebind( + total_call_interval=0, skip_notification=True + ) + return results @property def model_id(self) -> str: @@ -448,13 +448,50 @@ def test_using_cache(self): def test_retry(self): lm = MockModel( - failures_before_attempt=1, top_k=1, + failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1 ) with self.assertRaisesRegex( concurrent.RetryError, 'Calling .* failed after 1 attempts' ): lm('foo', max_attempts=1) - self.assertEqual(lm('foo', max_attempts=2), 'foo') + + usage = lm_lib.LMSamplingUsage( + prompt_tokens=100, + completion_tokens=100, + total_tokens=200, + num_requests=1, + estimated_cost=1.0, + retry_stats=lm_lib.RetryStats( + num_occurences=1, + total_wait_interval=1, + errors=['ValueError: Failed to sample prompts.'], + ), + ) + out = lm.sample(['foo']) + self.assertEqual( + # lm.sample(['foo'], max_attempts=2), + out, + [ + lm_lib.LMSamplingResult( + [ + lm_lib.LMSample( + message_lib.AIMessage( + 'foo', + score=-1.0, + logprobs=None, + is_cached=False, + usage=usage, + tags=['lm-response'], + ), + score=-1.0, + logprobs=None, + ) + ], + usage=usage, + is_cached=False, + ) + ], + ) def test_debug(self): class Image(modality.Modality): @@ -755,16 +792,34 @@ def test_basics(self): def test_add(self): usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0) + usage1.rebind(retry_stats=lm_lib.RetryStats(1, 3, 4, ['e1'])) usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0) self.assertEqual(usage1 + usage2, usage1 + usage2) self.assertIs(usage1 + None, usage1) self.assertIs(None + usage1, usage1) usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None) + usage3.rebind(retry_stats=lm_lib.RetryStats(2, 4, 5, ['e2', 'e3'])) self.assertEqual( - usage1 + usage3, lm_lib.LMSamplingUsage(200, 400, 600, 8, 5.0) + usage1 + usage3, + lm_lib.LMSamplingUsage( + 200, + 400, + 600, + 8, + 5.0, + retry_stats=lm_lib.RetryStats(3, 7, 9, ['e1', 'e2', 'e3']), + ), ) self.assertEqual( - usage3 + usage1, lm_lib.LMSamplingUsage(200, 400, 600, 8, 5.0) + usage3 + usage1, + lm_lib.LMSamplingUsage( + 200, + 400, + 600, + 8, + 5.0, + retry_stats=lm_lib.RetryStats(3, 7, 9, ['e2', 'e3', 'e1']), + ), ) def test_usage_not_available(self):