Skip to content

Commit

Permalink
Move transform_args and transform_result from ConcreteComputation t…
Browse files Browse the repository at this point in the history
…o the appropriate execution contexts.

This allows the context to transform the args of any computation it invokes.

PiperOrigin-RevId: 663340076
  • Loading branch information
michaelreneer authored and copybara-github committed Aug 15, 2024
1 parent c006062 commit c10f4dd
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,6 @@ def call_concrete(*args):
concrete = computation_impl.ConcreteComputation(
computation_proto=comp.proto,
context_stack=context_stack_impl.context_stack,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
result = concrete(*args)
if isinstance(comp.type_signature.result, computation_types.StructType):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ def _uniquify_reference_names(comp: computation_impl.ConcreteComputation):
return computation_impl.ConcreteComputation(
computation_proto=transformed_comp.proto,
context_stack=context_stack_impl.context_stack,
transform_args=comp.transform_args,
transform_result=comp.transform_result,
)

return DistributeAggregateFormExample(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ def get_state_initialization_computation(
return computation_impl.ConcreteComputation(
computation_proto=initialize_tree.proto,
context_stack=context_stack_impl.context_stack,
transform_args=initialize_computation.transform_args,
transform_result=initialize_computation.transform_result,
)


Expand Down Expand Up @@ -1029,8 +1027,6 @@ def _create_comp(proto):
return computation_impl.ConcreteComputation(
computation_proto=proto,
context_stack=context_stack_impl.context_stack,
transform_args=comp.transform_args,
transform_result=comp.transform_result,
)

compute_server_context, client_processing = (
Expand Down Expand Up @@ -1117,8 +1113,6 @@ def _create_comp(proto):
return computation_impl.ConcreteComputation(
computation_proto=proto,
context_stack=context_stack_impl.context_stack,
transform_args=comp.transform_args,
transform_result=comp.transform_result,
)

blocks = (
Expand Down Expand Up @@ -1454,8 +1448,6 @@ def _create_comp(proto):
return computation_impl.ConcreteComputation(
computation_proto=proto,
context_stack=context_stack_impl.context_stack,
transform_args=comp.transform_args,
transform_result=comp.transform_result,
)

comps = [_create_comp(bb.proto) for bb in blocks]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def _uniquify_reference_names(comp: computation_impl.ConcreteComputation):
return computation_impl.ConcreteComputation(
computation_proto=transformed_comp.proto,
context_stack=context_stack_impl.context_stack,
transform_args=comp.transform_args,
transform_result=comp.transform_result,
)

return MapReduceFormExample(
Expand Down
1 change: 0 additions & 1 deletion tensorflow_federated/python/core/backends/native/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ py_library(
"//tensorflow_federated/python/core/backends/mapreduce:compiler",
"//tensorflow_federated/python/core/environments/tensorflow_backend:compiled_computation_transformations",
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations",
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/compiler:building_blocks",
"//tensorflow_federated/python/core/impl/compiler:transformations",
"//tensorflow_federated/python/core/impl/computation:computation_impl",
Expand Down
5 changes: 0 additions & 5 deletions tensorflow_federated/python/core/backends/native/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from tensorflow_federated.python.core.backends.mapreduce import compiler
from tensorflow_federated.python.core.environments.tensorflow_backend import compiled_computation_transformations
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.compiler import transformations
from tensorflow_federated.python.core.impl.computation import computation_impl
Expand Down Expand Up @@ -111,8 +110,6 @@ def transform_to_native_form(
return computation_impl.ConcreteComputation(
computation_proto=form_with_ids.proto,
context_stack=context_stack_impl.context_stack,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
except ValueError as e:
logging.debug('Compilation for native runtime failed with error %s', e)
Expand Down Expand Up @@ -148,8 +145,6 @@ def desugar_and_transform_to_native(comp):
computation_impl.ConcreteComputation(
computation_proto=intrinsics_desugared_bb.proto,
context_stack=context_stack_impl.context_stack,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
),
grappler_config=grappler_config,
)
Expand Down
1 change: 0 additions & 1 deletion tensorflow_federated/python/core/backends/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ py_library(
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_building_block_factory",
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_computation_factory",
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations",
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/compiler:building_block_factory",
"//tensorflow_federated/python/core/impl/compiler:building_blocks",
"//tensorflow_federated/python/core/impl/compiler:intrinsic_defs",
Expand Down
3 changes: 0 additions & 3 deletions tensorflow_federated/python/core/backends/test/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_building_block_factory
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.compiler import building_block_factory
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
Expand Down Expand Up @@ -344,6 +343,4 @@ def replace_secure_intrinsics_with_bodies(comp):
return computation_impl.ConcreteComputation(
computation_proto=replaced_intrinsic_bodies.proto,
context_stack=context_stack_impl.context_stack,
transform_args=tensorflow_computation.transform_args,
transform_result=tensorflow_computation.transform_result,
)
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ def _tf_wrapper_fn(
computation_proto=comp_pb,
context_stack=context_stack,
annotated_type=extra_type_spec,
transform_args=transform_args,
transform_result=transform_result,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
"""Defines the abstract interface for classes that represent computations."""

import abc
from collections.abc import Callable
from typing import Optional

from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import typed_object
Expand All @@ -30,16 +28,6 @@ def type_signature(self) -> computation_types.FunctionType:
"""Returns the TFF type of this object."""
raise NotImplementedError

@property
def transform_args(self) -> Optional[Callable[[object], object]]:
"""A Callable used to transform the arguments to the computation."""
return None

@property
def transform_result(self) -> Optional[Callable[[object], object]]:
"""A Callable used to transform the result of the computation."""
return None

@abc.abstractmethod
def __call__(self, *args, **kwargs):
"""Invokes the computation with the given arguments in the given context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Defines the implementation of the base Computation interface."""

from collections.abc import Callable
from typing import Optional

from tensorflow_federated.proto.v0 import computation_pb2 as pb
Expand Down Expand Up @@ -90,8 +89,6 @@ def __init__(
computation_proto: pb.Computation,
context_stack: context_stack_base.ContextStack,
annotated_type: Optional[computation_types.FunctionType] = None,
transform_args: Optional[Callable[[object], object]] = None,
transform_result: Optional[Callable[[object], object]] = None,
):
"""Constructs a new instance of ConcreteComputation from the computation_proto.
Expand All @@ -101,10 +98,6 @@ def __init__(
context_stack: The context stack to use.
annotated_type: Optional, type information with additional annotations
that replaces the information in `computation_proto.type`.
transform_args: An `Optional` `Callable` used to transform the args before
they are passed to the computation.
transform_result: An `Optional` `Callable` used to transform the result
before it is returned.
Raises:
TypeError: If `annotated_type` is not `None` and is not compatible with
Expand Down Expand Up @@ -135,8 +128,6 @@ def __init__(
self._type_signature = type_spec
self._context_stack = context_stack
self._computation_proto = computation_proto
self._transform_args = transform_args
self._transform_result = transform_result

def __eq__(self, other: object) -> bool:
if self is other:
Expand All @@ -149,14 +140,6 @@ def __eq__(self, other: object) -> bool:
def type_signature(self) -> computation_types.FunctionType:
return self._type_signature

@property
def transform_args(self):
return self._transform_args

@property
def transform_result(self):
return self._transform_result

def __call__(self, *args, **kwargs):
arg = function_utils.pack_args(self._type_signature.parameter, args, kwargs)
result = self._context_stack.current.invoke(self, arg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ py_library(
"//tensorflow_federated/python/core/impl/compiler:building_blocks",
"//tensorflow_federated/python/core/impl/compiler:tree_analysis",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/computation:function_utils",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/executors:cardinalities_utils",
"//tensorflow_federated/python/core/impl/types:computation_types",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""A context for execution based on an embedded executor instance."""

import asyncio
from collections.abc import Callable
from collections.abc import Callable, Mapping, Sequence
import contextlib
from typing import Generic, Optional, TypeVar

Expand Down Expand Up @@ -154,13 +154,19 @@ def __init__(
executor_fn: executor_factory.ExecutorFactory,
compiler_fn: Optional[Callable[[_Computation], object]] = None,
*,
transform_args: Optional[Callable[[object], object]] = None,
transform_result: Optional[Callable[[object], object]] = None,
cardinality_inference_fn: cardinalities_utils.CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities,
):
"""Initializes an execution context.
Args:
executor_fn: Instance of `executor_factory.ExecutorFactory`.
compiler_fn: A Python function that will be used to compile a computation.
transform_args: An `Optional` `Callable` used to transform the args before
they are passed to the computation.
transform_result: An `Optional` `Callable` used to transform the result
before it is returned.
cardinality_inference_fn: A Python function specifying how to infer
cardinalities from arguments (and their associated types). The value
returned by this function will be passed to the `create_executor` method
Expand All @@ -173,6 +179,8 @@ def __init__(
self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn)
else:
self._compiler_pipeline = None
self._transform_args = transform_args
self._transform_result = transform_result
self._cardinality_inference_fn = cardinality_inference_fn

@contextlib.contextmanager
Expand Down Expand Up @@ -205,7 +213,7 @@ async def invoke(self, comp, arg):
f'Expected a `tff.FunctionType`, found {comp.type_signature}.'
)

if arg is not None and comp.transform_args is not None:
if arg is not None and self._transform_args is not None:
# `transform_args` is not intended to handle `tff.structure.Struct`.
# Normalize to a Python structure to make it simpler to handle; `args` is
# sometimes a `tff.structure.Struct` and sometimes it is not, other times
Expand All @@ -219,13 +227,21 @@ def _to_python(obj):
if isinstance(arg, structure.Struct):
args, kwargs = function_utils.unpack_args_from_struct(arg)
args = tree.traverse(_to_python, args)
args = comp.transform_args(args)
args = self._transform_args(args)
if not isinstance(args, Sequence):
raise ValueError(
f'Expected `args` to be a `Sequence`, found {type(args)}'
)
kwargs = tree.traverse(_to_python, kwargs)
kwargs = comp.transform_args(kwargs)
kwargs = self._transform_args(kwargs)
if not isinstance(kwargs, Mapping):
raise ValueError(
f'Expected `kwargs` to be a `Mapping`, found {type(kwargs)}'
)
arg = function_utils.pack_args_into_struct(args, kwargs)
else:
arg = tree.traverse(_to_python, arg)
arg = comp.transform_args(arg)
arg = self._transform_args(arg)

# Save the type signature before compiling. Compilation currently loses
# container types, so we must remember them here so that they can be
Expand Down Expand Up @@ -257,6 +273,6 @@ def _to_python(obj):
_invoke(executor, comp, arg, result_type)
)

if comp.transform_result is not None:
result = comp.transform_result(result)
if self._transform_result is not None:
result = self._transform_result(result)
return result
Loading

0 comments on commit c10f4dd

Please sign in to comment.