diff --git a/RELEASE.md b/RELEASE.md index 636ea9bdac..223da9e9ec 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -24,6 +24,7 @@ and this project adheres to sampling. * `tff.program.ComputationArg`, which is helpful when creating a federated platform. +* Add `FilterToLiveChannels` to pybindings. ### Removed diff --git a/tensorflow_federated/cc/core/impl/executor_stacks/executor_stack_bindings.cc b/tensorflow_federated/cc/core/impl/executor_stacks/executor_stack_bindings.cc index cc469eb720..6f6253b4a9 100644 --- a/tensorflow_federated/cc/core/impl/executor_stacks/executor_stack_bindings.cc +++ b/tensorflow_federated/cc/core/impl/executor_stacks/executor_stack_bindings.cc @@ -33,6 +33,12 @@ namespace py = ::pybind11; namespace { PYBIND11_MODULE(executor_stack_bindings, m) { + m.def("filter_to_live_channels", + py::overload_cast< + const std::vector>&, int>( + &FilterToLiveChannels), + "Wait and filter channels that are ready or idle."); + m.def("create_remote_executor_stack", py::overload_cast< const std::vector>&, diff --git a/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.cc b/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.cc index 2d6ef5a89a..430f16856e 100644 --- a/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.cc +++ b/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.cc @@ -44,9 +44,9 @@ namespace tensorflow_federated { // which are ready or idle. This function blocks up to // `wait_connected_duration_millis` milliseconds, awaiting all channels in // `channels` to be ready. -std::vector> FilterToLiveChannels_( +std::vector> FilterToLiveChannels( const std::vector>& channels, - int wait_connected_duration_millis = 1000) { + int wait_connected_duration_millis) { std::vector> live_channels; auto wait_connected = [&wait_connected_duration_millis]( @@ -154,7 +154,7 @@ absl::StatusOr> CreateRemoteExecutorStack( } const std::vector> live_channels = - FilterToLiveChannels_(channels); + FilterToLiveChannels(channels); if (live_channels.empty()) { return absl::UnavailableError( "No TFF workers are ready; try again to reconnect"); diff --git a/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.h b/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.h index 2174267d07..6b5e033dba 100644 --- a/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.h +++ b/tensorflow_federated/cc/core/impl/executor_stacks/remote_stacks.h @@ -35,6 +35,10 @@ using ComposingChildFn = std::function( using ComposingExecutorFn = std::function( std::shared_ptr, std::vector)>; +std::vector> FilterToLiveChannels( + const std::vector>& channels, + int wait_connected_duration_millis = 1000); + // Creates an executor stack which proxies for a group of remote workers. // // Upon object construction, the channels which represent connections to these diff --git a/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py b/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py index cf145e1c31..624199e43e 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py @@ -21,6 +21,16 @@ from tensorflow_federated.python.core.impl.types import placements +def filter_to_live_channels( + channels: Sequence[executor_bindings.GRPCChannel], + wait_connected_duration_millis: int = 1000, +) -> Sequence[executor_bindings.GRPCChannel]: + """Waits and filters channels that are ready or idle.""" + return executor_stack_bindings.filter_to_live_channels( + channels, wait_connected_duration_millis + ) + + def create_remote_executor_stack( channels: Sequence[executor_bindings.GRPCChannel], cardinalities: Mapping[placements.PlacementLiteral, int],