Skip to content

Commit

Permalink
Better error message when SeqIO metric computation fails
Browse files Browse the repository at this point in the history
1. Attach the task name to the error. There is some "Computing metrics for ..." logging before evaluating each task, but (1) that's hard to discover when trawling through verbose logs, and (2) it's potentially far away due to the delay explained below.
2. Immediately log the error. Due to the `self._metrics_future`, each round will only check that the "_previous_ step's metrics are finished", i.e. the error is thrown out to the user much later than it actually happened. This can make debugging more difficult, so additionally log the error right when it happened, before propagating it out to the Future.

PiperOrigin-RevId: 579187158
  • Loading branch information
SeqIO Team authored and SeqIO committed Feb 5, 2024
1 parent 7639986 commit e30f08c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 92 deletions.
178 changes: 92 additions & 86 deletions seqio/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import inspect
import itertools
import time
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, Union

from absl import logging
import clu.metrics
Expand All @@ -44,7 +44,8 @@
AllOutputTokensType = Mapping[str, Sequence[Sequence[int]]]
AllOutputScoresType = Any # Mapping[str, Sequence[float]]
AllOutputAuxValuesType = Mapping[str, Mapping[str, Sequence[Any]]]
AllMetricsType = Mapping[str, Mapping[str, Any]]
TaskMetricsType = MutableMapping[str, Any]
AllMetricsType = MutableMapping[str, TaskMetricsType]


class AllMetricsFuture(typing_extensions.Protocol):
Expand Down Expand Up @@ -656,9 +657,17 @@ def evaluate(

def compute_metrics_fn():
tick = time.time()
metrics = self._compute_clu_metrics(all_output, step)
all_metrics: AllMetricsType = {}
for task in self.eval_tasks:
try:
all_metrics[task.name] = self._compute_clu_metrics(
task, all_output[task.name], step
)
except Exception as e:
logging.error("Failed to evaluate task %s: %s", task.name, e)
raise ValueError(f"Failed to evaluate task {task.name}: {e}") from e
logging.info("Time computing metrics: %f secs.", time.time() - tick)
return metrics
return all_metrics

def wrap_graph(fn):
graph = tf.compat.v1.get_default_graph()
Expand All @@ -684,102 +693,99 @@ def wrapped_fn():
return all_metrics, all_output

def _compute_clu_metrics(
self, all_output, step: Optional[int] = None
) -> AllMetricsType:
self, task: Task, task_output, step: Optional[int] = None
) -> TaskMetricsType:
"""Computes and logs metrics given the predicted tokens and scores.
Args:
all_output: a mapping from task name to the model outputs needed by
metrics of that task.
task: the task being evaluated.
task_output: the model outputs needed by metrics of that task.
step: an optional step number of the current evaluation. If unspecified, a
dummy value of -1 will be used.
Returns:
A mapping from task name to computed metrics.
The computed metrics.
"""
all_metrics = {}

for task in self.eval_tasks:
logging.info("Computing metrics for %s", task.name)
task_dataset = self.cached_task_datasets[task.name]

task_metrics = []
inferences = {}
for metric_obj in task.metric_objs:
model_output = all_output[task.name][metric_obj.model_output_type]
# When model output type is PREDICTION_WITH_AUX or
# SCORE_WITH_INTERMEDIATES, model output is a tuple of two arrays/lists.
if isinstance(model_output, tuple):
prediction_or_score, aux_value = model_output
aux_value = jax.tree_map(
np.array,
aux_value,
is_leaf=lambda x: isinstance(x, list),
)
model_output = (np.array(prediction_or_score), aux_value)
else:
model_output = np.array(model_output)
metric_instance = metric_obj.from_model_output(
logging.info("Computing metrics for %s", task.name)
task_dataset = self.cached_task_datasets[task.name]

task_metrics = []
inferences = {}
targets_and_inferences = None
for metric_obj in task.metric_objs:
model_output = task_output[metric_obj.model_output_type]
# When model output type is PREDICTION_WITH_AUX or
# SCORE_WITH_INTERMEDIATES, model output is a tuple of two arrays/lists.
if isinstance(model_output, tuple):
prediction_or_score, aux_value = model_output
aux_value = jax.tree_map(
np.array,
aux_value,
is_leaf=lambda x: isinstance(x, list),
)
model_output = (np.array(prediction_or_score), aux_value)
else:
model_output = np.array(model_output)
metric_instance = metric_obj.from_model_output(
tfds.as_numpy(task_dataset),
model_output,
task.output_features,
self._target_field_name,
)
if isinstance(metric_instance, metrics_lib.CollectingMetric):
metric_value, targets_and_inferences = metric_instance.actual_compute(
tfds.as_numpy(task_dataset),
model_output,
task.output_features,
self._target_field_name,
self._cached_targets[task.name],
)
if isinstance(metric_instance, metrics_lib.CollectingMetric):
metric_value, targets_and_inferences = metric_instance.actual_compute(
tfds.as_numpy(task_dataset),
task.output_features,
self._target_field_name,
self._cached_targets[task.name],
)
self._cached_targets[task.name] = targets_and_inferences["targets"]
else:
metric_value = metric_instance.compute()
targets_and_inferences = None
if hasattr(metric_instance, "targets_and_inferences"):
targets_and_inferences = metric_instance.targets_and_inferences
task_metrics.append(metric_value)
# Records inferences for legacy logging compatibility.
# common ones are score, output, prediction.
if targets_and_inferences:
for key, val in targets_and_inferences.items():
if key == "targets":
continue
inferences[key] = (
val.tolist() if isinstance(val, np.ndarray) else val
)
# Records targets for legacy logging compatibility.
# Each targets_and_inferences should have identical targets.
# Chooses the last targets_and_inferences for this recording purpose.
if targets_and_inferences:
targets = targets_and_inferences["targets"]
self._cached_targets[task.name] = targets_and_inferences["targets"]
else:
targets = None

all_metrics[task.name] = {}
for k, v in itertools.chain(*[m.items() for m in task_metrics]):
if k in all_metrics[task.name]:
raise ValueError(f"Duplicate metric key '{k}' in Task '{task.name}'.")
all_metrics[task.name][k] = v

# pyformat: disable
metrics = {
k: metrics_lib.Scalar(v)
if not isinstance(v, metrics_lib.MetricValue) else v
for k, v in all_metrics[task.name].items()
}
# pyformat: enable
for logger in self.loggers:
logger(
task_name=task.name,
step=step,
metrics=metrics,
dataset=task_dataset,
targets=targets,
inferences=inferences,
metric_value = metric_instance.compute()
targets_and_inferences = None
if hasattr(metric_instance, "targets_and_inferences"):
targets_and_inferences = metric_instance.targets_and_inferences
task_metrics.append(metric_value)
# Records inferences for legacy logging compatibility.
# common ones are score, output, prediction.
if targets_and_inferences:
for key, val in targets_and_inferences.items():
if key == "targets":
continue
inferences[key] = val.tolist() if isinstance(val, np.ndarray) else val
# Records targets for legacy logging compatibility.
# Each targets_and_inferences should have identical targets.
# Chooses the last targets_and_inferences for this recording purpose.
if targets_and_inferences:
targets = targets_and_inferences["targets"]
else:
targets = None

result: TaskMetricsType = {}
# TODO(b/309107492): Fix that attribute-error instead of silencing it.
for k, v in itertools.chain(*[m.items() for m in task_metrics]): # pytype: disable=attribute-error
if k in result:
raise ValueError(f"Duplicate metric key '{k}' in Task '{task.name}'.")
result[k] = v

metrics = {
k: (
metrics_lib.Scalar(v)
if not isinstance(v, metrics_lib.MetricValue)
else v
)

return all_metrics
for k, v in result.items()
}
for logger in self.loggers:
logger(
task_name=task.name,
step=step,
metrics=metrics,
dataset=task_dataset,
targets=targets,
inferences=inferences,
)
return result

@property
def eval_tasks(self) -> Sequence[Task]:
Expand Down
21 changes: 15 additions & 6 deletions seqio/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,29 +1175,38 @@ def mixing_order_predict_fn(
all_outputs[task.name][metrics_lib.ModelOutputType.PREDICTION],
)

def test_includes_task_name_in_error_message(self):
task = get_mocked_task(
name="test_task_number_42",
predict_metric_fns=[_accuracy_metric, _accuracy_metric],
)
with self.assertRaisesRegex(ValueError, ".*test_task_number_42*"):
self._evaluate_single_task(task)

def test_duplicate_metric(self):
task = get_mocked_task(
predict_metric_fns=[_accuracy_metric, _accuracy_metric]
)
with self.assertRaisesWithLiteralMatch(
ValueError, "Duplicate metric key 'accuracy' in Task 'mocked_test'."
with self.assertRaisesRegex(
ValueError, ".*Duplicate metric key 'accuracy' in Task 'mocked_test'.*"
):
self._evaluate_single_task(task)

task = get_mocked_task(
score_metric_fns=[_sum_scores_metric, _sum_scores_metric]
)
with self.assertRaisesWithLiteralMatch(
ValueError, "Duplicate metric key 'total_score' in Task 'mocked_test'."
with self.assertRaisesRegex(
ValueError,
".*Duplicate metric key 'total_score' in Task 'mocked_test'.*",
):
self._evaluate_single_task(task)

task = get_mocked_task(
predict_metric_fns=[_accuracy_metric],
score_metric_fns=[lambda targets, scores: {"accuracy": 0}],
)
with self.assertRaisesWithLiteralMatch(
ValueError, "Duplicate metric key 'accuracy' in Task 'mocked_test'."
with self.assertRaisesRegex(
ValueError, ".*Duplicate metric key 'accuracy' in Task 'mocked_test'.*"
):
self._evaluate_single_task(task)

Expand Down

0 comments on commit e30f08c

Please sign in to comment.