Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add symbols to TFF google API which are needed by the borg context for trusted programs. #4951

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace py = ::pybind11;
namespace {

PYBIND11_MODULE(executor_stack_bindings, m) {
m.def("filter_to_live_channels",
py::overload_cast<
const std::vector<std::shared_ptr<grpc::ChannelInterface>>&, int>(
&FilterToLiveChannels),
"Wait and filter channels that are ready or idle.");

m.def("create_remote_executor_stack",
py::overload_cast<
const std::vector<std::shared_ptr<grpc::ChannelInterface>>&,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ namespace tensorflow_federated {
// which are ready or idle. This function blocks up to
// `wait_connected_duration_millis` milliseconds, awaiting all channels in
// `channels` to be ready.
std::vector<std::shared_ptr<grpc::ChannelInterface>> FilterToLiveChannels_(
std::vector<std::shared_ptr<grpc::ChannelInterface>> FilterToLiveChannels(
const std::vector<std::shared_ptr<grpc::ChannelInterface>>& channels,
int wait_connected_duration_millis = 1000) {
int wait_connected_duration_millis) {
std::vector<std::shared_ptr<grpc::ChannelInterface>> live_channels;
auto wait_connected =
[&wait_connected_duration_millis](
Expand Down Expand Up @@ -154,7 +154,7 @@ absl::StatusOr<std::shared_ptr<Executor>> CreateRemoteExecutorStack(
}

const std::vector<std::shared_ptr<grpc::ChannelInterface>> live_channels =
FilterToLiveChannels_(channels);
FilterToLiveChannels(channels);
if (live_channels.empty()) {
return absl::UnavailableError(
"No TFF workers are ready; try again to reconnect");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ using ComposingChildFn = std::function<absl::StatusOr<ComposingChild>(
using ComposingExecutorFn = std::function<std::shared_ptr<Executor>(
std::shared_ptr<Executor>, std::vector<ComposingChild>)>;

std::vector<std::shared_ptr<grpc::ChannelInterface>> FilterToLiveChannels(
const std::vector<std::shared_ptr<grpc::ChannelInterface>>& channels,
int wait_connected_duration_millis = 1000);

// Creates an executor stack which proxies for a group of remote workers.
//
// Upon object construction, the channels which represent connections to these
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
from tensorflow_federated.python.core.impl.types import placements


def filter_to_live_channels(
channels: Sequence[executor_bindings.GRPCChannel],
wait_connected_duration_millis: int = 1000,
) -> Sequence[executor_bindings.GRPCChannel]:
"""Waits and filters channels that are ready or idle."""
return executor_stack_bindings.filter_to_live_channels(
channels, wait_connected_duration_millis
)


def create_remote_executor_stack(
channels: Sequence[executor_bindings.GRPCChannel],
cardinalities: Mapping[placements.PlacementLiteral, int],
Expand Down