Skip to content

Commit

Permalink
Revert part of some breaking changes in using rate_per_task_name
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596903315
  • Loading branch information
tomvdw authored and SeqIO committed Jan 9, 2024
1 parent 19e6aec commit 24619f5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion seqio/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def add_fully_cached_mixture(
return MixtureRegistry.add(
new_name,
[
(new_t.name, mixture.rate_per_task_name[old_t.name])
(new_t.name, mixture._task_to_rate[old_t.name]) # pylint:disable=protected-access
for old_t, new_t in zip(mixture.tasks, new_tasks)
],
)
Expand Down
6 changes: 3 additions & 3 deletions seqio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _validate_output_features(og_output_features, new_output_features):
# This is a Mixture. Create and register new sub-Tasks/Mixtures with the
# provided vocab/output_features, then create a new Mixture.
new_tasks_and_rates = []
for task_name, rate in mixture_or_task.rate_per_task_name.items():
for task_name, rate in mixture_or_task._task_to_rate.items():
new_task_name = f"{new_mixture_or_task_name}.{task_name}"
new_task = mixture_or_task_with_new_vocab(
task_name,
Expand Down Expand Up @@ -306,7 +306,7 @@ def mixture_or_task_with_truncated_data(
# This is a Mixture. Create and register new sub-Tasks/Mixtures with the
# provided vocab/output_features, then create a new Mixture.
new_tasks_and_rates = []
for task_name, rate in mixture_or_task.rate_per_task_name.items():
for task_name, rate in mixture_or_task._task_to_rate.items():
new_task = mixture_or_task_with_truncated_data(
task_name,
f"{new_mixture_or_task_name}.{task_name}",
Expand Down Expand Up @@ -354,7 +354,7 @@ def mixture_with_missing_task_splits_removed(
"""
og_mix: dp.Mixture = dp.get_mixture_or_task(mixture_name) # pytype: disable=annotation-type-mismatch # always-use-return-annotations
new_tasks_and_rates = []
for task_name, rate in og_mix.rate_per_task_name.items():
for task_name, rate in og_mix._task_to_rate.items():
subtask: dp.Task = dp.get_mixture_or_task(task_name) # pytype: disable=annotation-type-mismatch # always-use-return-annotations
if split in subtask.splits:
new_tasks_and_rates.append((subtask.name, rate))
Expand Down

0 comments on commit 24619f5

Please sign in to comment.