diff --git a/RELEASE.md b/RELEASE.md index 4ee6afb11b..f89df937e0 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -8,6 +8,12 @@ and this project adheres to ## Unreleased +### Added + +* `tff.tensorflow.transform_args` and `tff.tnesorflow.transform_result`, these + functions are intended to be used when instantiating and execution context + in a TensorFlow environment. + ## Release 0.85.0 ### Added diff --git a/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py b/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py index 37a834f827..b8fd425a08 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py +++ b/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py @@ -21,6 +21,7 @@ from tensorflow_federated.python.core.backends.mapreduce import form_utils from tensorflow_federated.python.core.backends.mapreduce import mapreduce_test_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory +from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.impl.compiler import building_block_factory from tensorflow_federated.python.core.impl.compiler import building_block_test_utils from tensorflow_federated.python.core.impl.compiler import building_blocks @@ -43,7 +44,11 @@ def _create_test_context(): factory = executor_factory.local_cpp_executor_factory() - return sync_execution_context.SyncExecutionContext(executor_fn=factory) + return sync_execution_context.SyncExecutionContext( + executor_fn=factory, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, + ) class CheckExtractionResultTest(absltest.TestCase): diff --git a/tensorflow_federated/python/core/backends/native/BUILD b/tensorflow_federated/python/core/backends/native/BUILD index d9ca88c2dc..0df9d6c427 100644 --- a/tensorflow_federated/python/core/backends/native/BUILD +++ b/tensorflow_federated/python/core/backends/native/BUILD @@ -68,6 +68,7 @@ py_library( deps = [ ":compiler", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_executor_bindings", + "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/impl/context_stack:set_default_context", "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", @@ -103,6 +104,7 @@ py_library( deps = [ ":compiler", ":mergeable_comp_compiler", + "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/impl/context_stack:context_base", "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", diff --git a/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py b/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py index fca28416ca..d61dc3c46e 100644 --- a/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py +++ b/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py @@ -17,6 +17,7 @@ from tensorflow_federated.python.core.backends.native import compiler from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_executor_bindings +from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.impl.context_stack import set_default_context from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context @@ -65,7 +66,10 @@ def create_sync_local_cpp_execution_context( leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) context = sync_execution_context.SyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native + executor_fn=factory, + compiler_fn=compiler.desugar_and_transform_to_native, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) return context @@ -120,7 +124,10 @@ def create_async_local_cpp_execution_context( leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) context = async_execution_context.AsyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native + executor_fn=factory, + compiler_fn=compiler.desugar_and_transform_to_native, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) return context @@ -158,7 +165,10 @@ def create_sync_remote_cpp_execution_context( channels=channels, default_num_clients=default_num_clients ) context = sync_execution_context.SyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native + executor_fn=factory, + compiler_fn=compiler.desugar_and_transform_to_native, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) return context @@ -182,6 +192,9 @@ def create_async_remote_cpp_execution_context( channels=channels, default_num_clients=default_num_clients ) context = async_execution_context.AsyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native + executor_fn=factory, + compiler_fn=compiler.desugar_and_transform_to_native, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) return context diff --git a/tensorflow_federated/python/core/backends/native/execution_contexts.py b/tensorflow_federated/python/core/backends/native/execution_contexts.py index 1958ca2722..f09be387e2 100644 --- a/tensorflow_federated/python/core/backends/native/execution_contexts.py +++ b/tensorflow_federated/python/core/backends/native/execution_contexts.py @@ -18,6 +18,7 @@ from tensorflow_federated.python.core.backends.native import compiler from tensorflow_federated.python.core.backends.native import mergeable_comp_compiler +from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.impl.context_stack import context_base from tensorflow_federated.python.core.impl.context_stack import context_stack_impl from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context @@ -103,7 +104,10 @@ def create_async_local_cpp_execution_context( stream_structs=stream_structs, ) return async_execution_context.AsyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native + executor_fn=factory, + compiler_fn=compiler.desugar_and_transform_to_native, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) @@ -150,7 +154,10 @@ def create_sync_local_cpp_execution_context( stream_structs=stream_structs, ) return sync_execution_context.SyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native + executor_fn=factory, + compiler_fn=compiler.desugar_and_transform_to_native, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) diff --git a/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py b/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py index 6539fa89bd..9a1e744f07 100644 --- a/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py +++ b/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py @@ -29,7 +29,11 @@ def _create_test_context(): factory = executor_factory.local_cpp_executor_factory() - context = async_execution_context.AsyncExecutionContext(factory) + context = async_execution_context.AsyncExecutionContext( + executor_fn=factory, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, + ) return mergeable_comp_execution_context.MergeableCompExecutionContext( [context] ) diff --git a/tensorflow_federated/python/core/backends/test/BUILD b/tensorflow_federated/python/core/backends/test/BUILD index 3bf14c7dd3..6af599223f 100644 --- a/tensorflow_federated/python/core/backends/test/BUILD +++ b/tensorflow_federated/python/core/backends/test/BUILD @@ -75,6 +75,7 @@ py_library( ":compiler", "//tensorflow_federated/python/core/backends/native:compiler", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_executor_bindings", + "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", @@ -108,6 +109,7 @@ py_library( deps = [ ":compiler", "//tensorflow_federated/python/core/backends/native:compiler", + "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", diff --git a/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py b/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py index 0606db3e43..322ac3ca2d 100644 --- a/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py +++ b/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py @@ -26,6 +26,7 @@ from tensorflow_federated.python.core.backends.native import compiler as native_compiler from tensorflow_federated.python.core.backends.test import compiler as test_compiler from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_executor_bindings +from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.impl.context_stack import context_stack_impl from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context @@ -93,7 +94,10 @@ def _compile(comp): leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) context = async_execution_context.AsyncExecutionContext( - executor_fn=factory, compiler_fn=_compile + executor_fn=factory, + compiler_fn=_compile, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) return context @@ -228,6 +232,8 @@ def initialize_channel(self) -> None: return sync_execution_context.SyncExecutionContext( executor_fn=ManagedServiceContext(), compiler_fn=native_compiler.desugar_and_transform_to_native, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) @@ -273,7 +279,10 @@ def _compile(comp): leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) context = sync_execution_context.SyncExecutionContext( - executor_fn=factory, compiler_fn=_compile + executor_fn=factory, + compiler_fn=_compile, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) return context diff --git a/tensorflow_federated/python/core/backends/test/execution_contexts.py b/tensorflow_federated/python/core/backends/test/execution_contexts.py index b48e5882c6..f62bb2008b 100644 --- a/tensorflow_federated/python/core/backends/test/execution_contexts.py +++ b/tensorflow_federated/python/core/backends/test/execution_contexts.py @@ -15,6 +15,7 @@ from tensorflow_federated.python.core.backends.native import compiler as native_compiler from tensorflow_federated.python.core.backends.test import compiler as test_compiler +from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.impl.context_stack import context_stack_impl from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context @@ -44,6 +45,8 @@ def _compile(comp): return async_execution_context.AsyncExecutionContext( executor_fn=factory, compiler_fn=_compile, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) @@ -85,6 +88,8 @@ def _compile(comp): return sync_execution_context.SyncExecutionContext( executor_fn=factory, compiler_fn=_compile, + transform_args=tensorflow_computation.transform_args, + transform_result=tensorflow_computation.transform_result, ) diff --git a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py index c2bfcb9a3f..2240672813 100644 --- a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py +++ b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py @@ -55,7 +55,8 @@ def create_async_local_cpp_execution_context( leaf_executor_fn=_create_xla_backend_execution_stack, ) return async_execution_context.AsyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.transform_to_native_form + executor_fn=factory, + compiler_fn=compiler.transform_to_native_form, ) @@ -95,7 +96,8 @@ def create_sync_local_cpp_execution_context( # computations instead of TensorFlow, similar to "desugar intrinsics" in the # native backend. return sync_execution_context.SyncExecutionContext( - executor_fn=factory, compiler_fn=compiler.transform_to_native_form + executor_fn=factory, + compiler_fn=compiler.transform_to_native_form, ) diff --git a/tensorflow_federated/python/core/environments/tensorflow/__init__.py b/tensorflow_federated/python/core/environments/tensorflow/__init__.py index 468daf4915..ade25e0f2d 100644 --- a/tensorflow_federated/python/core/environments/tensorflow/__init__.py +++ b/tensorflow_federated/python/core/environments/tensorflow/__init__.py @@ -16,4 +16,6 @@ # pylint: disable=g-importing-member from tensorflow_federated.python.core.environments.tensorflow_backend.tensorflow_tree_transformations import replace_intrinsics_with_bodies from tensorflow_federated.python.core.environments.tensorflow_frontend.tensorflow_computation import tf_computation as computation +from tensorflow_federated.python.core.environments.tensorflow_frontend.tensorflow_computation import transform_args +from tensorflow_federated.python.core.environments.tensorflow_frontend.tensorflow_computation import transform_result # pylint: enable=g-importing-member diff --git a/tensorflow_federated/python/core/impl/execution_contexts/BUILD b/tensorflow_federated/python/core/impl/execution_contexts/BUILD index feb1d56475..43c2b77393 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/BUILD +++ b/tensorflow_federated/python/core/impl/execution_contexts/BUILD @@ -85,7 +85,6 @@ py_library( "//tensorflow_federated/python/core/impl/compiler:building_blocks", "//tensorflow_federated/python/core/impl/compiler:tree_analysis", "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/computation:function_utils", "//tensorflow_federated/python/core/impl/context_stack:context_base", "//tensorflow_federated/python/core/impl/executors:cardinalities_utils", "//tensorflow_federated/python/core/impl/types:computation_types", diff --git a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py index 7a81390d58..537bdb0bb8 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py @@ -14,13 +14,12 @@ """Execution context for single-aggregation computations.""" import asyncio -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Sequence import functools import math from typing import Generic, Optional, TypeVar, Union import attrs -import tree from tensorflow_federated.python.common_libs import async_utils from tensorflow_federated.python.common_libs import py_typecheck @@ -28,7 +27,6 @@ from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import function_utils from tensorflow_federated.python.core.impl.context_stack import context_base from tensorflow_federated.python.core.impl.execution_contexts import compiler_pipeline from tensorflow_federated.python.core.impl.executors import cardinalities_utils @@ -695,36 +693,6 @@ def invoke( comp, (MergeableCompForm, computation_base.Computation) ) - if arg is not None and self._transform_args is not None: - # `transform_args` is not intended to handle `tff.structure.Struct`. - # Normalize to a Python structure to make it simpler to handle; `args` is - # sometimes a `tff.structure.Struct` and sometimes it is not, other times - # it is a Python structure that contains a `tff.structure.Struct`. - def _to_python(obj): - if isinstance(obj, structure.Struct): - return structure.to_odict_or_tuple(obj) - else: - return None - - if isinstance(arg, structure.Struct): - args, kwargs = function_utils.unpack_args_from_struct(arg) - args = tree.traverse(_to_python, args) - args = self._transform_args(args) - if not isinstance(args, Sequence): - raise ValueError( - f'Expected `args` to be a `Sequence`, found {type(args)}' - ) - kwargs = tree.traverse(_to_python, kwargs) - kwargs = self._transform_args(kwargs) - if not isinstance(kwargs, Mapping): - raise ValueError( - f'Expected `kwargs` to be a `Mapping`, found {type(kwargs)}' - ) - arg = function_utils.pack_args_into_struct(args, kwargs) - else: - arg = tree.traverse(_to_python, arg) - arg = self._transform_args(arg) - if isinstance(comp, computation_base.Computation): if self._compiler_pipeline is None: raise ValueError( diff --git a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py index 7b4d61df4e..87e20148e5 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py @@ -36,6 +36,8 @@ def __init__( executor_fn: executor_factory.ExecutorFactory, compiler_fn: Optional[Callable[[_Computation], object]] = None, *, + transform_args: Optional[Callable[[object], object]] = None, + transform_result: Optional[Callable[[object], object]] = None, cardinality_inference_fn: cardinalities_utils.CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities, ): """Initializes a synchronous execution context which retries invocations. @@ -43,6 +45,10 @@ def __init__( Args: executor_fn: Instance of `executor_factory.ExecutorFactory`. compiler_fn: A Python function that will be used to compile a computation. + transform_args: An `Optional` `Callable` used to transform the args before + they are passed to the computation. + transform_result: An `Optional` `Callable` used to transform the result + before it is returned. cardinality_inference_fn: A Python function specifying how to infer cardinalities from arguments (and their associated types). The value returned by this function will be passed to the `create_executor` method @@ -53,6 +59,8 @@ def __init__( self._async_context = async_execution_context.AsyncExecutionContext( executor_fn=executor_fn, compiler_fn=compiler_fn, + transform_args=transform_args, + transform_result=transform_result, cardinality_inference_fn=cardinality_inference_fn, ) self._async_runner = async_utils.AsyncThreadRunner() diff --git a/tensorflow_federated/python/tests/mergeable_comp_execution_context_integration_test.py b/tensorflow_federated/python/tests/mergeable_comp_execution_context_integration_test.py index 7876024bab..70d3c02ff3 100644 --- a/tensorflow_federated/python/tests/mergeable_comp_execution_context_integration_test.py +++ b/tensorflow_federated/python/tests/mergeable_comp_execution_context_integration_test.py @@ -334,6 +334,8 @@ def test_runs_computation_with_clients_placed_return_values( max_concurrent_computation_calls=1 ), compiler_fn=tff.backends.native.desugar_and_transform_to_native, + transform_args=tff.tensorflow.transform_args, + transform_result=tff.tensorflow.transform_result, ) contexts.append(context) mergeable_comp_context = tff.framework.MergeableCompExecutionContext( @@ -410,6 +412,8 @@ def test_computes_sum_of_clients_values( max_concurrent_computation_calls=1 ), compiler_fn=tff.backends.native.desugar_and_transform_to_native, + transform_args=tff.tensorflow.transform_args, + transform_result=tff.tensorflow.transform_result, ) contexts.append(context) mergeable_comp_context = tff.framework.MergeableCompExecutionContext( @@ -464,6 +468,8 @@ def test_computes_sum_of_all_values(self, arg, expected_sum, num_subrounds): max_concurrent_computation_calls=1 ), compiler_fn=tff.backends.native.desugar_and_transform_to_native, + transform_args=tff.tensorflow.transform_args, + transform_result=tff.tensorflow.transform_result, ) contexts.append(context) mergeable_comp_context = tff.framework.MergeableCompExecutionContext( @@ -500,6 +506,8 @@ def test_counts_clients_with_noarg_computation(self, num_subrounds): max_concurrent_computation_calls=1, ), compiler_fn=tff.backends.native.desugar_and_transform_to_native, + transform_args=tff.tensorflow.transform_args, + transform_result=tff.tensorflow.transform_result, ) contexts.append(context) mergeable_comp_context = tff.framework.MergeableCompExecutionContext( @@ -516,7 +524,11 @@ def return_one(): return 1 factory = tff.framework.local_cpp_executor_factory() - context = tff.framework.AsyncExecutionContext(factory) + context = tff.framework.AsyncExecutionContext( + executor_fn=factory, + transform_args=tff.tensorflow.transform_args, + transform_result=tff.tensorflow.transform_result, + ) context = tff.framework.MergeableCompExecutionContext([context]) with self.assertRaises(ValueError): @@ -528,7 +540,11 @@ def return_one(): return 1 factory = tff.framework.local_cpp_executor_factory() - context = tff.framework.AsyncExecutionContext(factory) + context = tff.framework.AsyncExecutionContext( + executor_fn=factory, + transform_args=tff.tensorflow.transform_args, + transform_result=tff.tensorflow.transform_result, + ) context = tff.framework.MergeableCompExecutionContext( [context], compiler_fn=lambda x: x )