diff --git a/tsl/platform/BUILD b/tsl/platform/BUILD index 9d6f00280..c31b1d6d5 100644 --- a/tsl/platform/BUILD +++ b/tsl/platform/BUILD @@ -323,7 +323,7 @@ cc_library( deps = [ ":status", "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/protobuf:status_proto_cc", + "@xla//xla/tsl/protobuf:status_proto_cc", ] + tf_platform_deps("status"), ) @@ -1349,10 +1349,10 @@ tsl_cc_test( ":test", ":test_main", "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/protobuf:status_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@xla//xla/tsl/protobuf:status_proto_cc", ], ) diff --git a/tsl/platform/default/build_config.bzl b/tsl/platform/default/build_config.bzl index 688750986..7806902b0 100644 --- a/tsl/platform/default/build_config.bzl +++ b/tsl/platform/default/build_config.bzl @@ -726,7 +726,7 @@ def tf_lib_proto_parsing_deps(): return [ ":protos_all_cc", clean_dep("@eigen_archive//:eigen3"), - clean_dep("//tsl/protobuf:protos_all_cc"), + clean_dep("@xla//xla/tsl/protobuf:protos_all_cc"), ] def tf_py_clif_cc(name, visibility = None, **kwargs): @@ -779,8 +779,8 @@ def tsl_cc_test( # TODO(ddunleavy) remove these and add proto deps to tests # granularly clean_dep("//tsl/protobuf:error_codes_proto_impl_cc_impl"), - clean_dep("//tsl/protobuf:histogram_proto_cc_impl"), - clean_dep("//tsl/protobuf:status_proto_cc_impl"), + clean_dep("@xla//xla/tsl/protobuf:histogram_proto_cc_impl"), + clean_dep("@xla//xla/tsl/protobuf:status_proto_cc_impl"), clean_dep("//tsl/profiler/protobuf:xplane_proto_cc_impl"), clean_dep("//tsl/profiler/protobuf:profiler_options_proto_cc_impl"), ], @@ -789,7 +789,7 @@ def tsl_cc_test( ) def tf_portable_proto_lib(): - return ["//tensorflow/core:protos_all_cc_impl", clean_dep("//tsl/protobuf:protos_all_cc_impl")] + return ["//tensorflow/core:protos_all_cc_impl", clean_dep("@xla//xla/tsl/protobuf:protos_all_cc_impl")] def tf_protobuf_compiler_deps(): return if_static( diff --git a/tsl/platform/status_test.cc b/tsl/platform/status_test.cc index 6d9948fa6..fbfdab891 100644 --- a/tsl/platform/status_test.cc +++ b/tsl/platform/status_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_format.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/stack_frame.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/test.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { namespace { diff --git a/tsl/platform/status_to_from_proto.cc b/tsl/platform/status_to_from_proto.cc index 96ad290f9..e83fa7d1b 100644 --- a/tsl/platform/status_to_from_proto.cc +++ b/tsl/platform/status_to_from_proto.cc @@ -16,9 +16,9 @@ limitations under the License. #include +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/status.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { diff --git a/tsl/platform/status_to_from_proto.h b/tsl/platform/status_to_from_proto.h index 9891737f0..021e002ae 100644 --- a/tsl/platform/status_to_from_proto.h +++ b/tsl/platform/status_to_from_proto.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { diff --git a/tsl/protobuf/BUILD b/tsl/protobuf/BUILD index 302648504..c444dce89 100644 --- a/tsl/protobuf/BUILD +++ b/tsl/protobuf/BUILD @@ -1,4 +1,3 @@ -# Placeholder: load py_proto_library load( "@xla//xla/tsl:tsl.bzl", "if_google", @@ -20,14 +19,6 @@ package( licenses = ["notice"], ) -tf_proto_library( - name = "dnn_proto", - srcs = ["dnn.proto"], - make_default_target_header_only = True, - protodeps = if_google(["//google/protobuf:wrappers"]), - visibility = ["//visibility:public"], -) - tf_proto_library( name = "error_codes_proto_impl", srcs = ["error_codes.proto"], @@ -35,78 +26,3 @@ tf_proto_library( protodeps = if_google(["//google/protobuf:any"]), visibility = ["//visibility:public"], ) - -tf_proto_library( - name = "status_proto", - srcs = ["status.proto"], - make_default_target_header_only = True, - protodeps = [":error_codes_proto_impl"], - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "histogram_proto", - srcs = ["histogram.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "coordination_config_proto", - srcs = ["coordination_config.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "coordination_service_proto", - srcs = ["coordination_service.proto"], - has_services = 1, - create_grpc_library = True, - create_java_proto = False, - create_service = True, - protodeps = if_google(["//google/protobuf:any"]), - visibility = ["//visibility:public"], -) - -# copybara:uncomment_begin(google-only) -# py_proto_library( -# name = "coordination_service_py_pb2", -# api_version = 2, -# visibility = ["//visibility:public"], -# deps = [":coordination_service_proto"], -# ) -# copybara:uncomment_end - -tf_proto_library( - name = "distributed_runtime_payloads_proto", - srcs = ["distributed_runtime_payloads.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "rpc_options_proto", - srcs = ["rpc_options.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "protos_all", - create_go_proto = False, - make_default_target_header_only = True, - protodeps = [ - # TODO(tlongeri): Conceptually, these fit into protos_all but adding them currently causes - # breakages (and they are not actually used). - "@xla//xla/tsl/protobuf:bfc_memory_map_proto", - ":coordination_config_proto", - ":distributed_runtime_payloads_proto", - ":error_codes_proto_impl", - ":histogram_proto", - ":rpc_options_proto", - ":status_proto", - "@xla//xla/tsl/protobuf:test_log_proto", - ] + if_google(["//google/protobuf:any"]), - visibility = ["//visibility:public"], -) diff --git a/tsl/protobuf/coordination_config.proto b/tsl/protobuf/coordination_config.proto deleted file mode 100644 index 23aff65eb..000000000 --- a/tsl/protobuf/coordination_config.proto +++ /dev/null @@ -1,74 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; - -// Represents a job type and the number of tasks under this job. -// For example, ("worker", 20) implies that there will be 20 worker tasks. -message CoordinatedJob { - string name = 1; - int32 num_tasks = 2; -} - -// Coordination service configuration parameters. -// The system picks appropriate values for fields that are not set. -message CoordinationServiceConfig { - // Type of coordination service implementation to enable. - // For example, setting the service type as "standalone" starts a service - // instance on the leader task to provide the coordination services such as - // heartbeats and consistent key-value store. - string service_type = 1; - - // Address where the coordination service instance is hosted. - string service_leader = 2; - - // Whether to enable the health check mechanism. - bool enable_health_check = 3; - - // Maximum wait time for all members in the cluster to be registered. - int64 cluster_register_timeout_in_ms = 4; - - // Heartbeat timeout, if a task does not record heartbeat in this time - // window, it will be considered disconnected. - // Note: This is also used as a grace period to accept any heartbeats after - // the agent has disconnected, to account for the lag time between the service - // recording the state change and the agent stopping heartbeats. - int64 heartbeat_timeout_in_ms = 5; - - // The list of `CoordinatedJob`s that will register in coordination service. - reserved 6; - repeated CoordinatedJob coordinated_job_list = 10; - - // Denotes how long to wait for all coordination agents to reach the barriers - // (after the first shutdown request) before disconnecting together. If - // set to 0, no barrier is imposed upon shutdown and each worker can - // disconnect individually. - int64 shutdown_barrier_timeout_in_ms = 7; - - // If set, agents do not make an explicit Shutdown() call. Service will only - // find out about the disconnecte agent via stale heartbeats. Used for - // testing. - bool agent_destruction_without_shutdown = 8; - - // The list of jobs which are recoverable. If a task in this list fails, - // it will not propagate error to other tasks. - // If empty, no jobs will be recoverable and every task failure will cause - // error propagation to other tasks. - repeated string recoverable_jobs = 9; - - // If a task restarts with a new incarnation, we may allow it to reconnect - // silently. This is useful when we know that a task can immediately resume - // work upon re-connecting to the service. - bool allow_new_incarnation_to_reconnect = 11; - - // Disables coordination service. - // Some libraries enable coordination service by default even if the user did - // not specify any config. This field allows users to explicitly disable - // coordination service under all situations. - bool force_disable = 12; - - // Use long polling to get error from coordination service as the error - // propagation mechanism. - bool poll_for_error_from_service_at_startup = 13; -} diff --git a/tsl/protobuf/coordination_service.proto b/tsl/protobuf/coordination_service.proto deleted file mode 100644 index 2405cb936..000000000 --- a/tsl/protobuf/coordination_service.proto +++ /dev/null @@ -1,363 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -import "google/protobuf/any.proto"; - -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; - -// Represents a remote worker task, specified by job name and task id. -message CoordinatedTask { - string job_name = 1; - int32 task_id = 2; -} - -// Represents the state of a remote worker -enum CoordinatedTaskState { - // TASKSTATE_UNSPECIFIED is an invalid state such that indicates a bug. - TASKSTATE_UNSPECIFIED = 0; - // TASKSTATE_UNINITIALIZED is an agent-only state. While the agent is - // disconnected, the service has no way of knowing if the task is - // initialized/uninitialized. - TASKSTATE_UNINITIALIZED = 1; - TASKSTATE_DISCONNECTED = 2; - TASKSTATE_CONNECTED = 3; - TASKSTATE_ERROR = 4; -} - -// Status payload for all coordination service errors. -// Note: an empty proto may be set if the error is triggered by the task's own -// agent calls (i.e. not propagated by the service from another remote task). -message CoordinationServiceError { - // Removed fields which used to specify the error origin. - reserved 1, 2; - // If true, error is reported via the agent API by the user (and not an - // internal service error). - bool is_reported_error = 3; - // Denotes which task hit the error. If unset, the error originated from the - // same task that is processing this error. - CoordinatedTask source_task = 4; -} - -message CoordinatedTaskStateInfo { - CoordinatedTask task = 1; - CoordinatedTaskState state = 2; - int32 error_code = 3; - string error_message = 4; - CoordinationServiceError error_payload = 5; -} - -// Placeholder message to be extended by other runtimes' device representations. -message DeviceInfo { - repeated google.protobuf.Any device = 1; -} - -// Request and response messages for registering a task to the cluster leader. -// A task is uniquely represented by its `job_name`, `task_id` and -// `incarnation`. Leader responds with its `incarnation` to identify a leader -// process. -message RegisterTaskRequest { - // Removed fields which used to specify the task. - reserved 1, 2; - fixed64 incarnation = 3; - // Moved the field `local_device_attributes` from this request message to - // WaitForAllTasksRequest defined below. - reserved 4; - CoordinatedTask source_task = 5; -} - -message RegisterTaskResponse { - fixed64 leader_incarnation = 1; -} - -// Request and response messages for sending heartbeats. -message HeartbeatRequest { - // Removed fields which used to specify the remote task. - reserved 1, 2; - fixed64 incarnation = 3; - CoordinatedTask source_task = 4; -} - -message HeartbeatResponse { - fixed64 leader_incarnation = 1; - // If there are failures in cluster, use additional metadata in response to - // broadcast error code and message to other tasks. -} - -message PollForErrorRequest { - CoordinatedTask source_task = 1; -} - -message PollForErrorResponse {} - -// Request and response messages for waiting for all tasks. -message WaitForAllTasksRequest { - // Removed fields which used to specify the remote task. - reserved 1, 2; - // Removed field that specifically used TF device info. - reserved 3, 4; - CoordinatedTask source_task = 5; - // All local device attributes on the request sender; - DeviceInfo device_info = 6; -} - -message WaitForAllTasksResponse { - fixed64 leader_incarnation = 1; - // Removed field that specifically used TF device info. - reserved 2, 3; - // All devices in the cluster. - DeviceInfo device_info = 4; -} - -// Request and response messages for disconnecting a task from the service. -message ShutdownTaskRequest { - CoordinatedTask source_task = 1; -} - -message ShutdownTaskResponse {} - -// Request and response messages for resetting a task state in the service. -message ResetTaskRequest { - CoordinatedTask source_task = 1; -} - -message ResetTaskResponse {} - -// Request and response messages for reporting errors to task. -message ReportErrorToTaskRequest { - int32 error_code = 1; - string error_message = 2; - // Removed fields that are embedded in payload. - reserved 3, 4; - CoordinationServiceError error_payload = 5; -} - -message ReportErrorToTaskResponse {} - -// Request and response messages for reporting errors to service instance. -message ReportErrorToServiceRequest { - int32 error_code = 1; - string error_message = 2; - // Removed fields which used to specify the error origin. - reserved 3, 4; - CoordinatedTask error_origin = 5; -} - -message ReportErrorToServiceResponse {} - -// Request and response messages for getting state of a remote task. -message GetTaskStateRequest { - repeated CoordinatedTask source_task = 1; -} - -message GetTaskStateResponse { - repeated CoordinatedTaskStateInfo task_state = 1; -} - -// Message for configuration key value. -// Key is structured like Unix file system, with multiple levels of directory -// names separated by the slash ('/') characters. -message KeyValueEntry { - string key = 1; - bytes value = 2; -} - -// Request and response messages for inserting configuration key-value data. -message InsertKeyValueRequest { - KeyValueEntry kv = 1; - bool allow_overwrite = 2; -} - -message InsertKeyValueResponse {} - -// Request and response messages for getting configuration key-value data. -message GetKeyValueRequest { - string key = 1; -} - -message GetKeyValueResponse { - KeyValueEntry kv = 1; -} - -message TryGetKeyValueRequest { - string key = 1; -} - -message TryGetKeyValueResponse { - KeyValueEntry kv = 1; -} - -message GetKeyValueDirRequest { - string directory_key = 1; -} - -message GetKeyValueDirResponse { - string directory_key = 1; - repeated KeyValueEntry kv = 2; -} - -// Request and response messages for deleting configuration key-value data. -// When is_directory is true, delete key-values recursively under `key`. -message DeleteKeyValueRequest { - string key = 1; - bool is_directory = 2; -} - -message DeleteKeyValueResponse {} - -// Request and response messages for generic sync barriers. -message BarrierRequest { - string barrier_id = 1; - int64 barrier_timeout_in_ms = 2; - // Denotes list of tasks that will wait for the barrier. If unspecified, it - // implies that the entire cluster is participating in the barrier. - repeated CoordinatedTask tasks = 3; - // Task that is making the request. - CoordinatedTask source_task = 4; -} - -message BarrierResponse {} - -// Request and response messages for cancelling generic sync barriers. -message CancelBarrierRequest { - string barrier_id = 1; - // Task that is making the request. - CoordinatedTask source_task = 2; -} - -message CancelBarrierResponse {} - -// Coordination Service defines a TensorFlow service that controls and -// coordinates distributed execution in a cluster of multiple tasks. -// -// The service keeps track of the cluster configuration and the state of cluster -// members or the leader depending on the role of the current task. The -// distributed runtime leverages this service to coordinate and perform cluster -// initialization, check the healthiness of tasks, and propagate error -// messages to the cluster. -service CoordinationService { - // Register task to coordination service so that the service starts to track - // liveness of the task. RPC blocks and returns only when it registers to - // the service successfully, or error happens in the registering process. - rpc RegisterTask(RegisterTaskRequest) returns (RegisterTaskResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } - - // Heartbeat message from task to coordination service. Heartbeat is sent from - // a task to refresh its timestamp on leader to avoid it becoming stale. - // RPC responds immediately after refreshing the timestamp on leader. - rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } - - // Wait for all tasks in the cluster to be up and running. The RPC request - // only gets responded when all tasks have registered, or some error occurs. - rpc WaitForAllTasks(WaitForAllTasksRequest) returns (WaitForAllTasksResponse); - - // Disconnects task from the service. If `shutdown_barrier_timeout_in_ms` is - // specified in the config, blocks until all tasks reach the barrier before - // disconnecting together. If the barrier times out, tasks at the barrier will - // still disconnect, while an error is reported to tasks that did not reach - // the barrier on time. - rpc ShutdownTask(ShutdownTaskRequest) returns (ShutdownTaskResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } - - // Disconnects task from the service if it is in an ERROR state, thereby - // allowing it to reconnect via RegisterTask() in the future. - rpc ResetTask(ResetTaskRequest) returns (ResetTaskResponse); - - // Report error to the task. RPC sets the receiving instance of coordination - // service agent to error state permanently. - // TODO(b/195990880): Consider splitting this into a different RPC service. - rpc ReportErrorToTask(ReportErrorToTaskRequest) - returns (ReportErrorToTaskResponse); - - // Report task error to coordination service. RPC sets the service-side task - // state to error, and propagate the error to other tasks in the cluster. - rpc ReportErrorToService(ReportErrorToServiceRequest) - returns (ReportErrorToServiceResponse); - - // Get the state of a remote task. Specifically, RPC returns a - // CoordinatedTaskState, and if the task is in an error status, returns a - // non-OK error code, non-empty error message and error payload. - rpc GetTaskState(GetTaskStateRequest) returns (GetTaskStateResponse); - - // Insert configuration key-value that will be accessible to all cluster - // tasks. The key can be formatted as Unix file path with hierarchy. The - // coordination service key-value store should only be used for cluster - // configuration data. - rpc InsertKeyValue(InsertKeyValueRequest) returns (InsertKeyValueResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } - - // Get configuration key-value. The request blocks until the key-value data - // becomes available (i.e., set by a task in the cluster). - rpc GetKeyValue(GetKeyValueRequest) returns (GetKeyValueResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } - - // Get configuration key-value. The request does not block, but returns an - // error if the requested key does not exist. - rpc TryGetKeyValue(TryGetKeyValueRequest) returns (TryGetKeyValueResponse); - - // Same as GetKeyValue, but returns all values that have keys which are - // prefixed with the directory key. - rpc GetKeyValueDir(GetKeyValueDirRequest) returns (GetKeyValueDirResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } - - // Delete configuration key-value. If is_directory is set in request, - // recursively clean up all key-values under the path specified by `key`. - rpc DeleteKeyValue(DeleteKeyValueRequest) returns (DeleteKeyValueResponse); - - // Blocks until all (or a subset of) tasks are at the barrier or the barrier - // fails. - // - // `barrier_id` should be unique across barriers. Once the barrier has passed - // or failed, subsequent calls will not block, and immediately respond with - // the previous response. - // - // The first WaitAtBarrier() call received by the service for a particular - // barrier id is special in that it determines the barrier deadline based on - // timeout duration. - // However, if subsequent calls by different agents specify a different set of - // `tasks` for the same `barrier_id`, the barrier will fail instantly. - // - // If no tasks are specified (default), the barrier will block for all the - // connected tasks. - // - // Possible service errors: - // - DeadlineExceeded: Timed out waiting for specified tasks at the barrier. - // Deadline is determined by the server timestamp when it receives the - // first WaitAtBarrier() + timeout duration. - // - Cancelled: One of the tasks called CancelBarrier(). - // - Aborted: Service is shutting down. - // - Internal: Any participating task is in ERROR state. - // - InvalidArgument: (1) Conflicting tasks specified by different agents - // for the same barrier, (2) one of the participating tasks is not in - // the cluster, or (3) task making the request is not included in the - // list of participating tasks. - rpc Barrier(BarrierRequest) returns (BarrierResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } - - // Aborts the barrier if it is ongoing. - // Current and future WaitAtBarrier() calls with the same id will return a - // CANCELLED error status. - // Possible service errors: - // - FailedPrecondition: Barrier has already been passed. - rpc CancelBarrier(CancelBarrierRequest) returns (CancelBarrierResponse); - - // Polls the service for errors. - // - // This RPC is used by the coordination service agent to send long polling - // request to service for errors. The call will block until an error is - // reported by the service. - // - // Possible service errors: - // - Aborted: Service is shutting down. - rpc PollForError(PollForErrorRequest) returns (PollForErrorResponse) { - // [AUTOMATION]: Internal rpc option goes here. - } -} diff --git a/tsl/protobuf/distributed_runtime_payloads.proto b/tsl/protobuf/distributed_runtime_payloads.proto deleted file mode 100644 index 3a2aecdd2..000000000 --- a/tsl/protobuf/distributed_runtime_payloads.proto +++ /dev/null @@ -1,24 +0,0 @@ -syntax = "proto3"; - -package tensorflow.distributed_runtime; - -option cc_enable_arenas = true; -option go_package = "github.com/tsl/tsl/go/core/protobuf/for_core_protos_go_proto"; - -// Used to serialize and transmit tensorflow::Status payloads through -// grpc::Status `error_details` since grpc::Status lacks payload API. -// TODO(b/204231601): Use GRPC API once supported. -message GrpcPayloadContainer { - map payloads = 1; -} - -// If included as a payload, this message flags the Status to have lost payloads -// during the GRPC transmission. -// URI: "type.googleapis.com/tensorflow.distributed_runtime.GrpcPayloadsLost" -message GrpcPayloadsLost {} - -// If included as a payload, this message flags the Status to be a possible -// outcome of a worker restart. -// URI: -// "type.googleapis.com/tensorflow.distributed_runtime.WorkerPossiblyRestarted" -message WorkerPossiblyRestarted {} diff --git a/tsl/protobuf/dnn.proto b/tsl/protobuf/dnn.proto deleted file mode 100644 index 695db935f..000000000 --- a/tsl/protobuf/dnn.proto +++ /dev/null @@ -1,203 +0,0 @@ -// LINT: LEGACY_NAMES -syntax = "proto3"; - -package stream_executor.dnn; - -import "google/protobuf/wrappers.proto"; - -option go_package = "github.com/google/tsl/tsl/go/stream_executor"; - -// Specifies the data type used by an operation. -enum DataType { - kFloat = 0; - kDouble = 1; - kHalf = 2; - kInt8 = 3; - kInt32 = 4; - kComplexFloat = 5; - kComplexDouble = 6; - kBF16 = 7; - kF8E5M2 = 8; - kF8E4M3FN = 9; - kF8E5M2FNUZ = 10; - kF8E4M3FNUZ = 11; - kInt64 = 12; -} - -// Describes how a convolution input or output layer's data is formatted. -enum DataLayout { - // Naming convention: - // Y <-> row or height - // X <-> column or width - // Batch <-> batch, or N - // Depth <-> feature, or channel - // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. - // - // Note: In cudnn, kBatchDepthYX4 and kBatchDepthYX32 are the same layout - // (namely, NCHW_VECT_C). It differentiates between these two by using a - // different data type (int8x4 vs int8x32). In StreamExecutor we use - // different layouts for these, because we don't usually pass an explicit data - // type to StreamExecutor functions. - kYXDepthBatch = 0; - kYXBatchDepth = 1; - kBatchYXDepth = 2; // cuDNN's NHWC layout - kBatchDepthYX = 3; // cuDNN's NCHW layout - kBatchDepthYX4 = 4; // cuDNN's NCHW_VECT_C with 4-elem vectors (e.g. int8x4) - kBatchDepthYX32 = 5; // cuDNN's NCHW_VECT_C with 32-elem vects (e.g. int8x32) -} - -// Describes how a convolution filter is laid out in the memory. -enum FilterLayout { - // Naming convention: - // Y <-> row or height - // X <-> column or width - // Output <-> output feature, or N - // Input <-> input feature, or N - // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. - kOutputInputYX = 0; // cuDNN's NCHW layout - kOutputYXInput = 1; // cuDNN's NHWC layout - kOutputInputYX4 = 2; // cuDNN's NCHW_VECT_C layout with 4-elem vectors - kOutputInputYX32 = 5; // cuDNN's NCHW_VECT_C layout with 32-elem vectors - // cuDNN-specific filter reordering (using `cudnnReorderFilterAndBias`) - // When the filter is reordered, so is the bias (if present). - kOutputInputYX32_CudnnReordered = 6; - kInputYXOutput = 3; - kYXInputOutput = 4; -} - -// Describes a kind of non-linearity (threshold-like mathematical function). -enum ActivationMode { - kNone = 0; - kSigmoid = 1; - // Rectified linear activation: f(x) = x < 0 ? 0 : x - kRelu = 2; - // Rectified linear activation; where upper maximum is 6.0. - kRelu6 = 3; - // Rectified linear activation; where upper maximum specified by - // BatchDescriptor::value_max(). - kReluX = 4; - kTanh = 5; - // Like ReluX; but passes all values in the range [-X,X]. - kBandPass = 6; - // Exponential linear activation: f(x) = x < 0 ? e^x - 1 : x - kElu = 7; - // Leaky Rectified linear activation: f(x) = x < 0 ? alpha * x : x - kLeakyRelu = 8; - // Gaussian Error linear unit activation: - // x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2))), where P(X) ~ N(0, 1). - kGeluExact = 9; -} - -// Describe the math definition for the conv op. The popular behavior is -// actually called cross-correlation in math, despite the operation is often -// referred as convolution. See cuDNN cudnnConvolutionMode_t. -enum ConvolutionMode { - CROSS_CORRELATION = 0; - CONVOLUTION = 1; -} - -enum ConvolutionKind { - INVALID = 0; - FORWARD = 1; - BACKWARD_FILTER = 2; - BACKWARD_DATA = 3; - FORWARD_BIAS_ACTIVATION = 4; - FORWARD_GRAPH = 5; -} - -// Generic tensor representation. -message TensorDescriptorProto { - repeated int64 dimensions = 1; - DataType data_type = 2; - oneof layout_oneof { - DataLayout data_layout = 3; - FilterLayout filter_layout = 4; - } -} - -// Generic algorithm representation. -message AlgorithmProto { - enum MathType { - DEFAULT_MATH = 0; - // The GPU may operate 4x4 matrix FMA. - // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH. - TENSOR_OP_MATH = 1; - } - int64 algo_id = 1; - MathType math_type = 2; - reserved 3; - - map tuning_knobs = 4; - // Legacy algorithm enums and cuDNN Frontend engine numbers need to coexist in - // the same proto medium-term, until we can be confident of no longer needing - // the legacy cuDNN convolution API. Once the migration is complete, we can - // stop producing legacy algorithm enums and remove this field. - bool is_cudnn_frontend = 5; - - // For ROCm only, it's impossible to re-query the required workspace size - // after running the algorithm search, so we must store the workspace size - // along with the choice of algorithm. For consistency and convenience, - // cuDNN uses this field in the same way, even though it would be possible to - // re-query the workspace size from cuDNN at each use. - // - // Since this message is persisted in files, we need to be able to distinguish - // 0 workspace size from unknown workspace size in an old message, so this is - // a message field. - google.protobuf.UInt64Value workspace_size = 6; -} - -// Proto definition of AlgorithmConfig in "dnn.h". -// TODO(ruochengw): After cl/380702564 is submitted, add support for algorithm -// configs with cuDNN Frontend APIs. -message AlgorithmConfigProto { - // Use oneof to emulate optional semantics in proto2 since older - // version of proto3 cannot distinguish "unset field" and "default field". - oneof optional_algorithm { - AlgorithmProto algorithm = 1; - } - oneof optional_algorithm_no_scratch { - AlgorithmProto algorithm_no_scratch = 2; - } - oneof optional_scratch_size { - int64 scratch_size = 3; - } -} - -// Convolution-specific parameters. -message ConvolutionDescriptorProto { - repeated int64 paddings = 1; - repeated int64 strides = 2; - repeated int64 dilations = 3; - // The "accumulator" type. For example, use F32 as an accumulator for F16 - // convolutions. - // See cuDNN's cudnnConvolutionMode_t. - DataType compute_mode = 4; - // See cuDNN's group count. - int32 group_count = 5; - ConvolutionMode convolution_mode = 6; - // Tensorflow node name, same as in NodeDef, for debugging purposes. - string name = 7; -} - -// NormKind kind -enum NormKind { - LAYER_FWD_INFER = 0; - LAYER_FWD_TRAIN = 1; - LAYER_BWD = 2; -} - -// FusedMHAKind kind -enum FusedMHAKind { - BMM1_OUTPUT_UNKNOWN = 0; - BMM1_OUTPUT_INPUT_TYPE = 1; - BMM1_OUTPUT_FLOAT = 2; -} - -// FusedMHAMaskKind kind -enum FMHAMaskKind { - NO_MASK = 0; - PADDING = 1; - CAUSAL = 2; - PADDING_CAUSAL = 3; - ALIBI = 4; -} diff --git a/tsl/protobuf/histogram.proto b/tsl/protobuf/histogram.proto deleted file mode 100644 index 2a5f6d936..000000000 --- a/tsl/protobuf/histogram.proto +++ /dev/null @@ -1,26 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -option cc_enable_arenas = true; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/google/tsl/tsl/go/core/protobuf/summary_go_proto"; - -// Serialization format for histogram module in -// tsl/lib/histogram/histogram.h -message HistogramProto { - double min = 1; - double max = 2; - double num = 3; - double sum = 4; - double sum_squares = 5; - - // Parallel arrays encoding the bucket boundaries and the bucket values. - // bucket(i) is the count for the bucket i. The range for - // a bucket is: - // i == 0: -DBL_MAX .. bucket_limit(0) - // i != 0: bucket_limit(i-1) .. bucket_limit(i) - repeated double bucket_limit = 6 [packed = true]; - repeated double bucket = 7 [packed = true]; -} diff --git a/tsl/protobuf/rpc_options.proto b/tsl/protobuf/rpc_options.proto deleted file mode 100644 index 35c5dbe3b..000000000 --- a/tsl/protobuf/rpc_options.proto +++ /dev/null @@ -1,41 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; - -// RPC options for distributed runtime. -message RPCOptions { - // If true, always use RPC to contact the session target. - // - // If false (the default option), TensorFlow may use an optimized - // transport for client-master communication that avoids the RPC - // stack. This option is primarily for used testing the RPC stack. - bool use_rpc_for_inprocess_master = 1; - - // The compression algorithm to be used. One of "deflate", "gzip". - string compression_algorithm = 2; - - // If compression_algorithm is set, the compression level to be used. - // From 0 (no compression), up to 3. - int32 compression_level = 3; - - // Setting cache_rpc_response to true will enable sender side caching of - // response for RecvTensorAsync and RecvBufAsync to allow receiver to retry - // requests . This is only necessary when the network fabric is experiencing a - // significant error rate. Without it we'll fail a step on an network error, - // while with it we'll be able to complete long steps (like complex - // initializations) in the face of some network errors during RecvTensor. - bool cache_rpc_response = 4; - - // Disables TCP connection sharing when opening a new RPC channel. - bool disable_session_connection_sharing = 5; - - // Setting num_channels_per_target > 0 allows uses of multiple channels to - // communicate to the same target. This can be used to improve the aggregate - // throughput on high speed links (e.g 100G) where single connection is not - // sufficient to maximize link utilization. Note that a single RPC only goes - // on a single channel, this only helps in situations where there are multiple - // transfers to the same target overlapping in time. - int32 num_channels_per_target = 6; -} diff --git a/tsl/protobuf/status.proto b/tsl/protobuf/status.proto deleted file mode 100644 index 09d722189..000000000 --- a/tsl/protobuf/status.proto +++ /dev/null @@ -1,20 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -import "tsl/protobuf/error_codes.proto"; - -option cc_enable_arenas = true; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; - -// Wire-format for Status. -// Next tag: 3 -message StatusProto { - // Status code as defined in tensorflow/tsl/protobuf/error_codes.proto. - error.Code code = 1; - - // Detail error message. - string message = 2; -}