Skip to content

Commit

Permalink
allow passing source info to PyGloveTunableMixture
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583328931
  • Loading branch information
tomvdw authored and SeqIO committed Nov 17, 2023
1 parent 44fecb3 commit 62f23fc
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,6 +2076,7 @@ def __init__(
sample_fn: SampleFn = functools.partial(
tf.data.Dataset.sample_from_datasets, stop_on_empty_dataset=True
),
source_info: Optional[SourceInfo] = None,
):
def hyper_ratio(task_name, hyper):
"""Function for converting PyGlove hyper primitive as ratio fn."""
Expand All @@ -2100,6 +2101,7 @@ def ratio_fn(unused_task):
tasks=converted_tasks,
default_rate=default_rate,
sample_fn=sample_fn,
source_info=source_info,
)

def _get_submixture_rate(self, mix: "Mixture") -> float:
Expand Down

0 comments on commit 62f23fc

Please sign in to comment.