Skip to content

Commit

Permalink
Add symbols to TFF google API which are needed by the borg context fo…
Browse files Browse the repository at this point in the history
…r trusted programs.

PiperOrigin-RevId: 691864813
  • Loading branch information
suxinguo authored and copybara-github committed Nov 1, 2024
1 parent 352c7a6 commit 27955da
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace py = ::pybind11;
namespace {

PYBIND11_MODULE(executor_stack_bindings, m) {
m.def("filter_to_live_channels",
py::overload_cast<
const std::vector<std::shared_ptr<grpc::ChannelInterface>>&, int>(
&FilterToLiveChannels),
"Wait and filter channels that are ready or idle.");

m.def("create_remote_executor_stack",
py::overload_cast<
const std::vector<std::shared_ptr<grpc::ChannelInterface>>&,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ namespace tensorflow_federated {
// which are ready or idle. This function blocks up to
// `wait_connected_duration_millis` milliseconds, awaiting all channels in
// `channels` to be ready.
std::vector<std::shared_ptr<grpc::ChannelInterface>> FilterToLiveChannels_(
std::vector<std::shared_ptr<grpc::ChannelInterface>> FilterToLiveChannels(
const std::vector<std::shared_ptr<grpc::ChannelInterface>>& channels,
int wait_connected_duration_millis = 1000) {
int wait_connected_duration_millis) {
std::vector<std::shared_ptr<grpc::ChannelInterface>> live_channels;
auto wait_connected =
[&wait_connected_duration_millis](
Expand Down Expand Up @@ -154,7 +154,7 @@ absl::StatusOr<std::shared_ptr<Executor>> CreateRemoteExecutorStack(
}

const std::vector<std::shared_ptr<grpc::ChannelInterface>> live_channels =
FilterToLiveChannels_(channels);
FilterToLiveChannels(channels);
if (live_channels.empty()) {
return absl::UnavailableError(
"No TFF workers are ready; try again to reconnect");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ using ComposingChildFn = std::function<absl::StatusOr<ComposingChild>(
using ComposingExecutorFn = std::function<std::shared_ptr<Executor>(
std::shared_ptr<Executor>, std::vector<ComposingChild>)>;

std::vector<std::shared_ptr<grpc::ChannelInterface>> FilterToLiveChannels(
const std::vector<std::shared_ptr<grpc::ChannelInterface>>& channels,
int wait_connected_duration_millis = 1000);

// Creates an executor stack which proxies for a group of remote workers.
//
// Upon object construction, the channels which represent connections to these
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
from tensorflow_federated.python.core.impl.types import placements


def filter_to_live_channels(
channels: Sequence[executor_bindings.GRPCChannel],
wait_connected_duration_millis: int = 1000,
) -> Sequence[executor_bindings.GRPCChannel]:
"""Waits and filters channels that are ready or idle."""
return executor_stack_bindings.filter_to_live_channels(
channels, wait_connected_duration_millis
)


def create_remote_executor_stack(
channels: Sequence[executor_bindings.GRPCChannel],
cardinalities: Mapping[placements.PlacementLiteral, int],
Expand Down

0 comments on commit 27955da

Please sign in to comment.