diff --git a/tsl/platform/BUILD b/tsl/platform/BUILD index 1cc83c02f..eb208e6ef 100644 --- a/tsl/platform/BUILD +++ b/tsl/platform/BUILD @@ -16,7 +16,6 @@ load( "tf_error_logging_deps", "tf_fingerprint_deps", "tf_google_mobile_srcs_no_runtime", - "tf_logging_deps", "tf_platform_deps", "tf_protobuf_compiler_deps", "tf_resource_deps", @@ -26,7 +25,6 @@ load( "tsl_grpc_credentials_deps", "tsl_protobuf_deps", ) -load("@xla//xla/tsl/platform:build_config_root.bzl", "if_static") load( "@xla//xla/tsl/platform:rules_cc.bzl", "cc_library", @@ -143,57 +141,40 @@ cc_library( "file_system_helper.h", "threadpool.h", ], - deps = tf_windows_aware_platform_deps("env") + if_static([":env_impl"]), + deps = [ + "@xla//xla/tsl/platform:env", + ], ) cc_library( name = "threadpool_async_executor", hdrs = ["threadpool_async_executor.h"], deps = [ - ":env", - "@xla//xla/tsl/concurrency:async_value", - ], -) - -tsl_cc_test( - name = "threadpool_async_executor_test", - srcs = ["threadpool_async_executor_test.cc"], - deps = [ - ":env", - ":env_impl", - ":test", - ":test_main", - ":threadpool_async_executor", - "@com_google_absl//absl/synchronization", + "@xla//xla/tsl/platform:threadpool_async_executor", ], ) cc_library( name = "env_impl", - deps = tf_windows_aware_platform_deps("env_impl"), + deps = [ + "@xla//xla/tsl/platform:env_impl", + ], ) cc_library( name = "env_time", compatible_with = get_compatible_with_portable(), textual_hdrs = ["env_time.h"], - deps = tf_windows_aware_platform_deps("env_time"), + deps = [ + "@xla//xla/tsl/platform:env_time", + ], ) cc_library( name = "errors", - srcs = ["errors.cc"], hdrs = ["errors.h"], deps = [ - ":logging", - ":macros", - ":status", - ":str_util", - ":strcat", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", + "@xla//xla/tsl/platform:errors", ], ) @@ -290,55 +271,26 @@ cc_library( cc_library( name = "status", - srcs = ["status.cc"], hdrs = ["status.h"], deps = [ - ":logging", - ":macros", - ":mutex", - ":platform", - ":stack_frame", - ":stacktrace", - ":str_util", - ":strcat", - ":stringprintf", - ":types", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/types:optional", - "@xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - ] + tf_platform_deps("status"), + "@xla//xla/tsl/platform:status", + ], ) cc_library( name = "status_to_from_proto", - srcs = [ - "status_to_from_proto.cc", - ], hdrs = ["status_to_from_proto.h"], deps = [ - ":status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - "@xla//xla/tsl/protobuf:status_proto_cc", - ] + tf_platform_deps("status"), + "@xla//xla/tsl/platform:status_to_from_proto", + ], ) cc_library( name = "status_matchers", testonly = 1, - srcs = ["status_matchers.cc"], hdrs = ["status_matchers.h"], deps = [ - ":status", - ":statusor", - ":test", - "@xla//xla/tsl/protobuf:error_codes_proto_impl_cc", + "@xla//xla/tsl/platform:status_matchers", ], ) @@ -346,17 +298,8 @@ cc_library( name = "statusor", hdrs = ["statusor.h"], deps = [ - ":errors", - ":logging", - ":macros", - ":platform", - ":status", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ] + tf_platform_deps("statusor"), + "@xla//xla/tsl/platform:statusor", + ], ) cc_library( @@ -368,17 +311,10 @@ cc_library( cc_library( name = "test", testonly = True, - srcs = ["test.cc"], compatible_with = get_compatible_with_portable(), textual_hdrs = ["test.h"], deps = [ - ":logging", - ":macros", - ":net", - ":path", - ":platform", - ":types", - "@com_google_googletest//:gtest", + "@xla//xla/tsl/platform:test", ], ) @@ -388,8 +324,7 @@ cc_library( hdrs = ["test_benchmark.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":platform", - "@com_google_benchmark//:benchmark", + "@xla//xla/tsl/platform:test_benchmark", ], ) @@ -554,13 +489,10 @@ filegroup( "denormal.cc", "denormal.h", "dynamic_annotations.h", - "env.cc", "env.h", "env_time.h", - "errors.cc", "errors.h", "file_statistics.h", - "file_system.cc", "file_system.h", "file_system_helper.h", "hash.cc", @@ -593,7 +525,6 @@ filegroup( "setround.h", "snappy.h", "stacktrace.h", - "status.cc", "status.h", "statusor.h", "str_util.cc", @@ -604,7 +535,6 @@ filegroup( "stringprintf.cc", "stringprintf.h", "thread_annotations.h", - "threadpool.cc", "threadpool.h", "threadpool_interface.h", "tracing.h", @@ -612,7 +542,6 @@ filegroup( ] + select({ "@xla//xla/tsl:fuchsia": tf_google_mobile_srcs_no_runtime(), "//conditions:default": [ - "file_system_helper.cc", "tracing.cc", "@xla//xla/tsl/platform/default:mobile_srcs_no_runtime", ], @@ -674,13 +603,11 @@ exports_files( "criticality.h", "cuda_root_path.h", "demangle.h", - "env.cc", "env.h", "env_time.h", "error_logging.h", "file_system.cc", "file_system.h", - "file_system_helper.cc", "file_system_helper.h", "grpc_credentials.h", "host_info.h", @@ -813,6 +740,9 @@ cc_library( name = "macros", hdrs = ["macros.h"], compatible_with = get_compatible_with_portable(), + deps = [ + "@xla//xla/tsl/platform:macros", + ], ) filegroup( @@ -1005,9 +935,7 @@ cc_library( hdrs = ["threadpool_interface.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":mutex", - ":types", - "@eigen_archive//:eigen3", + "@xla//xla/tsl/platform:threadpool_interface", ], ) @@ -1016,11 +944,8 @@ cc_library( hdrs = ["types.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":bfloat16", - ":ml_dtypes", - ":platform", - ":tstring", - ] + tf_platform_deps("types"), + "@xla//xla/tsl/platform:types", + ], ) cc_library( @@ -1078,7 +1003,9 @@ cc_library( visibility = [ "//visibility:public", ], - deps = tf_logging_deps(), + deps = [ + "@xla//xla/tsl/platform:logging", + ], ) cc_library( @@ -1205,7 +1132,7 @@ cc_library( name = "file_statistics", hdrs = ["file_statistics.h"], deps = [ - ":types", + "@xla//xla/tsl/platform:file_statistics", ], ) @@ -1332,72 +1259,12 @@ tsl_cc_test( cc_library( name = "test_main", testonly = 1, - srcs = ["test_main.cc"], - copts = tsl_copts(), - linkopts = select({ - "@xla//xla/tsl:windows": [], - "//conditions:default": ["-lm"], - }), deps = [ - ":platform", - ":stacktrace_handler", - ":test", - ":test_benchmark", - "@com_google_absl//absl/strings", + "@xla//xla/tsl/platform:test_main", ], alwayslink = 1, ) -tsl_cc_test( - name = "status_test", - size = "small", - srcs = ["status_test.cc"], - deps = [ - ":errors", - ":stack_frame", - ":status", - ":status_matchers", - ":status_to_from_proto", - ":test", - ":test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/strings:str_format", - "@xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - "@xla//xla/tsl/protobuf:status_proto_cc", - ], -) - -tsl_cc_test( - name = "statusor_test", - size = "small", - srcs = ["statusor_test.cc"], - deps = [ - ":errors", - ":macros", - ":statusor", - ":test", - ":test_benchmark", - ":test_main", - "@com_google_absl//absl/base:config", - ], -) - -tsl_cc_test( - name = "status_matchers_test", - size = "small", - srcs = ["status_matchers_test.cc"], - deps = [ - ":errors", - ":status", - ":status_matchers", - ":statusor", - ":test", - ":test_main", - "@xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - ], -) - cc_library( name = "notification", hdrs = ["notification.h"], @@ -1413,7 +1280,7 @@ cc_library( hdrs = ["threadpool_options.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":threadpool_interface", + "@xla//xla/tsl/platform:threadpool_options", ], ) @@ -1483,18 +1350,6 @@ cc_library( ], ) -tsl_cc_test( - name = "errors_test", - size = "small", - srcs = ["errors_test.cc"], - deps = [ - ":errors", - ":test", - ":test_main", - "@com_google_absl//absl/status", - ], -) - tsl_cc_test( name = "intrusive_ptr_test", size = "small", @@ -1566,26 +1421,6 @@ tsl_cc_test( ], ) -tsl_cc_test( - name = "logging_test", - size = "small", - srcs = [ - "logging_test.cc", - ], - deps = [ - ":logging", - ":path", - ":stacktrace_handler", - ":statusor", - ":test", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - ], -) - tsl_cc_test( name = "mutex_test", size = "small", diff --git a/tsl/platform/env.cc b/tsl/platform/env.cc deleted file mode 100644 index 29d5d6ff4..000000000 --- a/tsl/platform/env.cc +++ /dev/null @@ -1,649 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/env.h" - -#include - -#include -#include -#include -#include - -#include "absl/strings/str_format.h" -#include "tsl/platform/env_time.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/host_info.h" -#include "tsl/platform/path.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/protobuf.h" -#include "tsl/platform/stringprintf.h" - -#if defined(__APPLE__) -#include -#endif -#if defined(__FreeBSD__) -#include -#endif -#if defined(PLATFORM_WINDOWS) -#include -#undef DeleteFile -#undef CopyFile -#include "xla/tsl/platform/windows/wide_char.h" -#define PATH_MAX MAX_PATH -#else -#include -#include -#include -#include -#endif - -namespace tsl { - -// 128KB copy buffer -constexpr size_t kCopyFileBufferSize = 128 * 1024; - -class FileSystemRegistryImpl : public FileSystemRegistry { - public: - absl::Status Register(const std::string& scheme, Factory factory) override; - absl::Status Register(const std::string& scheme, - std::unique_ptr filesystem) override; - FileSystem* Lookup(const std::string& scheme) override; - absl::Status GetRegisteredFileSystemSchemes( - std::vector* schemes) override; - - private: - mutable mutex mu_; - mutable std::unordered_map> registry_ - TF_GUARDED_BY(mu_); -}; - -absl::Status FileSystemRegistryImpl::Register( - const std::string& scheme, FileSystemRegistry::Factory factory) { - mutex_lock lock(mu_); - if (!registry_.emplace(scheme, std::unique_ptr(factory())) - .second) { - return errors::AlreadyExists("File factory for ", scheme, - " already registered"); - } - return absl::OkStatus(); -} - -absl::Status FileSystemRegistryImpl::Register( - const std::string& scheme, std::unique_ptr filesystem) { - mutex_lock lock(mu_); - if (!registry_.emplace(scheme, std::move(filesystem)).second) { - return errors::AlreadyExists("File system for ", scheme, - " already registered"); - } - return absl::OkStatus(); -} - -FileSystem* FileSystemRegistryImpl::Lookup(const std::string& scheme) { - mutex_lock lock(mu_); - const auto found = registry_.find(scheme); - if (found == registry_.end()) { - return nullptr; - } - return found->second.get(); -} - -absl::Status FileSystemRegistryImpl::GetRegisteredFileSystemSchemes( - std::vector* schemes) { - mutex_lock lock(mu_); - for (const auto& e : registry_) { - schemes->push_back(e.first); - } - return absl::OkStatus(); -} - -Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {} - -absl::Status Env::GetFileSystemForFile(const std::string& fname, - FileSystem** result) { - absl::string_view scheme, host, path; - io::ParseURI(fname, &scheme, &host, &path); - FileSystem* file_system = file_system_registry_->Lookup(std::string(scheme)); - if (!file_system) { - if (scheme.empty()) { - scheme = "[local]"; - } - - return errors::Unimplemented("File system scheme '", scheme, - "' not implemented (file: '", fname, "')"); - } - *result = file_system; - return absl::OkStatus(); -} - -absl::Status Env::GetRegisteredFileSystemSchemes( - std::vector* schemes) { - return file_system_registry_->GetRegisteredFileSystemSchemes(schemes); -} - -absl::Status Env::RegisterFileSystem(const std::string& scheme, - FileSystemRegistry::Factory factory) { - return file_system_registry_->Register(scheme, std::move(factory)); -} - -absl::Status Env::RegisterFileSystem(const std::string& scheme, - std::unique_ptr filesystem) { - return file_system_registry_->Register(scheme, std::move(filesystem)); -} - -absl::Status Env::SetOption(const std::string& scheme, const std::string& key, - const std::string& value) { - FileSystem* file_system = file_system_registry_->Lookup(scheme); - if (!file_system) { - return errors::Unimplemented("File system scheme '", scheme, - "' not found to set configuration"); - } - return file_system->SetOption(key, value); -} - -absl::Status Env::SetOption(const std::string& scheme, const std::string& key, - const std::vector& values) { - FileSystem* file_system = file_system_registry_->Lookup(scheme); - if (!file_system) { - return errors::Unimplemented("File system scheme '", scheme, - "' not found to set configuration"); - } - return file_system->SetOption(key, values); -} - -absl::Status Env::SetOption(const std::string& scheme, const std::string& key, - const std::vector& values) { - FileSystem* file_system = file_system_registry_->Lookup(scheme); - if (!file_system) { - return errors::Unimplemented("File system scheme '", scheme, - "' not found to set configuration"); - } - return file_system->SetOption(key, values); -} - -absl::Status Env::SetOption(const std::string& scheme, const std::string& key, - const std::vector& values) { - FileSystem* file_system = file_system_registry_->Lookup(scheme); - if (!file_system) { - return errors::Unimplemented("File system scheme '", scheme, - "' not found to set configuration"); - } - return file_system->SetOption(key, values); -} - -absl::Status Env::FlushFileSystemCaches() { - std::vector schemes; - TF_RETURN_IF_ERROR(GetRegisteredFileSystemSchemes(&schemes)); - for (const string& scheme : schemes) { - FileSystem* fs = nullptr; - TF_RETURN_IF_ERROR( - GetFileSystemForFile(io::CreateURI(scheme, "", ""), &fs)); - fs->FlushCaches(); - } - return absl::OkStatus(); -} - -absl::Status Env::NewRandomAccessFile( - const string& fname, std::unique_ptr* result) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->NewRandomAccessFile(fname, result); -} - -absl::Status Env::NewReadOnlyMemoryRegionFromFile( - const string& fname, std::unique_ptr* result) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->NewReadOnlyMemoryRegionFromFile(fname, result); -} - -absl::Status Env::NewWritableFile(const string& fname, - std::unique_ptr* result) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->NewWritableFile(fname, result); -} - -absl::Status Env::NewAppendableFile(const string& fname, - std::unique_ptr* result) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->NewAppendableFile(fname, result); -} - -absl::Status Env::FileExists(const string& fname) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->FileExists(fname); -} - -bool Env::FilesExist(const std::vector& files, - std::vector* status) { - std::unordered_map> files_per_fs; - for (const auto& file : files) { - absl::string_view scheme, host, path; - io::ParseURI(file, &scheme, &host, &path); - files_per_fs[string(scheme)].push_back(file); - } - - std::unordered_map per_file_status; - bool result = true; - for (auto itr : files_per_fs) { - FileSystem* file_system = file_system_registry_->Lookup(itr.first); - bool fs_result; - std::vector local_status; - std::vector* fs_status = status ? &local_status : nullptr; - if (!file_system) { - fs_result = false; - if (fs_status) { - absl::Status s = errors::Unimplemented("File system scheme '", - itr.first, "' not implemented"); - local_status.resize(itr.second.size(), s); - } - } else { - fs_result = file_system->FilesExist(itr.second, fs_status); - } - if (fs_status) { - result &= fs_result; - for (size_t i = 0; i < itr.second.size(); ++i) { - per_file_status[itr.second[i]] = fs_status->at(i); - } - } else if (!fs_result) { - // Return early - return false; - } - } - - if (status) { - for (const auto& file : files) { - status->push_back(per_file_status[file]); - } - } - - return result; -} - -absl::Status Env::GetChildren(const string& dir, std::vector* result) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(dir, &fs)); - return fs->GetChildren(dir, result); -} - -absl::Status Env::GetMatchingPaths(const string& pattern, - std::vector* results) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(pattern, &fs)); - return fs->GetMatchingPaths(pattern, results); -} - -absl::Status Env::DeleteFile(const string& fname) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->DeleteFile(fname); -} - -absl::Status Env::RecursivelyCreateDir(const string& dirname) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(dirname, &fs)); - return fs->RecursivelyCreateDir(dirname); -} - -absl::Status Env::CreateDir(const string& dirname) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(dirname, &fs)); - return fs->CreateDir(dirname); -} - -absl::Status Env::DeleteDir(const string& dirname) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(dirname, &fs)); - return fs->DeleteDir(dirname); -} - -absl::Status Env::Stat(const string& fname, FileStatistics* stat) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->Stat(fname, stat); -} - -absl::Status Env::IsDirectory(const string& fname) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->IsDirectory(fname); -} - -absl::Status Env::HasAtomicMove(const string& path, bool* has_atomic_move) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(path, &fs)); - return fs->HasAtomicMove(path, has_atomic_move); -} - -absl::Status Env::CanCreateTempFile(const string& fname, - bool* can_create_temp_file) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->CanCreateTempFile(fname, can_create_temp_file); -} - -absl::Status Env::DeleteRecursively(const string& dirname, - int64_t* undeleted_files, - int64_t* undeleted_dirs) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(dirname, &fs)); - return fs->DeleteRecursively(dirname, undeleted_files, undeleted_dirs); -} - -absl::Status Env::GetFileSize(const string& fname, uint64* file_size) { - FileSystem* fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(fname, &fs)); - return fs->GetFileSize(fname, file_size); -} - -absl::Status Env::RenameFile(const string& src, const string& target) { - FileSystem* src_fs; - FileSystem* target_fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(src, &src_fs)); - TF_RETURN_IF_ERROR(GetFileSystemForFile(target, &target_fs)); - if (src_fs != target_fs) { - return errors::Unimplemented("Renaming ", src, " to ", target, - " not implemented"); - } - return src_fs->RenameFile(src, target); -} - -absl::Status Env::CopyFile(const string& src, const string& target) { - FileSystem* src_fs; - FileSystem* target_fs; - TF_RETURN_IF_ERROR(GetFileSystemForFile(src, &src_fs)); - TF_RETURN_IF_ERROR(GetFileSystemForFile(target, &target_fs)); - if (src_fs == target_fs) { - return src_fs->CopyFile(src, target); - } - return FileSystemCopyFile(src_fs, src, target_fs, target); -} - -string Env::GetExecutablePath() { - char exe_path[PATH_MAX] = {0}; -#ifdef __APPLE__ - uint32_t buffer_size(0U); - _NSGetExecutablePath(nullptr, &buffer_size); - std::vector unresolved_path(buffer_size); - _NSGetExecutablePath(unresolved_path.data(), &buffer_size); - CHECK(realpath(unresolved_path.data(), exe_path)); -#elif defined(__FreeBSD__) - int mib[4] = {CTL_KERN, KERN_PROC, KERN_PROC_PATHNAME, -1}; - size_t exe_path_size = PATH_MAX; - - if (sysctl(mib, 4, exe_path, &exe_path_size, NULL, 0) != 0) { - // Resolution of path failed - return ""; - } -#elif defined(PLATFORM_WINDOWS) - HMODULE hModule = GetModuleHandleW(NULL); - WCHAR wc_file_path[MAX_PATH] = {0}; - GetModuleFileNameW(hModule, wc_file_path, MAX_PATH); - string file_path = WideCharToUtf8(wc_file_path); - std::copy(file_path.begin(), file_path.end(), exe_path); -#else - char buf[PATH_MAX] = {0}; - int path_length = readlink("/proc/self/exe", buf, sizeof(buf) - 1); - CHECK_NE(-1, path_length); - - if (strstr(buf, "python") != nullptr) { - // Discard the path of the python binary, and any flags. - int fd = open("/proc/self/cmdline", O_RDONLY); - CHECK_NE(-1, fd); - int cmd_length = read(fd, buf, PATH_MAX - 1); - CHECK_NE(-1, cmd_length); - close(fd); - int token_pos = 0; - for (bool token_is_first_or_flag = true; token_is_first_or_flag;) { - // Get token length, including null - int token_len = strlen(&buf[token_pos]) + 1; - token_is_first_or_flag = false; - // Check if we can skip without overshooting - if (token_pos + token_len < cmd_length) { - token_pos += token_len; - token_is_first_or_flag = (buf[token_pos] == '-'); // token is a flag - } - } - snprintf(exe_path, sizeof(exe_path), "%s", &buf[token_pos]); - } else { - snprintf(exe_path, sizeof(exe_path), "%s", buf); - } - -#endif - // Make sure it's null-terminated: - exe_path[sizeof(exe_path) - 1] = 0; - - return exe_path; -} - -bool Env::LocalTempFilename(string* filename) { - std::vector dirs; - GetLocalTempDirectories(&dirs); - - // Try each directory, as they might be full, have inappropriate - // permissions or have different problems at times. - for (const string& dir : dirs) { - *filename = io::JoinPath(dir, "tempfile-"); - if (CreateUniqueFileName(filename, "")) { - return true; - } - } - return false; -} - -bool Env::CreateUniqueFileName(string* prefix, const string& suffix) { - int64_t tid = GetCurrentThreadId(); - int32_t pid = GetProcessId(); - long long now_microsec = NowMicros(); // NOLINT - - absl::StrAppendFormat(prefix, "%s-%x-%d-%llx", port::Hostname(), tid, pid, - now_microsec); - - if (!suffix.empty()) { - *prefix += suffix; - } - if (FileExists(*prefix).ok()) { - prefix->clear(); - return false; - } else { - return true; - } -} - -int32 Env::GetProcessId() { -#ifdef PLATFORM_WINDOWS - return static_cast(GetCurrentProcessId()); -#else - return static_cast(getpid()); -#endif -} - -Thread::~Thread() {} - -EnvWrapper::~EnvWrapper() {} - -absl::Status ReadFileToString(Env* env, const string& fname, string* data) { - uint64 file_size; - absl::Status s = env->GetFileSize(fname, &file_size); - if (!s.ok()) { - return s; - } - std::unique_ptr file; - s = env->NewRandomAccessFile(fname, &file); - if (!s.ok()) { - return s; - } - data->resize(file_size); - char* p = &*data->begin(); - absl::string_view result; - s = file->Read(0, file_size, &result, p); - if (!s.ok()) { - data->clear(); - } else if (result.size() != file_size) { - s = errors::Aborted("File ", fname, " changed while reading: ", file_size, - " vs. ", result.size()); - data->clear(); - } else if (result.data() == p) { - // Data is already in the correct location - } else { - memmove(p, result.data(), result.size()); - } - return s; -} - -absl::Status WriteStringToFile(Env* env, const string& fname, - const absl::string_view& data) { - std::unique_ptr file; - absl::Status s = env->NewWritableFile(fname, &file); - if (!s.ok()) { - return s; - } - s = file->Append(data); - if (s.ok()) { - s = file->Close(); - } - return s; -} - -absl::Status FileSystemCopyFile(FileSystem* src_fs, const string& src, - FileSystem* target_fs, const string& target) { - std::unique_ptr src_file; - TF_RETURN_IF_ERROR(src_fs->NewRandomAccessFile(src, &src_file)); - - // When `target` points to a directory, we need to create a file within. - string target_name; - if (target_fs->IsDirectory(target).ok()) { - target_name = io::JoinPath(target, io::Basename(src)); - } else { - target_name = target; - } - - std::unique_ptr target_file; - TF_RETURN_IF_ERROR(target_fs->NewWritableFile(target_name, &target_file)); - - uint64 offset = 0; - std::unique_ptr scratch(new char[kCopyFileBufferSize]); - absl::Status s = absl::OkStatus(); - while (s.ok()) { - absl::string_view result; - s = src_file->Read(offset, kCopyFileBufferSize, &result, scratch.get()); - if (!(s.ok() || s.code() == error::OUT_OF_RANGE)) { - return s; - } - TF_RETURN_IF_ERROR(target_file->Append(result)); - offset += result.size(); - } - return target_file->Close(); -} - -// A ZeroCopyInputStream on a RandomAccessFile. -namespace { -class FileStream : public protobuf::io::ZeroCopyInputStream { - public: - explicit FileStream(RandomAccessFile* file) : file_(file), pos_(0) {} - - void BackUp(int count) override { pos_ -= count; } - bool Skip(int count) override { - pos_ += count; - return true; - } - int64_t ByteCount() const override { return pos_; } - absl::Status status() const { return status_; } - - bool Next(const void** data, int* size) override { - absl::string_view result; - absl::Status s = file_->Read(pos_, kBufSize, &result, scratch_); - if (result.empty()) { - status_ = s; - return false; - } - pos_ += result.size(); - *data = result.data(); - *size = result.size(); - return true; - } - - private: - static constexpr int kBufSize = 512 << 10; - - RandomAccessFile* file_; - int64_t pos_; - absl::Status status_; - char scratch_[kBufSize]; -}; - -} // namespace - -absl::Status WriteBinaryProto(Env* env, const string& fname, - const protobuf::MessageLite& proto) { - string serialized; - proto.AppendToString(&serialized); - return WriteStringToFile(env, fname, serialized); -} - -absl::Status ReadBinaryProto(Env* env, const string& fname, - protobuf::MessageLite* proto) { - std::unique_ptr file; - TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file)); - std::unique_ptr stream(new FileStream(file.get())); - protobuf::io::CodedInputStream coded_stream(stream.get()); - - if (!proto->ParseFromCodedStream(&coded_stream) || - !coded_stream.ConsumedEntireMessage()) { - TF_RETURN_IF_ERROR(stream->status()); - return errors::DataLoss("Can't parse ", fname, " as binary proto"); - } - return absl::OkStatus(); -} - -absl::Status WriteTextProto(Env* env, const string& fname, - const protobuf::Message& proto) { - string serialized; - if (!protobuf::TextFormat::PrintToString(proto, &serialized)) { - return errors::FailedPrecondition("Unable to convert proto to text."); - } - return WriteStringToFile(env, fname, serialized); -} - -absl::Status ReadTextProto(Env* env, const string& fname, - protobuf::Message* proto) { - std::unique_ptr file; - TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file)); - std::unique_ptr stream(new FileStream(file.get())); - - if (!protobuf::TextFormat::Parse(stream.get(), proto)) { - TF_RETURN_IF_ERROR(stream->status()); - return errors::DataLoss("Can't parse ", fname, " as text proto"); - } - return absl::OkStatus(); -} - -absl::Status ReadTextOrBinaryProto(Env* env, const string& fname, - protobuf::Message* proto) { - if (ReadTextProto(env, fname, proto).ok()) { - return absl::OkStatus(); - } - return ReadBinaryProto(env, fname, proto); -} - -absl::Status ReadTextOrBinaryProto(Env* env, const string& fname, - protobuf::MessageLite* proto) { - return ReadBinaryProto(env, fname, proto); -} - -} // namespace tsl diff --git a/tsl/platform/env.h b/tsl/platform/env.h index 874a80ac3..806cbb1c9 100644 --- a/tsl/platform/env.h +++ b/tsl/platform/env.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,722 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_ENV_H_ #define TENSORFLOW_TSL_PLATFORM_ENV_H_ -#include - -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "tsl/platform/env_time.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/numa.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" -#include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" - -// Delete leaked Windows definitions. -#ifdef PLATFORM_WINDOWS -#undef CopyFile -#undef DeleteFile -#endif - -namespace tsl { - -class Thread; -struct ThreadOptions; - -/// \brief An interface used by the tensorflow implementation to -/// access operating system functionality like the filesystem etc. -/// -/// Callers may wish to provide a custom Env object to get fine grain -/// control. -/// -/// All Env implementations of file-system modifying functionality are safe -/// for concurrent access from multiple threads without any external -/// synchronization, *however*, Envs and their underlying file systems are -/// global objects, and therefore, if any thread modifies options, the modified -/// options take effect process-wide. The SetOption functions themselves are -/// also *not* thread safe. -class Env { - public: - Env(); - virtual ~Env() = default; - - /// \brief Returns a default environment suitable for the current operating - /// system. - /// - /// Sophisticated users may wish to provide their own Env - /// implementation instead of relying on this default environment. - /// - /// The result of Default() belongs to this library and must never be deleted. - static Env* Default(); - - /// \brief Returns the FileSystem object to handle operations on the file - /// specified by 'fname'. The FileSystem object is used as the implementation - /// for the file system related (non-virtual) functions that follow. - /// Returned FileSystem object is still owned by the Env object and will - // (might) be destroyed when the environment is destroyed. - virtual absl::Status GetFileSystemForFile(const std::string& fname, - FileSystem** result); - - /// \brief Returns the file system schemes registered for this Env. - virtual absl::Status GetRegisteredFileSystemSchemes( - std::vector* schemes); - - /// \brief Register a file system for a scheme. - virtual absl::Status RegisterFileSystem(const std::string& scheme, - FileSystemRegistry::Factory factory); - - /// \brief Register a modular file system for a scheme. - /// - /// Same as `RegisterFileSystem` but for filesystems provided by plugins. - /// - /// TODO(b/139060984): After all filesystems are converted, make this be the - /// canonical registration function. - virtual absl::Status RegisterFileSystem( - const std::string& scheme, std::unique_ptr filesystem); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::string& value); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::vector& values); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::vector& values); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::vector& values); - - /// \brief Flush filesystem caches for all registered filesystems. - absl::Status FlushFileSystemCaches(); - - /// \brief Creates a brand new random access read-only file with the - /// specified name. - - /// On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. If the file does not exist, returns a non-OK - /// status. - /// - /// The returned file may be concurrently accessed by multiple threads. - /// - /// The ownership of the returned RandomAccessFile is passed to the caller - /// and the object should be deleted when is not used. The file object - /// shouldn't live longer than the Env object. - absl::Status NewRandomAccessFile(const std::string& fname, - std::unique_ptr* result); - - absl::Status NewRandomAccessFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - // We duplicate these methods due to Google internal coding style prevents - // virtual functions with default arguments. See PR #41615. - return absl::OkStatus(); - } - - /// \brief Creates an object that writes to a new file with the specified - /// name. - /// - /// Deletes any existing file with the same name and creates a - /// new file. On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. The file object - /// shouldn't live longer than the Env object. - absl::Status NewWritableFile(const std::string& fname, - std::unique_ptr* result); - - absl::Status NewWritableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// \brief Creates an object that either appends to an existing file, or - /// writes to a new file (if the file does not exist to begin with). - /// - /// On success, stores a pointer to the new file in *result and - /// returns OK. On failure stores NULL in *result and returns - /// non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. The file object - /// shouldn't live longer than the Env object. - absl::Status NewAppendableFile(const std::string& fname, - std::unique_ptr* result); - - absl::Status NewAppendableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - /// \brief Creates a readonly region of memory with the file context. - /// - /// On success, it returns a pointer to read-only memory region - /// from the content of file fname. The ownership of the region is passed to - /// the caller. On failure stores nullptr in *result and returns non-OK. - /// - /// The returned memory region can be accessed from many threads in parallel. - /// - /// The ownership of the returned ReadOnlyMemoryRegion is passed to the caller - /// and the object should be deleted when is not used. The memory region - /// object shouldn't live longer than the Env object. - absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, std::unique_ptr* result); - - absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// Returns OK if the named path exists and NOT_FOUND otherwise. - absl::Status FileExists(const std::string& fname); - - absl::Status FileExists(const std::string& fname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// Returns true if all the listed files exist, false otherwise. - /// if status is not null, populate the vector with a detailed status - /// for each file. - bool FilesExist(const std::vector& files, - std::vector* status); - - bool FilesExist(const std::vector& files, TransactionToken* token, - std::vector* status) { - return true; - } - - /// \brief Stores in *result the names of the children of the specified - /// directory. The names are relative to "dir". - /// - /// Original contents of *results are dropped. - absl::Status GetChildren(const std::string& dir, std::vector* result); - - absl::Status GetChildren(const std::string& dir, TransactionToken* token, - std::vector* result) { - return absl::OkStatus(); - } - - /// \brief Returns true if the path matches the given pattern. The wildcards - /// allowed in pattern are described in FileSystem::GetMatchingPaths. - virtual bool MatchPath(const std::string& path, - const std::string& pattern) = 0; - - /// \brief Given a pattern, stores in *results the set of paths that matches - /// that pattern. *results is cleared. - /// - /// More details about `pattern` in FileSystem::GetMatchingPaths. - virtual absl::Status GetMatchingPaths(const std::string& pattern, - std::vector* results); - - absl::Status GetMatchingPaths(const std::string& pattern, - TransactionToken* token, - std::vector* results) { - return absl::OkStatus(); - } - - /// Deletes the named file. - absl::Status DeleteFile(const std::string& fname); - - absl::Status DeleteFile(const std::string& fname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Deletes the specified directory and all subdirectories and files - /// underneath it. This is accomplished by traversing the directory tree - /// rooted at dirname and deleting entries as they are encountered. - /// - /// If dirname itself is not readable or does not exist, *undeleted_dir_count - /// is set to 1, *undeleted_file_count is set to 0 and an appropriate status - /// (e.g. NOT_FOUND) is returned. - /// - /// If dirname and all its descendants were successfully deleted, TF_OK is - /// returned and both error counters are set to zero. - /// - /// Otherwise, while traversing the tree, undeleted_file_count and - /// undeleted_dir_count are updated if an entry of the corresponding type - /// could not be deleted. The returned error status represents the reason that - /// any one of these entries could not be deleted. - /// - /// REQUIRES: undeleted_files, undeleted_dirs to be not null. - /// - /// Typical return codes: - /// * OK - dirname exists and we were able to delete everything underneath. - /// * NOT_FOUND - dirname doesn't exist - /// * PERMISSION_DENIED - dirname or some descendant is not writable - /// * UNIMPLEMENTED - Some underlying functions (like Delete) are not - /// implemented - absl::Status DeleteRecursively(const std::string& dirname, - int64_t* undeleted_files, - int64_t* undeleted_dirs); - - absl::Status DeleteRecursively(const std::string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) { - return absl::OkStatus(); - } - - /// \brief Creates the specified directory and all the necessary - /// subdirectories. Typical return codes. - /// * OK - successfully created the directory and sub directories, even if - /// they were already created. - /// * PERMISSION_DENIED - dirname or some subdirectory is not writable. - absl::Status RecursivelyCreateDir(const std::string& dirname); - - absl::Status RecursivelyCreateDir(const std::string& dirname, - TransactionToken* token) { - return absl::OkStatus(); - } - /// \brief Creates the specified directory. Typical return codes - /// * OK - successfully created the directory. - /// * ALREADY_EXISTS - directory already exists. - /// * PERMISSION_DENIED - dirname is not writable. - absl::Status CreateDir(const std::string& dirname); - - absl::Status CreateDir(const std::string& dirname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// Deletes the specified directory. - absl::Status DeleteDir(const std::string& dirname); - - absl::Status DeleteDir(const std::string& dirname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// Obtains statistics for the given path. - absl::Status Stat(const std::string& fname, FileStatistics* stat); - - absl::Status Stat(const std::string& fname, TransactionToken* token, - FileStatistics* stat) { - return absl::OkStatus(); - } - - /// \brief Returns whether the given path is a directory or not. - /// Typical return codes (not guaranteed exhaustive): - /// * OK - The path exists and is a directory. - /// * FAILED_PRECONDITION - The path exists and is not a directory. - /// * NOT_FOUND - The path entry does not exist. - /// * PERMISSION_DENIED - Insufficient permissions. - /// * UNIMPLEMENTED - The file factory doesn't support directories. - absl::Status IsDirectory(const std::string& fname); - - /// \brief Returns whether the given path is on a file system - /// that has atomic move capabilities. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// The second boolean argument has_atomic_move contains this information. - /// - /// Returns one of the following status codes (not guaranteed exhaustive): - /// * OK - The path is on a recognized file system, - /// so has_atomic_move holds the above information. - /// * UNIMPLEMENTED - The file system of the path hasn't been implemented in - /// TF - absl::Status HasAtomicMove(const std::string& path, bool* has_atomic_move); - - /// Returns whether the give path is on a file system - /// that has ability to create a new temp file. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// If this returns false, TensorFlow will write directly to output files - /// instead of creating a temporary file and swapping it in. This may mean - /// that incomplete writes are visible to consumers. - absl::Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); - - /// Stores the size of `fname` in `*file_size`. - absl::Status GetFileSize(const std::string& fname, uint64* file_size); - - absl::Status GetFileSize(const std::string& fname, TransactionToken* token, - uint64* file_size) { - return absl::OkStatus(); - } - - /// \brief Renames file src to target. If target already exists, it will be - /// replaced. - absl::Status RenameFile(const std::string& src, const std::string& target); - - absl::Status RenameFile(const std::string& src, const std::string& target, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Copy the src to target. - absl::Status CopyFile(const std::string& src, const std::string& target); - - absl::Status CopyFile(const std::string& src, const std::string& target, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief starts a new transaction on the filesystem that handles filename - absl::Status StartTransaction(const std::string& filename, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Adds `path` to transaction in `token` if token belongs to - /// filesystem that handles the path. - absl::Status AddToTransaction(const std::string& path, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Get token for `path` or start a new transaction and add `path` to - /// it. - absl::Status GetTokenOrStartTransaction(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Returns the transaction for `path` or nullptr in `token` - absl::Status GetTransactionForPath(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Finalizes the transaction - absl::Status EndTransaction(TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Returns the absolute path of the current executable. It resolves - /// symlinks if there is any. - std::string GetExecutablePath(); - - /// Creates a local unique temporary file name. Returns true if success. - bool LocalTempFilename(std::string* filename); - - /// Creates a local unique file name that starts with |prefix| and ends with - /// |suffix|. Returns true if success. - bool CreateUniqueFileName(std::string* prefix, const std::string& suffix); - - /// \brief Return the runfiles directory if running under bazel. Returns - /// the directory the executable is located in if not running under bazel. - virtual std::string GetRunfilesDir() = 0; - - // TODO(jeff,sanjay): Add back thread/thread-pool support if needed. - // TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or - // provide a routine to get the absolute time. - - /// \brief Returns the number of nano-seconds since the Unix epoch. - virtual uint64 NowNanos() const { return EnvTime::NowNanos(); } - - /// \brief Returns the number of micro-seconds since the Unix epoch. - virtual uint64 NowMicros() const { return EnvTime::NowMicros(); } - - /// \brief Returns the number of seconds since the Unix epoch. - virtual uint64 NowSeconds() const { return EnvTime::NowSeconds(); } - - /// Sleeps/delays the thread for the prescribed number of micro-seconds. - virtual void SleepForMicroseconds(int64_t micros) = 0; - - /// Returns the process ID of the calling process. - int32 GetProcessId(); - - /// \brief Returns a new thread that is running fn() and is identified - /// (for debugging/performance-analysis) by "name". - /// - /// Caller takes ownership of the result and must delete it eventually - /// (the deletion will block until fn() stops running). - virtual Thread* StartThread( - const ThreadOptions& thread_options, const std::string& name, - absl::AnyInvocable fn) TF_MUST_USE_RESULT = 0; - - // Returns the thread id of calling thread. - // Posix: Returns pthread id which is only guaranteed to be unique within a - // process. - // Windows: Returns thread id which is unique. - virtual int64_t GetCurrentThreadId() = 0; - - // Copies current thread name to "name". Returns true if success. - virtual bool GetCurrentThreadName(std::string* name) = 0; - - // \brief Schedules the given closure on a thread-pool. - // - // NOTE(mrry): This closure may block. - virtual void SchedClosure(absl::AnyInvocable closure) = 0; - - // \brief Schedules the given closure on a thread-pool after the given number - // of microseconds. - // - // NOTE(mrry): This closure must not block. - virtual void SchedClosureAfter(int64_t micros, - absl::AnyInvocable closure) = 0; - - // \brief Load a dynamic library. - // - // Pass "library_filename" to a platform-specific mechanism for dynamically - // loading a library. The rules for determining the exact location of the - // library are platform-specific and are not documented here. - // - // On success, returns a handle to the library in "*handle" and returns - // OK from the function. - // Otherwise returns nullptr in "*handle" and an error status from the - // function. - virtual absl::Status LoadDynamicLibrary(const char* library_filename, - void** handle) = 0; - - // \brief Get a pointer to a symbol from a dynamic library. - // - // "handle" should be a pointer returned from a previous call to LoadLibrary. - // On success, store a pointer to the located symbol in "*symbol" and return - // OK from the function. Otherwise, returns nullptr in "*symbol" and an error - // status from the function. - virtual absl::Status GetSymbolFromLibrary(void* handle, - const char* symbol_name, - void** symbol) = 0; - - // \brief build the name of dynamic library. - // - // "name" should be name of the library. - // "version" should be the version of the library or NULL - // returns the name that LoadLibrary() can use - virtual std::string FormatLibraryFileName(const std::string& name, - const std::string& version) = 0; - - // Returns a possible list of local temporary directories. - virtual void GetLocalTempDirectories(std::vector* list) = 0; - - private: - std::unique_ptr file_system_registry_; - Env(const Env&) = delete; - void operator=(const Env&) = delete; -}; - -/// \brief An implementation of Env that forwards all calls to another Env. -/// -/// May be useful to clients who wish to override just part of the -/// functionality of another Env. -class EnvWrapper : public Env { - public: - /// Initializes an EnvWrapper that delegates all calls to *t - explicit EnvWrapper(Env* t) : target_(t) {} - ~EnvWrapper() override; - - /// Returns the target to which this Env forwards all calls - Env* target() const { return target_; } - - absl::Status GetFileSystemForFile(const std::string& fname, - FileSystem** result) override { - return target_->GetFileSystemForFile(fname, result); - } - - absl::Status GetRegisteredFileSystemSchemes( - std::vector* schemes) override { - return target_->GetRegisteredFileSystemSchemes(schemes); - } - - absl::Status RegisterFileSystem( - const std::string& scheme, FileSystemRegistry::Factory factory) override { - return target_->RegisterFileSystem(scheme, factory); - } - - bool MatchPath(const std::string& path, const std::string& pattern) override { - return target_->MatchPath(path, pattern); - } - - uint64 NowMicros() const override { return target_->NowMicros(); } - void SleepForMicroseconds(int64_t micros) override { - target_->SleepForMicroseconds(micros); - } - Thread* StartThread(const ThreadOptions& thread_options, - const std::string& name, - absl::AnyInvocable fn) override { - return target_->StartThread(thread_options, name, std::move(fn)); - } - int64_t GetCurrentThreadId() override { - return target_->GetCurrentThreadId(); - } - bool GetCurrentThreadName(std::string* name) override { - return target_->GetCurrentThreadName(name); - } - void SchedClosure(absl::AnyInvocable closure) override { - target_->SchedClosure(std::move(closure)); - } - void SchedClosureAfter(int64_t micros, - absl::AnyInvocable closure) override { - target_->SchedClosureAfter(micros, std::move(closure)); - } - absl::Status LoadDynamicLibrary(const char* library_filename, - void** handle) override { - return target_->LoadDynamicLibrary(library_filename, handle); - } - absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) override { - return target_->GetSymbolFromLibrary(handle, symbol_name, symbol); - } - std::string FormatLibraryFileName(const std::string& name, - const std::string& version) override { - return target_->FormatLibraryFileName(name, version); - } - - std::string GetRunfilesDir() override { return target_->GetRunfilesDir(); } - - private: - void GetLocalTempDirectories(std::vector* list) override { - target_->GetLocalTempDirectories(list); - } - - Env* target_; -}; - -/// Represents a thread used to run a TSL function. -class Thread { - public: - Thread() {} - - /// Blocks until the thread of control stops running. - virtual ~Thread(); - - private: - Thread(const Thread&) = delete; - void operator=(const Thread&) = delete; -}; - -/// \brief Cross-platform setenv. -/// -/// Since setenv() is not available on windows, we provide an -/// alternative with platform specific implementations here. -int setenv(const char* name, const char* value, int overwrite); - -/// Cross-platform unsetenv. -int unsetenv(const char* name); - -/// \brief Options to configure a Thread. -/// -/// Note that the options are all hints, and the -/// underlying implementation may choose to ignore it. -struct ThreadOptions { - /// Thread stack size to use (in bytes). - size_t stack_size = 0; // 0: use system default value - /// Guard area size to use near thread stacks to use (in bytes) - size_t guard_size = 0; // 0: use system default value - int numa_node = port::kNUMANoAffinity; -}; - -/// A utility routine: copy contents of `src` in file system `src_fs` -/// to `target` in file system `target_fs`. -absl::Status FileSystemCopyFile(FileSystem* src_fs, const std::string& src, - FileSystem* target_fs, - const std::string& target); - -/// A utility routine: reads contents of named file into `*data` -absl::Status ReadFileToString(Env* env, const std::string& fname, - std::string* data); - -/// A utility routine: write contents of `data` to file named `fname` -/// (overwriting existing contents, if any). -absl::Status WriteStringToFile(Env* env, const std::string& fname, - const absl::string_view& data); - -/// Write binary representation of "proto" to the named file. -absl::Status WriteBinaryProto(Env* env, const std::string& fname, - const protobuf::MessageLite& proto); - -/// Reads contents of named file and parse as binary encoded proto data -/// and store into `*proto`. -absl::Status ReadBinaryProto(Env* env, const std::string& fname, - protobuf::MessageLite* proto); - -/// Write the text representation of "proto" to the named file. -inline absl::Status WriteTextProto(Env* /* env */, - const std::string& /* fname */, - const protobuf::MessageLite& /* proto */) { - return errors::Unimplemented("Can't write text protos with protolite."); -} -absl::Status WriteTextProto(Env* env, const std::string& fname, - const protobuf::Message& proto); - -/// Read contents of named file and parse as text encoded proto data -/// and store into `*proto`. -inline absl::Status ReadTextProto(Env* /* env */, - const std::string& /* fname */, - protobuf::MessageLite* /* proto */) { - return errors::Unimplemented("Can't parse text protos with protolite."); -} -absl::Status ReadTextProto(Env* env, const std::string& fname, - protobuf::Message* proto); - -/// Read contents of named file and parse as either text or binary encoded proto -/// data and store into `*proto`. -absl::Status ReadTextOrBinaryProto(Env* env, const std::string& fname, - protobuf::Message* proto); -absl::Status ReadTextOrBinaryProto(Env* env, const std::string& fname, - protobuf::MessageLite* proto); - -// START_SKIP_DOXYGEN - -// The following approach to register filesystems is deprecated and will be -// replaced with modular filesystem plugins registration. -// TODO(b/139060984): After all filesystems are converted, remove this. -namespace register_file_system { - -template -struct Register { - Register(Env* env, const std::string& scheme, bool try_modular_filesystems) { - // TODO(yongtang): Remove legacy file system registration for hdfs/s3/gcs - // after TF 2.6+. - if (try_modular_filesystems) { - const char* env_value = getenv("TF_USE_MODULAR_FILESYSTEM"); - string load_plugin = env_value ? absl::AsciiStrToLower(env_value) : ""; - if (load_plugin == "true" || load_plugin == "1") { - // We don't register the static filesystem and wait for SIG IO one - LOG(WARNING) << "Using modular file system for '" << scheme << "'." - << " Please switch to tensorflow-io" - << " (https://github.com/tensorflow/io) for file system" - << " support of '" << scheme << "'."; - return; - } - // If the envvar is missing or not "true"/"1", then fall back to legacy - // implementation to be backwards compatible. - } - // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! - env->RegisterFileSystem(scheme, []() -> FileSystem* { return new Factory; }) - .IgnoreError(); - } -}; - -} // namespace register_file_system - -// END_SKIP_DOXYGEN - -} // namespace tsl - -// Register a FileSystem implementation for a scheme. Files with names that have -// "scheme://" prefixes are routed to use this implementation. -#define REGISTER_FILE_SYSTEM_ENV(env, scheme, factory, modular) \ - REGISTER_FILE_SYSTEM_UNIQ_HELPER(__COUNTER__, env, scheme, factory, modular) -#define REGISTER_FILE_SYSTEM_UNIQ_HELPER(ctr, env, scheme, factory, modular) \ - REGISTER_FILE_SYSTEM_UNIQ(ctr, env, scheme, factory, modular) -#define REGISTER_FILE_SYSTEM_UNIQ(ctr, env, scheme, factory, modular) \ - static ::tsl::register_file_system::Register register_ff##ctr \ - TF_ATTRIBUTE_UNUSED = \ - ::tsl::register_file_system::Register(env, scheme, modular) - -#define REGISTER_FILE_SYSTEM(scheme, factory) \ - REGISTER_FILE_SYSTEM_ENV(::tsl::Env::Default(), scheme, factory, false); - -#define REGISTER_LEGACY_FILE_SYSTEM(scheme, factory) \ - REGISTER_FILE_SYSTEM_ENV(::tsl::Env::Default(), scheme, factory, true); +#include "xla/tsl/platform/env.h" #endif // TENSORFLOW_TSL_PLATFORM_ENV_H_ diff --git a/tsl/platform/env_time.h b/tsl/platform/env_time.h index 2ec888069..eaadae805 100644 --- a/tsl/platform/env_time.h +++ b/tsl/platform/env_time.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,54 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef TENSORFLOW_TSL_PLATFORM_ENV_TIME_H_ #define TENSORFLOW_TSL_PLATFORM_ENV_TIME_H_ -#include - -#include "tsl/platform/types.h" - -namespace tsl { - -/// \brief An interface used by the tsl implementation to -/// access timer related operations. -class EnvTime { - public: - static constexpr uint64 kMicrosToPicos = 1000ULL * 1000ULL; - static constexpr uint64 kMicrosToNanos = 1000ULL; - static constexpr uint64 kMillisToMicros = 1000ULL; - static constexpr uint64 kMillisToNanos = 1000ULL * 1000ULL; - static constexpr uint64 kNanosToPicos = 1000ULL; - static constexpr uint64 kSecondsToMillis = 1000ULL; - static constexpr uint64 kSecondsToMicros = 1000ULL * 1000ULL; - static constexpr uint64 kSecondsToNanos = 1000ULL * 1000ULL * 1000ULL; - - EnvTime() = default; - virtual ~EnvTime() = default; - - /// \brief Returns the number of nano-seconds since the Unix epoch. - static uint64 NowNanos(); - - /// \brief Returns the number of micro-seconds since the Unix epoch. - static uint64 NowMicros() { return NowNanos() / kMicrosToNanos; } - - /// \brief Returns the number of seconds since the Unix epoch. - static uint64 NowSeconds() { return NowNanos() / kSecondsToNanos; } - - /// \brief A version of NowNanos() that may be overridden by a subclass. - virtual uint64 GetOverridableNowNanos() const { return NowNanos(); } - - /// \brief A version of NowMicros() that may be overridden by a subclass. - virtual uint64 GetOverridableNowMicros() const { - return GetOverridableNowNanos() / kMicrosToNanos; - } - - /// \brief A version of NowSeconds() that may be overridden by a subclass. - virtual uint64 GetOverridableNowSeconds() const { - return GetOverridableNowNanos() / kSecondsToNanos; - } -}; - -} // namespace tsl +#include "xla/tsl/platform/env_time.h" #endif // TENSORFLOW_TSL_PLATFORM_ENV_TIME_H_ diff --git a/tsl/platform/errors.cc b/tsl/platform/errors.cc deleted file mode 100644 index 6c732a478..000000000 --- a/tsl/platform/errors.cc +++ /dev/null @@ -1,249 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/errors.h" - -#include -#include - -#include "tsl/platform/status.h" -#include "tsl/platform/strcat.h" - -namespace tsl { -namespace errors { - -namespace { - -absl::StatusCode ErrnoToCode(int err_number) { - absl::StatusCode code; - switch (err_number) { - case 0: - code = absl::StatusCode::kOk; - break; - case EINVAL: // Invalid argument - case ENAMETOOLONG: // Filename too long - case E2BIG: // Argument list too long - case EDESTADDRREQ: // Destination address required - case EDOM: // Mathematics argument out of domain of function - case EFAULT: // Bad address - case EILSEQ: // Illegal byte sequence - case ENOPROTOOPT: // Protocol not available - case ENOSTR: // Not a STREAM - case ENOTSOCK: // Not a socket - case ENOTTY: // Inappropriate I/O control operation - case EPROTOTYPE: // Protocol wrong type for socket - case ESPIPE: // Invalid seek - code = absl::StatusCode::kInvalidArgument; - break; - case ETIMEDOUT: // Connection timed out - case ETIME: // Timer expired - code = absl::StatusCode::kDeadlineExceeded; - break; - case ENODEV: // No such device - case ENOENT: // No such file or directory - case ENXIO: // No such device or address - case ESRCH: // No such process - code = absl::StatusCode::kNotFound; - break; - case EEXIST: // File exists - case EADDRNOTAVAIL: // Address not available - case EALREADY: // Connection already in progress - code = absl::StatusCode::kAlreadyExists; - break; - case EPERM: // Operation not permitted - case EACCES: // Permission denied - case EROFS: // Read only file system - code = absl::StatusCode::kPermissionDenied; - break; - case ENOTEMPTY: // Directory not empty - case EISDIR: // Is a directory - case ENOTDIR: // Not a directory - case EADDRINUSE: // Address already in use - case EBADF: // Invalid file descriptor - case EBUSY: // Device or resource busy - case ECHILD: // No child processes - case EISCONN: // Socket is connected -#if !defined(_WIN32) && !defined(__HAIKU__) - case ENOTBLK: // Block device required -#endif - case ENOTCONN: // The socket is not connected - case EPIPE: // Broken pipe -#if !defined(_WIN32) - case ESHUTDOWN: // Cannot send after transport endpoint shutdown -#endif - case ETXTBSY: // Text file busy - code = absl::StatusCode::kFailedPrecondition; - break; - case ENOSPC: // No space left on device -#if !defined(_WIN32) - case EDQUOT: // Disk quota exceeded -#endif - case EMFILE: // Too many open files - case EMLINK: // Too many links - case ENFILE: // Too many open files in system - case ENOBUFS: // No buffer space available - case ENODATA: // No message is available on the STREAM read queue - case ENOMEM: // Not enough space - case ENOSR: // No STREAM resources -#if !defined(_WIN32) && !defined(__HAIKU__) - case EUSERS: // Too many users -#endif - code = absl::StatusCode::kResourceExhausted; - break; - case EFBIG: // File too large - case EOVERFLOW: // Value too large to be stored in data type - case ERANGE: // Result too large - code = absl::StatusCode::kOutOfRange; - break; - case ENOSYS: // Function not implemented - case ENOTSUP: // Operation not supported - case EAFNOSUPPORT: // Address family not supported -#if !defined(_WIN32) - case EPFNOSUPPORT: // Protocol family not supported -#endif - case EPROTONOSUPPORT: // Protocol not supported -#if !defined(_WIN32) && !defined(__HAIKU__) - case ESOCKTNOSUPPORT: // Socket type not supported -#endif - case EXDEV: // Improper link - code = absl::StatusCode::kUnimplemented; - break; - case EAGAIN: // Resource temporarily unavailable - case ECONNREFUSED: // Connection refused - case ECONNABORTED: // Connection aborted - case ECONNRESET: // Connection reset - case EINTR: // Interrupted function call -#if !defined(_WIN32) - case EHOSTDOWN: // Host is down -#endif - case EHOSTUNREACH: // Host is unreachable - case ENETDOWN: // Network is down - case ENETRESET: // Connection aborted by network - case ENETUNREACH: // Network unreachable - case ENOLCK: // No locks available - case ENOLINK: // Link has been severed -#if !(defined(__APPLE__) || defined(__FreeBSD__) || defined(_WIN32) || \ - defined(__HAIKU__)) - case ENONET: // Machine is not on the network -#endif - code = absl::StatusCode::kUnavailable; - break; - case EDEADLK: // Resource deadlock avoided -#if !defined(_WIN32) - case ESTALE: // Stale file handle -#endif - code = absl::StatusCode::kAborted; - break; - case ECANCELED: // Operation cancelled - code = absl::StatusCode::kCancelled; - break; - // NOTE: If you get any of the following (especially in a - // reproducible way) and can propose a better mapping, - // please email the owners about updating this mapping. - case EBADMSG: // Bad message - case EIDRM: // Identifier removed - case EINPROGRESS: // Operation in progress - case EIO: // I/O error - case ELOOP: // Too many levels of symbolic links - case ENOEXEC: // Exec format error - case ENOMSG: // No message of the desired type - case EPROTO: // Protocol error -#if !defined(_WIN32) && !defined(__HAIKU__) - case EREMOTE: // Object is remote -#endif - code = absl::StatusCode::kUnknown; - break; - default: { - code = absl::StatusCode::kUnknown; - break; - } - } - return code; -} - -} // namespace - -absl::Status IOError(const string& context, int err_number) { - auto code = ErrnoToCode(err_number); - return absl::Status(code, - strings::StrCat(context, "; ", strerror(err_number))); -} - -bool IsAborted(const absl::Status& status) { - return status.code() == tsl::error::Code::ABORTED; -} - -bool IsAlreadyExists(const absl::Status& status) { - return status.code() == tsl::error::Code::ALREADY_EXISTS; -} - -bool IsCancelled(const absl::Status& status) { - return status.code() == tsl::error::Code::CANCELLED; -} - -bool IsDataLoss(const absl::Status& status) { - return status.code() == tsl::error::Code::DATA_LOSS; -} - -bool IsDeadlineExceeded(const absl::Status& status) { - return status.code() == tsl::error::Code::DEADLINE_EXCEEDED; -} - -bool IsFailedPrecondition(const absl::Status& status) { - return status.code() == tsl::error::Code::FAILED_PRECONDITION; -} - -bool IsInternal(const absl::Status& status) { - return status.code() == tsl::error::Code::INTERNAL; -} - -bool IsInvalidArgument(const absl::Status& status) { - return status.code() == tsl::error::Code::INVALID_ARGUMENT; -} - -bool IsNotFound(const absl::Status& status) { - return status.code() == tsl::error::Code::NOT_FOUND; -} - -bool IsOutOfRange(const absl::Status& status) { - return status.code() == tsl::error::Code::OUT_OF_RANGE; -} - -bool IsPermissionDenied(const absl::Status& status) { - return status.code() == tsl::error::Code::PERMISSION_DENIED; -} - -bool IsResourceExhausted(const absl::Status& status) { - return status.code() == tsl::error::Code::RESOURCE_EXHAUSTED; -} - -bool IsUnauthenticated(const absl::Status& status) { - return status.code() == tsl::error::Code::UNAUTHENTICATED; -} - -bool IsUnavailable(const absl::Status& status) { - return status.code() == tsl::error::Code::UNAVAILABLE; -} - -bool IsUnimplemented(const absl::Status& status) { - return status.code() == tsl::error::Code::UNIMPLEMENTED; -} - -bool IsUnknown(const absl::Status& status) { - return status.code() == tsl::error::Code::UNKNOWN; -} - -} // namespace errors -} // namespace tsl diff --git a/tsl/platform/errors.h b/tsl/platform/errors.h index 9be699596..0c28bd418 100644 --- a/tsl/platform/errors.h +++ b/tsl/platform/errors.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,631 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_ERRORS_H_ #define TENSORFLOW_TSL_PLATFORM_ERRORS_H_ -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/str_join.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/str_util.h" -#include "tsl/platform/strcat.h" - -namespace tsl { -namespace error { -// NOLINTBEGIN(misc-unused-using-decls) -// TODO(aminim): figure out the protobuf migration story. -using tensorflow::error::ABORTED; -using tensorflow::error::ALREADY_EXISTS; -using tensorflow::error::CANCELLED; -using tensorflow::error::Code; -using tensorflow::error::DATA_LOSS; -using tensorflow::error::DEADLINE_EXCEEDED; -using tensorflow::error::FAILED_PRECONDITION; -using tensorflow::error::INTERNAL; -using tensorflow::error::INVALID_ARGUMENT; -using tensorflow::error::NOT_FOUND; -using tensorflow::error::OK; -using tensorflow::error::OUT_OF_RANGE; -using tensorflow::error::PERMISSION_DENIED; -using tensorflow::error::RESOURCE_EXHAUSTED; -using tensorflow::error::UNAUTHENTICATED; -using tensorflow::error::UNAVAILABLE; -using tensorflow::error::UNIMPLEMENTED; -using tensorflow::error::UNKNOWN; -// NOLINTEND(misc-unused-using-decls) -} // namespace error - -namespace errors { - -namespace internal { - -// The DECLARE_ERROR macro below only supports types that can be converted -// into StrCat's AlphaNum. For the other types we rely on a slower path -// through std::stringstream. To add support of a new type, it is enough to -// make sure there is an operator<<() for it: -// -// std::ostream& operator<<(std::ostream& os, const MyType& foo) { -// os << foo.ToString(); -// return os; -// } -// Eventually absl::strings will have native support for this and we will be -// able to completely remove PrepareForStrCat(). -template -typename std::enable_if::value, - std::string>::type -PrepareForStrCat(const T& t) { - std::stringstream ss; - ss << t; - return ss.str(); -} -inline const strings::AlphaNum& PrepareForStrCat(const strings::AlphaNum& a) { - return a; -} - -} // namespace internal - -// Maps UNIX errors into a Status. -absl::Status IOError(const string& context, int err_number); - -// Returns all payloads from a Status as a key-value map. -inline std::unordered_map GetPayloads( - const absl::Status& status) { - std::unordered_map payloads; - status.ForEachPayload( - [&payloads](absl::string_view key, const absl::Cord& value) { - payloads[std::string(key)] = std::string(value); - }); - return payloads; -} - -// Inserts all given payloads into the given status. Will overwrite existing -// payloads if they exist with the same key. -inline void InsertPayloads( - absl::Status& status, - const std::unordered_map& payloads) { - for (const auto& payload : payloads) { - status.SetPayload(payload.first, absl::Cord(payload.second)); - } -} - -// Copies all payloads from one Status to another. Will overwrite existing -// payloads in the destination if they exist with the same key. -inline void CopyPayloads(const absl::Status& from, absl::Status& to) { - from.ForEachPayload([&to](absl::string_view key, const absl::Cord& value) { - to.SetPayload(key, value); - }); -} - -#if defined(PLATFORM_GOOGLE) -// Creates a new status with the given code, message and payloads. -inline absl::Status Create( - absl::StatusCode code, absl::string_view message, - const std::unordered_map& payloads, - absl::SourceLocation loc = absl::SourceLocation::current()) { - absl::Status status(code, message, loc); - InsertPayloads(status, payloads); - return status; -} -// Returns a new Status, replacing its message with the given. -inline absl::Status CreateWithUpdatedMessage(const absl::Status& status, - absl::string_view message) { - auto locations = status.GetSourceLocations(); - auto initial_loc = - locations.empty() ? absl::SourceLocation::current() : locations[0]; - absl::Status new_status = Create(static_cast(status.code()), - message, GetPayloads(status), initial_loc); - if (locations.size() > 1) { - for (auto loc : locations.subspan(1)) { - new_status.AddSourceLocation(loc); - } - } - return new_status; -} - -#else -inline ::absl::Status Create( - absl::StatusCode code, ::tsl::StringPiece message, - const std::unordered_map& payloads) { - Status status(code, message); - InsertPayloads(status, payloads); - return status; -} -// Returns a new Status, replacing its message with the given. -inline ::tsl::Status CreateWithUpdatedMessage(const ::tsl::Status& status, - ::tsl::StringPiece message) { - return Create(static_cast(status.code()), message, - GetPayloads(status)); -} -#endif - -// Append some context to an error message. Each time we append -// context put it on a new line, since it is possible for there -// to be several layers of additional context. -template -void AppendToMessage(absl::Status* status, Args... args) { - auto new_status = CreateWithUpdatedMessage( - *status, ::tsl::strings::StrCat(status->message(), "\n\t", args...)); - CopyPayloads(*status, new_status); - *status = std::move(new_status); -} - -// For propagating errors when calling a function. -#define TF_RETURN_IF_ERROR(...) \ - do { \ - ::absl::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - MAYBE_ADD_SOURCE_LOCATION(_status) \ - return _status; \ - } \ - } while (0) - -#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ - do { \ - ::tsl::Status _status = (expr); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - ::tsl::errors::AppendToMessage(&_status, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -// Convenience functions for generating and using error status. -// Example usage: -// status.Update(errors::InvalidArgument("The ", foo, " isn't right.")); -// if (errors::IsInvalidArgument(status)) { ... } -// switch (status.code()) { case error::INVALID_ARGUMENT: ... } - -// CANCELLED -template -absl::Status Cancelled(Args... args) { - return absl::Status(absl::StatusCode::kCancelled, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status CancelledWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kCancelled, message, payloads); -} - -// InvalidArgument -template -absl::Status InvalidArgument(Args... args) { - return absl::Status(absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} - -#if defined(PLATFORM_GOOGLE) -// Specialized overloads to capture source location for up to three arguments. -template -::absl::Status InvalidArgument( - Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3), - ::tsl::errors::internal::PrepareForStrCat(arg4)), - loc); -} -template -::absl::Status InvalidArgument( - Arg1 arg1, Arg2 arg2, Arg3 arg3, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3)), - loc); -} -template -::absl::Status InvalidArgument( - Arg1 arg1, Arg2 arg2, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2)), - loc); -} -template -::absl::Status InvalidArgument( - Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), - loc); -} -template -::absl::Status InvalidArgumentWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads, - loc); -} -#else -template -::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3) { - return ::absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3))); -} -template -::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) { - return ::absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2))); -} -template -::absl::Status InvalidArgument(Arg1 arg1) { - return ::absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); -} -template -::absl::Status InvalidArgumentWithPayloads( - const ::tsl::StringPiece& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads); -} -#endif - -// NotFound -template -absl::Status NotFound(Args... args) { - return absl::Status(absl::StatusCode::kNotFound, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -#if defined(PLATFORM_GOOGLE) -// Specialized overloads to capture source location for up to three arguments. -template -::absl::Status NotFound( - Arg1 arg1, Arg2 arg2, Arg3 arg3, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3)), - loc); -} -template -::absl::Status NotFound( - Arg1 arg1, Arg2 arg2, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2)), - loc); -} -template -::absl::Status NotFound( - Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), - loc); -} -template -::absl::Status NotFoundWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return errors::Create(absl::StatusCode::kNotFound, message, payloads, loc); -} -#else -template -::absl::Status NotFound(Arg1 arg1, Arg2 arg2, Arg3 arg3) { - return ::absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3))); -} -template -::absl::Status NotFound(Arg1 arg1, Arg2 arg2) { - return ::absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2))); -} -template -::absl::Status NotFound(Arg1 arg1) { - return ::absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); -} -template -::absl::Status NotFoundWithPayloads( - const ::tsl::StringPiece& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kNotFound, message, payloads); -} -#endif - -// AlreadyExists -template -absl::Status AlreadyExists(Args... args) { - return absl::Status(absl::StatusCode::kAlreadyExists, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status AlreadyExistsWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kAlreadyExists, message, payloads); -} - -// ResourceExhausted -template -absl::Status ResourceExhausted(Args... args) { - return absl::Status(absl::StatusCode::kResourceExhausted, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status ResourceExhaustedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kResourceExhausted, message, - payloads); -} - -// Unavailable -template -absl::Status Unavailable(Args... args) { - return absl::Status(absl::StatusCode::kUnavailable, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnavailableWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnavailable, message, payloads); -} - -// FailedPrecondition -template -absl::Status FailedPrecondition(Args... args) { - return absl::Status(absl::StatusCode::kFailedPrecondition, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status FailedPreconditionWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kFailedPrecondition, message, - payloads); -} - -// OutOfRange -template -absl::Status OutOfRange(Args... args) { - return absl::Status(absl::StatusCode::kOutOfRange, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status OutOfRangeWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kOutOfRange, message, payloads); -} - -// Unimplemented -template -absl::Status Unimplemented(Args... args) { - return absl::Status(absl::StatusCode::kUnimplemented, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnimplementedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnimplemented, message, payloads); -} - -// Internal -template -absl::Status Internal(Args... args) { - return absl::Status(absl::StatusCode::kInternal, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status InternalWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kInternal, message, payloads); -} - -// Aborted -template -absl::Status Aborted(Args... args) { - return absl::Status(absl::StatusCode::kAborted, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status AbortedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kAborted, message, payloads); -} - -// DeadlineExceeded -template -absl::Status DeadlineExceeded(Args... args) { - return absl::Status(absl::StatusCode::kDeadlineExceeded, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status DeadlineExceededWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kDeadlineExceeded, message, payloads); -} - -// DataLoss -template -absl::Status DataLoss(Args... args) { - return absl::Status(absl::StatusCode::kDataLoss, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status DataLossWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kDataLoss, message, payloads); -} - -// Unknown -template -absl::Status Unknown(Args... args) { - return absl::Status(absl::StatusCode::kUnknown, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnknownPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnknown, message, payloads); -} -// PermissionDenied -template -absl::Status PermissionDenied(Args... args) { - return absl::Status(absl::StatusCode::kPermissionDenied, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status PermissionDeniedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kPermissionDenied, message, payloads); -} - -// Unauthenticated -template -absl::Status Unauthenticated(Args... args) { - return absl::Status(absl::StatusCode::kUnauthenticated, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnauthenticatedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnauthenticated, message, payloads); -} - -bool IsAborted(const absl::Status& status); -bool IsAlreadyExists(const absl::Status& status); -bool IsCancelled(const absl::Status& status); -bool IsDataLoss(const absl::Status& status); -bool IsDeadlineExceeded(const absl::Status& status); -bool IsFailedPrecondition(const absl::Status& status); -bool IsInternal(const absl::Status& status); -bool IsInvalidArgument(const absl::Status& status); -bool IsNotFound(const absl::Status& status); -bool IsOutOfRange(const absl::Status& status); -bool IsPermissionDenied(const absl::Status& status); -bool IsResourceExhausted(const absl::Status& status); -bool IsUnauthenticated(const absl::Status& status); -bool IsUnavailable(const absl::Status& status); -bool IsUnimplemented(const absl::Status& status); -bool IsUnknown(const absl::Status& status); - -// Produces a formatted string pattern from the name which can uniquely identify -// this node upstream to produce an informative error message. The pattern -// followed is: {{node }} -// Note: The pattern below determines the regex _NODEDEF_NAME_RE in the file -// tensorflow/python/client/session.py -// LINT.IfChange -inline std::string FormatNodeNameForError(absl::string_view name) { - return strings::StrCat("{{node ", name, "}}"); -} -// LINT.ThenChange(//tensorflow/python/client/session.py) -template -std::string FormatNodeNamesForError(const T& names) { - return absl::StrJoin( - names, ", ", [](std::string* output, absl::string_view s) { - ::tsl::strings::StrAppend(output, FormatNodeNameForError(s)); - }); -} -// LINT.IfChange -inline std::string FormatColocationNodeForError(absl::string_view name) { - return strings::StrCat("{{colocation_node ", name, "}}"); -} -// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) -template >> -std::string FormatColocationNodeForError(const T& names) { - return absl::StrJoin( - names, ", ", [](std::string* output, absl::string_view s) { - ::tsl::strings::StrAppend(output, FormatColocationNodeForError(s)); - }); -} - -inline std::string FormatFunctionForError(absl::string_view name) { - return strings::StrCat("{{function_node ", name, "}}"); -} - -inline absl::Status ReplaceErrorFromNonCommunicationOps( - const absl::Status s, absl::string_view op_name) { - assert(::tsl::errors::IsUnavailable(s)); - return absl::Status( - absl::StatusCode::kInternal, - strings::StrCat( - s.message(), "\nExecuting non-communication op <", op_name, - "> originally returned UnavailableError, and was replaced by " - "InternalError to avoid invoking TF network error handling logic.")); -} - -template -std::string FormatOriginalNodeLocationForError(const T& node_names, - const T& func_names) { - std::vector error_message; - for (int i = 0; i != node_names.size(); ++i) { - if (i != 0) { - error_message.push_back(", "); - } - if (i < func_names.size()) { - error_message.push_back(FormatFunctionForError(func_names[i])); - } - error_message.push_back(FormatNodeNameForError(node_names[i])); - } - return absl::StrJoin(error_message, ""); -} - -// The CanonicalCode() for non-errors. -using ::tsl::error::OK; // NOLINT - -} // namespace errors -} // namespace tsl +#include "xla/tsl/platform/errors.h" #endif // TENSORFLOW_TSL_PLATFORM_ERRORS_H_ diff --git a/tsl/platform/errors_test.cc b/tsl/platform/errors_test.cc deleted file mode 100644 index 88a3a5a78..000000000 --- a/tsl/platform/errors_test.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/errors.h" - -#include "absl/status/status.h" -#include "tsl/platform/test.h" - -namespace tsl { - -TEST(AppendToMessageTest, PayloadsAreCopied) { - absl::Status status = errors::Aborted("Aborted Error Message"); - status.SetPayload("payload_key", absl::Cord("payload_value")); - errors::AppendToMessage(&status, "Appended Message"); - - EXPECT_EQ(status.message(), "Aborted Error Message\n\tAppended Message"); - EXPECT_EQ(status.GetPayload("payload_key"), absl::Cord("payload_value")); -} - -TEST(Status, GetAllPayloads) { - absl::Status s_error(absl::StatusCode::kInternal, "Error message"); - s_error.SetPayload("Error key", absl::Cord("foo")); - auto payloads_error_status = errors::GetPayloads(s_error); - ASSERT_EQ(payloads_error_status.size(), 1); - ASSERT_EQ(payloads_error_status["Error key"], "foo"); - - absl::Status s_ok = absl::Status(); - auto payloads_ok_status = errors::GetPayloads(s_ok); - ASSERT_TRUE(payloads_ok_status.empty()); -} - -TEST(Status, OKStatusInsertPayloadsFromErrorStatus) { - // An OK status will should not change after InsertPayloads() calls. - absl::Status s_error(absl::StatusCode::kInternal, "Error message"); - s_error.SetPayload("Error key", absl::Cord("foo")); - absl::Status s_ok = absl::Status(); - - errors::InsertPayloads(s_ok, errors::GetPayloads(s_error)); - auto payloads_ok_status = errors::GetPayloads(s_ok); - ASSERT_TRUE(payloads_ok_status.empty()); -} - -TEST(Status, ErrorStatusInsertPayloadsFromOKStatus) { - // An InsertPayloads() call should not take effect from empty inputs. - absl::Status s_error(absl::StatusCode::kInternal, "Error message"); - s_error.SetPayload("Error key", absl::Cord("foo")); - absl::Status s_ok = absl::Status(); - - errors::InsertPayloads(s_error, errors::GetPayloads(s_ok)); - ASSERT_EQ(s_error.GetPayload("Error key"), "foo"); -} - -TEST(Status, ErrorStatusInsertPayloadsFromErrorStatus) { - absl::Status s_error1(absl::StatusCode::kInternal, "Error message"); - s_error1.SetPayload("Error key 1", absl::Cord("foo")); - s_error1.SetPayload("Error key 2", absl::Cord("bar")); - absl::Status s_error2(absl::StatusCode::kInternal, "Error message"); - s_error2.SetPayload("Error key", absl::Cord("bar")); - ASSERT_EQ(s_error2.GetPayload("Error key"), "bar"); - - errors::InsertPayloads(s_error2, errors::GetPayloads(s_error1)); - ASSERT_EQ(s_error2.GetPayload("Error key 1"), "foo"); - ASSERT_EQ(s_error2.GetPayload("Error key 2"), "bar"); - auto payloads_error_status = errors::GetPayloads(s_error2); - ASSERT_EQ(payloads_error_status.size(), 3); -} - -#if defined(PLATFORM_GOOGLE) - -absl::Status GetError() { - return absl::InvalidArgumentError("An invalid argument error"); -} - -absl::Status PropagateError() { - TF_RETURN_IF_ERROR(GetError()); - return absl::OkStatus(); -} - -absl::Status PropagateError2() { - TF_RETURN_IF_ERROR(PropagateError()); - return absl::OkStatus(); -} - -TEST(Status, StackTracePropagation) { - absl::Status s = PropagateError2(); - auto sources = s.GetSourceLocations(); - ASSERT_EQ(sources.size(), 3); - - for (int i = 0; i < 3; ++i) { - ASSERT_EQ(sources[i].file_name(), - "third_party/tensorflow/tsl/platform/errors_test.cc"); - } -} - -TEST(Status, SourceLocationsPreservedByAppend) { - absl::Status s = PropagateError2(); - ASSERT_EQ(s.GetSourceLocations().size(), 3); - errors::AppendToMessage(&s, "A new message."); - ASSERT_EQ(s.GetSourceLocations().size(), 3); -} - -TEST(Status, SourceLocationsPreservedByUpdate) { - absl::Status s = PropagateError2(); - ASSERT_EQ(s.GetSourceLocations().size(), 3); - absl::Status s2 = errors::CreateWithUpdatedMessage(s, "New message."); - ASSERT_EQ(s2.GetSourceLocations().size(), 3); -} - -#endif - -} // namespace tsl diff --git a/tsl/platform/file_statistics.h b/tsl/platform/file_statistics.h index ebe50be46..07bf908ed 100644 --- a/tsl/platform/file_statistics.h +++ b/tsl/platform/file_statistics.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,24 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FILE_STATISTICS_H_ #define TENSORFLOW_TSL_PLATFORM_FILE_STATISTICS_H_ -#include "tsl/platform/types.h" - -namespace tsl { - -struct FileStatistics { - // The length of the file or -1 if finding file length is not supported. - int64_t length = -1; - // The last modified time in nanoseconds. - int64_t mtime_nsec = 0; - // True if the file is a directory, otherwise false. - bool is_directory = false; - - FileStatistics() {} - FileStatistics(int64_t length, int64_t mtime_nsec, bool is_directory) - : length(length), mtime_nsec(mtime_nsec), is_directory(is_directory) {} - ~FileStatistics() {} -}; - -} // namespace tsl +#include "xla/tsl/platform/file_statistics.h" #endif // TENSORFLOW_TSL_PLATFORM_FILE_STATISTICS_H_ diff --git a/tsl/platform/file_system.cc b/tsl/platform/file_system.cc deleted file mode 100644 index 453e04b39..000000000 --- a/tsl/platform/file_system.cc +++ /dev/null @@ -1,507 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/file_system.h" - -#include - -#include -#include -#include -#include -#include - -#include "tsl/platform/status.h" - -#if defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) || \ - defined(PLATFORM_GOOGLE) -#include -#else -#include "tsl/platform/regexp.h" -#endif // defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) || \ - // defined(PLATFORM_GOOGLE) - -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/scanner.h" -#include "tsl/platform/str_util.h" -#include "tsl/platform/strcat.h" - -namespace tsl { - -bool FileSystem::Match(const string& filename, const string& pattern) { -#if defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) || \ - defined(PLATFORM_GOOGLE) - // We avoid relying on RE2 on mobile platforms, because it incurs a - // significant binary size increase. - // For POSIX platforms, there is no need to depend on RE2 if `fnmatch` can be - // used safely. - return fnmatch(pattern.c_str(), filename.c_str(), FNM_PATHNAME) == 0; -#else - string regexp(pattern); - regexp = str_util::StringReplace(regexp, "*", "[^/]*", true); - regexp = str_util::StringReplace(regexp, "?", ".", true); - regexp = str_util::StringReplace(regexp, "(", "\\(", true); - regexp = str_util::StringReplace(regexp, ")", "\\)", true); - return RE2::FullMatch(filename, regexp); -#endif // defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) || \ - // defined(PLATFORM_GOOGLE) -} - -string FileSystem::TranslateName(const string& name) const { - // If the name is empty, CleanPath returns "." which is incorrect and - // we should return the empty path instead. - if (name.empty()) return name; - - // Otherwise, properly separate the URI components and clean the path one - absl::string_view scheme, host, path; - this->ParseURI(name, &scheme, &host, &path); - - // If `path` becomes empty, return `/` (`file://` should be `/`), not `.`. - if (path.empty()) return "/"; - - return this->CleanPath(path); -} - -absl::Status FileSystem::IsDirectory(const string& name, - TransactionToken* token) { - // Check if path exists. - // TODO(sami):Forward token to other methods once migration is complete. - TF_RETURN_IF_ERROR(FileExists(name)); - FileStatistics stat; - TF_RETURN_IF_ERROR(Stat(name, &stat)); - if (stat.is_directory) { - return absl::OkStatus(); - } - return absl::Status(absl::StatusCode::kFailedPrecondition, "Not a directory"); -} - -absl::Status FileSystem::HasAtomicMove(const string& path, - bool* has_atomic_move) { - *has_atomic_move = true; - return absl::OkStatus(); -} - -absl::Status FileSystem::CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file) { - *can_create_temp_file = true; - return absl::OkStatus(); -} - -void FileSystem::FlushCaches(TransactionToken* token) {} - -bool FileSystem::FilesExist(const std::vector& files, - TransactionToken* token, - std::vector* status) { - bool result = true; - for (const auto& file : files) { - absl::Status s = FileExists(file); - result &= s.ok(); - if (status != nullptr) { - status->push_back(s); - } else if (!result) { - // Return early since there is no need to check other files. - return false; - } - } - return result; -} - -absl::Status FileSystem::DeleteRecursively(const string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) { - CHECK_NOTNULL(undeleted_files); - CHECK_NOTNULL(undeleted_dirs); - - *undeleted_files = 0; - *undeleted_dirs = 0; - // Make sure that dirname exists; - absl::Status exists_status = FileExists(dirname); - if (!exists_status.ok()) { - (*undeleted_dirs)++; - return exists_status; - } - - // If given path to a single file, we should just delete it. - if (!IsDirectory(dirname).ok()) { - absl::Status delete_root_status = DeleteFile(dirname); - if (!delete_root_status.ok()) (*undeleted_files)++; - return delete_root_status; - } - - std::deque dir_q; // Queue for the BFS - std::vector dir_list; // List of all dirs discovered - dir_q.push_back(dirname); - absl::Status ret; // Status to be returned. - // Do a BFS on the directory to discover all the sub-directories. Remove all - // children that are files along the way. Then cleanup and remove the - // directories in reverse order.; - while (!dir_q.empty()) { - string dir = dir_q.front(); - dir_q.pop_front(); - dir_list.push_back(dir); - std::vector children; - // GetChildren might fail if we don't have appropriate permissions. - absl::Status s = GetChildren(dir, &children); - ret.Update(s); - if (!s.ok()) { - (*undeleted_dirs)++; - continue; - } - for (const string& child : children) { - const string child_path = this->JoinPath(dir, child); - // If the child is a directory add it to the queue, otherwise delete it. - if (IsDirectory(child_path).ok()) { - dir_q.push_back(child_path); - } else { - // Delete file might fail because of permissions issues or might be - // unimplemented. - absl::Status del_status = DeleteFile(child_path); - ret.Update(del_status); - if (!del_status.ok()) { - (*undeleted_files)++; - } - } - } - } - // Now reverse the list of directories and delete them. The BFS ensures that - // we can delete the directories in this order. - std::reverse(dir_list.begin(), dir_list.end()); - for (const string& dir : dir_list) { - // Delete dir might fail because of permissions issues or might be - // unimplemented. - absl::Status s = DeleteDir(dir); - ret.Update(s); - if (!s.ok()) { - (*undeleted_dirs)++; - } - } - return ret; -} - -absl::Status FileSystem::RecursivelyCreateDir(const string& dirname, - TransactionToken* token) { - absl::string_view scheme, host, remaining_dir; - this->ParseURI(dirname, &scheme, &host, &remaining_dir); - std::vector sub_dirs; - while (!remaining_dir.empty()) { - std::string current_entry = this->CreateURI(scheme, host, remaining_dir); - absl::Status exists_status = FileExists(current_entry); - if (exists_status.ok()) { - // FileExists cannot differentiate between existence of a file or a - // directory, hence we need an additional test as we must not assume that - // a path to a file is a path to a parent directory. - absl::Status directory_status = IsDirectory(current_entry); - if (directory_status.ok()) { - break; // We need to start creating directories from here. - } else if (directory_status.code() == absl::StatusCode::kUnimplemented) { - return directory_status; - } else { - return errors::FailedPrecondition(remaining_dir, " is not a directory"); - } - } - if (exists_status.code() != error::Code::NOT_FOUND) { - return exists_status; - } - // Basename returns "" for / ending dirs. - if (!absl::EndsWith(remaining_dir, "/")) { - sub_dirs.push_back(this->Basename(remaining_dir)); - } - remaining_dir = this->Dirname(remaining_dir); - } - - // sub_dirs contains all the dirs to be created but in reverse order. - std::reverse(sub_dirs.begin(), sub_dirs.end()); - - // Now create the directories. - string built_path(remaining_dir); - for (const absl::string_view sub_dir : sub_dirs) { - built_path = this->JoinPath(built_path, sub_dir); - absl::Status status = CreateDir(this->CreateURI(scheme, host, built_path)); - if (!status.ok() && status.code() != absl::StatusCode::kAlreadyExists) { - return status; - } - } - return absl::OkStatus(); -} - -absl::Status FileSystem::CopyFile(const string& src, const string& target, - TransactionToken* token) { - return FileSystemCopyFile(this, src, this, target); -} - -char FileSystem::Separator() const { return '/'; } - -string FileSystem::JoinPathImpl( - std::initializer_list paths) { - string result; - - for (absl::string_view path : paths) { - if (path.empty()) continue; - - if (result.empty()) { - result = string(path); - continue; - } - - if (result[result.size() - 1] == '/') { - if (this->IsAbsolutePath(path)) { - strings::StrAppend(&result, path.substr(1)); - } else { - strings::StrAppend(&result, path); - } - } else { - if (this->IsAbsolutePath(path)) { - strings::StrAppend(&result, path); - } else { - strings::StrAppend(&result, "/", path); - } - } - } - - return result; -} - -std::pair FileSystem::SplitPath( - absl::string_view uri) const { - absl::string_view scheme, host, path; - ParseURI(uri, &scheme, &host, &path); - - // We have 3 cases of results from `ParseURI`: - // - // 1. `path` is empty (`uri` is something like http://google.com/) - // Here, we don't have anything to split, so return empty components - // - // 2. all 3 components are non-empty (`uri` is something like - // http://google.com/path/to/resource) - // Here, all 3 components point to elements inside the same buffer as - // `uri`. In the given example, `scheme` contains `http://`, `host` - // contains `google.com/` and `path` contains `path/to/resource`. - // Since all 3 components point to the same buffer, we can do arithmetic - // such as `host.end() - uri.begin()` because we know for sure that - // `host` starts after `uri`. - // - // 3. `scheme` and `host` are empty (`uri` is local file, like /etc/passwd) - // Here, we split `path`, but we need to be careful with pointer - // arithmetic. Here we only know that `path` and `uri` represent the - // exact same buffer. - // - // To summarize, if `path` is empty there is nothing to return, in all other - // cases we can do arithmetic involving `path` and `uri` but if - // `host`/`scheme` are involved we need to make sure these are not empty. - - // Case 1 above - if (path.empty()) { - return std::make_pair(absl::string_view(), absl::string_view()); - } - - size_t pos = path.rfind(this->Separator()); - - // Our code assumes it is written for linux too many times. So, for windows - // also check for '/' -#ifdef PLATFORM_WINDOWS - size_t pos2 = path.rfind('/'); - // Pick the max value that is not string::npos. - if (pos == string::npos) { - pos = pos2; - } else { - if (pos2 != string::npos) { - pos = pos > pos2 ? pos : pos2; - } - } -#endif - - // Handle the case with no SEP in 'path'. - if (pos == absl::string_view::npos) { - if (host.empty()) { - // Case 3 above, `uri` and `path` point to the same thing - // We are returning all of the `path` as basename here. - return std::make_pair(absl::string_view(), path); - } - - // Safe to do this arithmetic here, we are in case 2 above - return std::make_pair( - absl::string_view(uri.data(), host.end() - uri.begin()), path); - } - - // Handle the case with a single leading '/' in 'path'. - if (pos == 0) { - return std::make_pair( - absl::string_view(uri.data(), path.begin() + 1 - uri.begin()), - absl::string_view(path.data() + 1, path.size() - 1)); - } - - return std::make_pair( - absl::string_view(uri.data(), path.begin() + pos - uri.begin()), - absl::string_view(path.data() + pos + 1, path.size() - (pos + 1))); -} - -bool FileSystem::IsAbsolutePath(absl::string_view path) const { - return !path.empty() && path[0] == '/'; -} - -absl::string_view FileSystem::Dirname(absl::string_view path) const { - return this->SplitPath(path).first; -} - -absl::string_view FileSystem::Basename(absl::string_view path) const { - return this->SplitPath(path).second; -} - -absl::string_view FileSystem::Extension(absl::string_view path) const { - absl::string_view basename = this->Basename(path); - - size_t pos = basename.rfind('.'); - if (pos == absl::string_view::npos) { - return absl::string_view(path.data() + path.size(), 0); - } else { - return absl::string_view(path.data() + pos + 1, path.size() - (pos + 1)); - } -} - -string FileSystem::CleanPath(absl::string_view unclean_path) const { - string path(unclean_path); - const char* src = path.c_str(); - string::iterator dst = path.begin(); - - // Check for absolute path and determine initial backtrack limit. - const bool is_absolute_path = *src == '/'; - if (is_absolute_path) { - *dst++ = *src++; - while (*src == '/') ++src; - } - string::const_iterator backtrack_limit = dst; - - // Process all parts - while (*src) { - bool parsed = false; - - if (src[0] == '.') { - // 1dot ".", check for END or SEP. - if (src[1] == '/' || !src[1]) { - if (*++src) { - ++src; - } - parsed = true; - } else if (src[1] == '.' && (src[2] == '/' || !src[2])) { - // 2dot END or SEP (".." | "../"). - src += 2; - if (dst != backtrack_limit) { - // We can backtrack the previous part - for (--dst; dst != backtrack_limit && dst[-1] != '/'; --dst) { - // Empty. - } - } else if (!is_absolute_path) { - // Failed to backtrack and we can't skip it either. Rewind and copy. - src -= 2; - *dst++ = *src++; - *dst++ = *src++; - if (*src) { - *dst++ = *src; - } - // We can never backtrack over a copied "../" part so set new limit. - backtrack_limit = dst; - } - if (*src) { - ++src; - } - parsed = true; - } - } - - // If not parsed, copy entire part until the next SEP or EOS. - if (!parsed) { - while (*src && *src != '/') { - *dst++ = *src++; - } - if (*src) { - *dst++ = *src++; - } - } - - // Skip consecutive SEP occurrences - while (*src == '/') { - ++src; - } - } - - // Calculate and check the length of the cleaned path. - string::difference_type path_length = dst - path.begin(); - if (path_length != 0) { - // Remove trailing '/' except if it is root path ("/" ==> path_length := 1) - if (path_length > 1 && path[path_length - 1] == '/') { - --path_length; - } - path.resize(path_length); - } else { - // The cleaned path is empty; assign "." as per the spec. - path.assign(1, '.'); - } - return path; -} - -void FileSystem::ParseURI(absl::string_view remaining, - absl::string_view* scheme, absl::string_view* host, - absl::string_view* path) const { - // 0. Parse scheme - // Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]* - // TODO(keveman): Allow "+" and "-" in the scheme. - // Keep URI pattern in tensorboard/backend/server.py updated accordingly - if (!strings::Scanner(remaining) - .One(strings::Scanner::LETTER) - .Many(strings::Scanner::LETTER_DIGIT_DOT) - .StopCapture() - .OneLiteral("://") - .GetResult(&remaining, scheme)) { - // If there's no scheme, assume the entire string is a path. - *scheme = absl::string_view(); - *host = absl::string_view(); - *path = remaining; - return; - } - - // 1. Parse host - if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) { - // No path, so the rest of the URI is the host. - *host = remaining; - *path = absl::string_view(); - return; - } - - // 2. The rest is the path - *path = remaining; -} - -string FileSystem::CreateURI(absl::string_view scheme, absl::string_view host, - absl::string_view path) const { - if (scheme.empty()) { - return string(path); - } - return strings::StrCat(scheme, "://", host, path); -} - -std::string FileSystem::DecodeTransaction(const TransactionToken* token) { - // TODO(sami): Switch using StrCat when void* is supported - if (token) { - std::stringstream oss; - oss << "Token= " << token->token << ", Owner=" << token->owner; - return oss.str(); - } - return "No Transaction"; -} - -} // namespace tsl diff --git a/tsl/platform/file_system.h b/tsl/platform/file_system.h index 8b4878826..8d55471a5 100644 --- a/tsl/platform/file_system.h +++ b/tsl/platform/file_system.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,921 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ #define TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ -#include - -#include -#include -#include -#include -#include -#include - -#include "tsl/platform/cord.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_statistics.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" - -#ifdef PLATFORM_WINDOWS -#undef DeleteFile -#undef CopyFile -#undef TranslateName -#endif - -namespace tsl { - -class FileAcl; -class RandomAccessFile; -class ReadOnlyMemoryRegion; -class WritableFile; - -class FileSystem; -struct TransactionToken { - FileSystem* owner; - void* token; -}; - -/// A generic interface for accessing a file system. Implementations -/// of custom filesystem adapters must implement this interface, -/// RandomAccessFile, WritableFile, and ReadOnlyMemoryRegion classes. -class FileSystem { - public: - /// \brief Creates a brand new random access read-only file with the - /// specified name. - /// - /// On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. If the file does not exist, returns a non-OK - /// status. - /// - /// The returned file may be concurrently accessed by multiple threads. - /// - /// The ownership of the returned RandomAccessFile is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewRandomAccessFile( - const std::string& fname, std::unique_ptr* result) { - return NewRandomAccessFile(fname, nullptr, result); - } - - virtual absl::Status NewRandomAccessFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - // We duplicate these methods due to Google internal coding style prevents - // virtual functions with default arguments. See PR #41615. - return absl::OkStatus(); - } - - /// \brief Creates an object that writes to a new file with the specified - /// name. - /// - /// Deletes any existing file with the same name and creates a - /// new file. On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewWritableFile(const std::string& fname, - std::unique_ptr* result) { - return NewWritableFile(fname, nullptr, result); - } - - virtual absl::Status NewWritableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// \brief Creates an object that either appends to an existing file, or - /// writes to a new file (if the file does not exist to begin with). - /// - /// On success, stores a pointer to the new file in *result and - /// returns OK. On failure stores NULL in *result and returns - /// non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewAppendableFile( - const std::string& fname, std::unique_ptr* result) { - return NewAppendableFile(fname, nullptr, result); - } - - virtual absl::Status NewAppendableFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// \brief Creates a readonly region of memory with the file context. - /// - /// On success, it returns a pointer to read-only memory region - /// from the content of file fname. The ownership of the region is passed to - /// the caller. On failure stores nullptr in *result and returns non-OK. - /// - /// The returned memory region can be accessed from many threads in parallel. - /// - /// The ownership of the returned ReadOnlyMemoryRegion is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, std::unique_ptr* result) { - return NewReadOnlyMemoryRegionFromFile(fname, nullptr, result); - } - - virtual absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// Returns OK if the named path exists and NOT_FOUND otherwise. - virtual absl::Status FileExists(const std::string& fname) { - return FileExists(fname, nullptr); - } - - virtual absl::Status FileExists(const std::string& fname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// Returns true if all the listed files exist, false otherwise. - /// if status is not null, populate the vector with a detailed status - /// for each file. - virtual bool FilesExist(const std::vector& files, - std::vector* status) { - return FilesExist(files, nullptr, status); - } - - virtual bool FilesExist(const std::vector& files, - TransactionToken* token, - std::vector* status); - - /// \brief Returns the immediate children in the given directory. - /// - /// The returned paths are relative to 'dir'. - virtual absl::Status GetChildren(const std::string& dir, - std::vector* result) { - return GetChildren(dir, nullptr, result); - } - - virtual absl::Status GetChildren(const std::string& dir, - TransactionToken* token, - std::vector* result) { - return absl::OkStatus(); - } - - /// \brief Given a pattern, stores in *results the set of paths that matches - /// that pattern. *results is cleared. - /// - /// pattern must match all of a name, not just a substring. - /// - /// pattern: { term } - /// term: - /// '*': matches any sequence of non-'/' characters - /// '?': matches a single non-'/' character - /// '[' [ '^' ] { match-list } ']': - /// matches any single character (not) on the list - /// c: matches character c (c != '*', '?', '\\', '[') - /// '\\' c: matches character c - /// character-range: - /// c: matches character c (c != '\\', '-', ']') - /// '\\' c: matches character c - /// lo '-' hi: matches character c for lo <= c <= hi - /// - /// Typical return codes: - /// * OK - no errors - /// * UNIMPLEMENTED - Some underlying functions (like GetChildren) are not - /// implemented - virtual absl::Status GetMatchingPaths(const std::string& pattern, - std::vector* results) { - return GetMatchingPaths(pattern, nullptr, results); - } - - virtual absl::Status GetMatchingPaths(const std::string& pattern, - TransactionToken* token, - std::vector* results) { - return absl::OkStatus(); - } - - /// \brief Checks if the given filename matches the pattern. - /// - /// This function provides the equivalent of posix fnmatch, however it is - /// implemented without fnmatch to ensure that this can be used for cloud - /// filesystems on windows. For windows filesystems, it uses PathMatchSpec. - virtual bool Match(const std::string& filename, const std::string& pattern); - - /// \brief Obtains statistics for the given path. - virtual absl::Status Stat(const std::string& fname, FileStatistics* stat) { - return Stat(fname, nullptr, stat); - } - - virtual absl::Status Stat(const std::string& fname, TransactionToken* token, - FileStatistics* stat) { - return absl::OkStatus(); - } - - /// \brief Deletes the named file. - virtual absl::Status DeleteFile(const std::string& fname) { - return DeleteFile(fname, nullptr); - } - - virtual absl::Status DeleteFile(const std::string& fname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Creates the specified directory. - /// Typical return codes: - /// * OK - successfully created the directory. - /// * ALREADY_EXISTS - directory with name dirname already exists. - /// * PERMISSION_DENIED - dirname is not writable. - virtual absl::Status CreateDir(const std::string& dirname) { - return CreateDir(dirname, nullptr); - } - - virtual absl::Status CreateDir(const std::string& dirname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Creates the specified directory and all the necessary - /// subdirectories. - /// Typical return codes: - /// * OK - successfully created the directory and sub directories, even if - /// they were already created. - /// * PERMISSION_DENIED - dirname or some subdirectory is not writable. - virtual absl::Status RecursivelyCreateDir(const std::string& dirname) { - return RecursivelyCreateDir(dirname, nullptr); - } - - virtual absl::Status RecursivelyCreateDir(const std::string& dirname, - TransactionToken* token); - - /// \brief Deletes the specified directory. - virtual absl::Status DeleteDir(const std::string& dirname) { - return DeleteDir(dirname, nullptr); - } - - virtual absl::Status DeleteDir(const std::string& dirname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Deletes the specified directory and all subdirectories and files - /// underneath it. This is accomplished by traversing the directory tree - /// rooted at dirname and deleting entries as they are encountered. - /// - /// If dirname itself is not readable or does not exist, *undeleted_dir_count - /// is set to 1, *undeleted_file_count is set to 0 and an appropriate status - /// (e.g. NOT_FOUND) is returned. - /// - /// If dirname and all its descendants were successfully deleted, TF_OK is - /// returned and both error counters are set to zero. - /// - /// Otherwise, while traversing the tree, undeleted_file_count and - /// undeleted_dir_count are updated if an entry of the corresponding type - /// could not be deleted. The returned error status represents the reason that - /// any one of these entries could not be deleted. - /// - /// REQUIRES: undeleted_files, undeleted_dirs to be not null. - /// - /// Typical return codes: - /// * OK - dirname exists and we were able to delete everything underneath. - /// * NOT_FOUND - dirname doesn't exist - /// * PERMISSION_DENIED - dirname or some descendant is not writable - /// * UNIMPLEMENTED - Some underlying functions (like Delete) are not - /// implemented - virtual absl::Status DeleteRecursively(const std::string& dirname, - int64_t* undeleted_files, - int64_t* undeleted_dirs) { - return DeleteRecursively(dirname, nullptr, undeleted_files, undeleted_dirs); - } - - virtual absl::Status DeleteRecursively(const std::string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs); - - /// \brief Stores the size of `fname` in `*file_size`. - virtual absl::Status GetFileSize(const std::string& fname, - uint64* file_size) { - return GetFileSize(fname, nullptr, file_size); - } - - virtual absl::Status GetFileSize(const std::string& fname, - TransactionToken* token, uint64* file_size) { - return absl::OkStatus(); - } - - /// \brief Overwrites the target if it exists. - virtual absl::Status RenameFile(const std::string& src, - const std::string& target) { - return RenameFile(src, target, nullptr); - } - - virtual absl::Status RenameFile(const std::string& src, - const std::string& target, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Copy the src to target. - virtual absl::Status CopyFile(const std::string& src, - const std::string& target) { - return CopyFile(src, target, nullptr); - } - - virtual absl::Status CopyFile(const std::string& src, - const std::string& target, - TransactionToken* token); - - /// \brief Translate an URI to a filename for the FileSystem implementation. - /// - /// The implementation in this class cleans up the path, removing - /// duplicate /'s, resolving .. and removing trailing '/'. - /// This respects relative vs. absolute paths, but does not - /// invoke any system calls (getcwd(2)) in order to resolve relative - /// paths with respect to the actual working directory. That is, this is - /// purely string manipulation, completely independent of process state. - virtual std::string TranslateName(const std::string& name) const; - - /// \brief Returns whether the given path is a directory or not. - /// - /// Typical return codes (not guaranteed exhaustive): - /// * OK - The path exists and is a directory. - /// * FAILED_PRECONDITION - The path exists and is not a directory. - /// * NOT_FOUND - The path entry does not exist. - /// * PERMISSION_DENIED - Insufficient permissions. - /// * UNIMPLEMENTED - The file factory doesn't support directories. - virtual absl::Status IsDirectory(const std::string& fname) { - return IsDirectory(fname, nullptr); - } - - virtual absl::Status IsDirectory(const std::string& fname, - TransactionToken* token); - - /// \brief Returns whether the given path is on a file system - /// that has atomic move capabilities. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// The second boolean argument has_atomic_move contains this information. - /// - /// Returns one of the following status codes (not guaranteed exhaustive): - /// * OK - The path is on a recognized file system, - /// so has_atomic_move holds the above information. - /// * UNIMPLEMENTED - The file system of the path hasn't been implemented in - /// TF - virtual absl::Status HasAtomicMove(const std::string& path, - bool* has_atomic_move); - - /// Returns whether the give path is on a file system - /// that has ability to create a new temp file. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// If the file system cannot create a temp file, it's possibile that - /// uncomplete result may appear in the given file. - virtual absl::Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); - - /// \brief Flushes any cached filesystem objects from memory. - virtual void FlushCaches() { FlushCaches(nullptr); } - - virtual void FlushCaches(TransactionToken* token); - - /// \brief The separator this filesystem uses. - /// - /// This is implemented as a part of the filesystem, because even on windows, - /// a user may need access to filesystems with '/' separators, such as cloud - /// filesystems. - virtual char Separator() const; - - /// \brief Split a path to its basename and dirname. - /// - /// Helper function for Basename and Dirname. - std::pair SplitPath( - absl::string_view uri) const; - - /// \brief returns the final file name in the given path. - /// - /// Returns the part of the path after the final "/". If there is no - /// "/" in the path, the result is the same as the input. - virtual absl::string_view Basename(absl::string_view path) const; - - /// \brief Returns the part of the path before the final "/". - /// - /// If there is a single leading "/" in the path, the result will be the - /// leading "/". If there is no "/" in the path, the result is the empty - /// prefix of the input. - absl::string_view Dirname(absl::string_view path) const; - - /// \brief Returns the part of the basename of path after the final ".". - /// - /// If there is no "." in the basename, the result is empty. - absl::string_view Extension(absl::string_view path) const; - - /// \brief Clean duplicate and trailing, "/"s, and resolve ".." and ".". - /// - /// NOTE: This respects relative vs. absolute paths, but does not - /// invoke any system calls (getcwd(2)) in order to resolve relative - /// paths with respect to the actual working directory. That is, this is - /// purely string manipulation, completely independent of process state. - std::string CleanPath(absl::string_view path) const; - - /// \brief Creates a URI from a scheme, host, and path. - /// - /// If the scheme is empty, we just return the path. - std::string CreateURI(absl::string_view scheme, absl::string_view host, - absl::string_view path) const; - - /// \brief Return true if path is absolute. - bool IsAbsolutePath(absl::string_view path) const; - -#ifndef SWIG // variadic templates - /// \brief Join multiple paths together. - /// - /// This function also removes the unnecessary path separators. - /// For example: - /// - /// Arguments | JoinPath - /// ---------------------------+---------- - /// '/foo', 'bar' | /foo/bar - /// '/foo/', 'bar' | /foo/bar - /// '/foo', '/bar' | /foo/bar - /// - /// Usage: - /// string path = io::JoinPath("/mydir", filename); - /// string path = io::JoinPath(FLAGS_test_srcdir, filename); - /// string path = io::JoinPath("/full", "path", "to", "filename"); - template - std::string JoinPath(const T&... args) { - return JoinPathImpl({args...}); - } -#endif /* SWIG */ - - std::string JoinPathImpl(std::initializer_list paths); - - /// \brief Populates the scheme, host, and path from a URI. - /// - /// scheme, host, and path are guaranteed by this function to point into the - /// contents of uri, even if empty. - /// - /// Corner cases: - /// - If the URI is invalid, scheme and host are set to empty strings and the - /// passed string is assumed to be a path - /// - If the URI omits the path (e.g. file://host), then the path is left - /// empty. - void ParseURI(absl::string_view remaining, absl::string_view* scheme, - absl::string_view* host, absl::string_view* path) const; - - // Transaction related API - - /// \brief Starts a new transaction - virtual absl::Status StartTransaction(TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Adds `path` to transaction in `token` - virtual absl::Status AddToTransaction(const std::string& path, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Ends transaction - virtual absl::Status EndTransaction(TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Get token for `path` or start a new transaction and add `path` to - /// it. - virtual absl::Status GetTokenOrStartTransaction(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Return transaction for `path` or nullptr in `token` - virtual absl::Status GetTransactionForPath(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Decode transaction to human readable string. - virtual std::string DecodeTransaction(const TransactionToken* token); - - /// \brief Set File System Configuration Options - virtual absl::Status SetOption(const string& key, const string& value) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System Configuration Option - virtual absl::Status SetOption(const std::string& name, - const std::vector& values) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System Configuration Option - virtual absl::Status SetOption(const std::string& name, - const std::vector& values) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System Configuration Option - virtual absl::Status SetOption(const std::string& name, - const std::vector& values) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System ACL checker. - /// - /// No checks are enforced if a FileAcl is never set. - virtual absl::Status SetFileAcl(std::shared_ptr file_acl) { - return errors::Unimplemented("SetFileAcl"); - } - - FileSystem() {} - - virtual ~FileSystem() = default; -}; -/// This macro adds forwarding methods from FileSystem class to -/// used class since name hiding will prevent these to be accessed from -/// derived classes and would require all use locations to migrate to -/// Transactional API. This is an interim solution until ModularFileSystem class -/// becomes a singleton. -// TODO(sami): Remove this macro when filesystem plugins migration is complete. -#define TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT \ - using FileSystem::NewRandomAccessFile; \ - using FileSystem::NewWritableFile; \ - using FileSystem::NewAppendableFile; \ - using FileSystem::NewReadOnlyMemoryRegionFromFile; \ - using FileSystem::FileExists; \ - using FileSystem::GetChildren; \ - using FileSystem::GetMatchingPaths; \ - using FileSystem::Stat; \ - using FileSystem::DeleteFile; \ - using FileSystem::RecursivelyCreateDir; \ - using FileSystem::DeleteDir; \ - using FileSystem::DeleteRecursively; \ - using FileSystem::GetFileSize; \ - using FileSystem::RenameFile; \ - using FileSystem::CopyFile; \ - using FileSystem::IsDirectory; \ - using FileSystem::FlushCaches - -/// A Wrapper class for Transactional FileSystem support. -/// This provides means to make use of the transactions with minimal code change -/// Any operations that are done through this interface will be through the -/// transaction created at the time of construction of this instance. -/// See FileSystem documentation for method descriptions. -/// This class simply forwards all calls to wrapped filesystem either with given -/// transaction token or with token used in its construction. This allows doing -/// transactional filesystem access with minimal code change. -class WrappedFileSystem : public FileSystem { - public: - TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; - - absl::Status NewRandomAccessFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewRandomAccessFile(fname, (token ? token : token_), result); - } - - absl::Status NewWritableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewWritableFile(fname, (token ? token : token_), result); - } - - absl::Status NewAppendableFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewAppendableFile(fname, (token ? token : token_), result); - } - - absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewReadOnlyMemoryRegionFromFile(fname, (token ? token : token_), - result); - } - - absl::Status FileExists(const std::string& fname, - TransactionToken* token) override { - return fs_->FileExists(fname, (token ? token : token_)); - } - - bool FilesExist(const std::vector& files, TransactionToken* token, - std::vector* status) override { - return fs_->FilesExist(files, (token ? token : token_), status); - } - - absl::Status GetChildren(const std::string& dir, TransactionToken* token, - std::vector* result) override { - return fs_->GetChildren(dir, (token ? token : token_), result); - } - - absl::Status GetMatchingPaths(const std::string& pattern, - TransactionToken* token, - std::vector* results) override { - return fs_->GetMatchingPaths(pattern, (token ? token : token_), results); - } - - bool Match(const std::string& filename, const std::string& pattern) override { - return fs_->Match(filename, pattern); - } - - absl::Status Stat(const std::string& fname, TransactionToken* token, - FileStatistics* stat) override { - return fs_->Stat(fname, (token ? token : token_), stat); - } - - absl::Status DeleteFile(const std::string& fname, - TransactionToken* token) override { - return fs_->DeleteFile(fname, (token ? token : token_)); - } - - absl::Status CreateDir(const std::string& dirname, - TransactionToken* token) override { - return fs_->CreateDir(dirname, (token ? token : token_)); - } - - absl::Status RecursivelyCreateDir(const std::string& dirname, - TransactionToken* token) override { - return fs_->RecursivelyCreateDir(dirname, (token ? token : token_)); - } - - absl::Status DeleteDir(const std::string& dirname, - TransactionToken* token) override { - return fs_->DeleteDir(dirname, (token ? token : token_)); - } - - absl::Status DeleteRecursively(const std::string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) override { - return fs_->DeleteRecursively(dirname, (token ? token : token_), - undeleted_files, undeleted_dirs); - } - - absl::Status GetFileSize(const std::string& fname, TransactionToken* token, - uint64* file_size) override { - return fs_->GetFileSize(fname, (token ? token : token_), file_size); - } - - absl::Status RenameFile(const std::string& src, const std::string& target, - TransactionToken* token) override { - return fs_->RenameFile(src, target, (token ? token : token_)); - } - - absl::Status CopyFile(const std::string& src, const std::string& target, - TransactionToken* token) override { - return fs_->CopyFile(src, target, (token ? token : token_)); - } - - std::string TranslateName(const std::string& name) const override { - return fs_->TranslateName(name); - } - - absl::Status IsDirectory(const std::string& fname, - TransactionToken* token) override { - return fs_->IsDirectory(fname, (token ? token : token_)); - } - - absl::Status HasAtomicMove(const std::string& path, - bool* has_atomic_move) override { - return fs_->HasAtomicMove(path, has_atomic_move); - } - - void FlushCaches(TransactionToken* token) override { - return fs_->FlushCaches((token ? token : token_)); - } - - char Separator() const override { return fs_->Separator(); } - - absl::string_view Basename(absl::string_view path) const override { - return fs_->Basename(path); - } - - absl::Status StartTransaction(TransactionToken** token) override { - return fs_->StartTransaction(token); - } - - absl::Status AddToTransaction(const std::string& path, - TransactionToken* token) override { - return fs_->AddToTransaction(path, (token ? token : token_)); - } - - absl::Status EndTransaction(TransactionToken* token) override { - return fs_->EndTransaction(token); - } - - absl::Status GetTransactionForPath(const std::string& path, - TransactionToken** token) override { - return fs_->GetTransactionForPath(path, token); - } - - absl::Status GetTokenOrStartTransaction(const std::string& path, - TransactionToken** token) override { - return fs_->GetTokenOrStartTransaction(path, token); - } - - std::string DecodeTransaction(const TransactionToken* token) override { - return fs_->DecodeTransaction((token ? token : token_)); - } - - WrappedFileSystem(FileSystem* file_system, TransactionToken* token) - : fs_(file_system), token_(token) {} - - ~WrappedFileSystem() override = default; - - private: - FileSystem* fs_; - TransactionToken* token_; -}; - -/// A file abstraction for randomly reading the contents of a file. -class RandomAccessFile { - public: - RandomAccessFile() {} - virtual ~RandomAccessFile() = default; - - /// \brief Returns the name of the file. - /// - /// This is an optional operation that may not be implemented by every - /// filesystem. - virtual absl::Status Name(absl::string_view* result) const { - return errors::Unimplemented("This filesystem does not support Name()"); - } - - /// \brief Reads up to `n` bytes from the file starting at `offset`. - /// - /// `scratch[0..n-1]` may be written by this routine. Sets `*result` - /// to the data that was read (including if fewer than `n` bytes were - /// successfully read). May set `*result` to point at data in - /// `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when - /// `*result` is used. - /// - /// On OK returned status: `n` bytes have been stored in `*result`. - /// On non-OK returned status: `[0..n]` bytes have been stored in `*result`. - /// - /// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result` - /// because of EOF. - /// - /// Safe for concurrent use by multiple threads. - virtual absl::Status Read(uint64 offset, size_t n, absl::string_view* result, - char* scratch) const = 0; - -#if defined(TF_CORD_SUPPORT) - /// \brief Read up to `n` bytes from the file starting at `offset`. - virtual absl::Status Read(uint64 offset, size_t n, absl::Cord* cord) const { - return errors::Unimplemented( - "Read(uint64, size_t, absl::Cord*) is not " - "implemented"); - } -#endif - - private: - RandomAccessFile(const RandomAccessFile&) = delete; - void operator=(const RandomAccessFile&) = delete; -}; - -/// \brief A file abstraction for sequential writing. -/// -/// The implementation must provide buffering since callers may append -/// small fragments at a time to the file. -class WritableFile { - public: - WritableFile() {} - virtual ~WritableFile() = default; - - /// \brief Append 'data' to the file. - virtual absl::Status Append(absl::string_view data) = 0; - -#if defined(TF_CORD_SUPPORT) - // \brief Append 'data' to the file. - virtual absl::Status Append(const absl::Cord& cord) { - for (absl::string_view chunk : cord.Chunks()) { - TF_RETURN_IF_ERROR(Append(chunk)); - } - return absl::OkStatus(); - } -#endif - - /// \brief Close the file. - /// - /// Flush() and de-allocate resources associated with this file - /// - /// Typical return codes (not guaranteed to be exhaustive): - /// * OK - /// * Other codes, as returned from Flush() - virtual absl::Status Close() = 0; - - /// \brief Flushes the file and optionally syncs contents to filesystem. - /// - /// This should flush any local buffers whose contents have not been - /// delivered to the filesystem. - /// - /// If the process terminates after a successful flush, the contents - /// may still be persisted, since the underlying filesystem may - /// eventually flush the contents. If the OS or machine crashes - /// after a successful flush, the contents may or may not be - /// persisted, depending on the implementation. - virtual absl::Status Flush() = 0; - - // \brief Returns the name of the file. - /// - /// This is an optional operation that may not be implemented by every - /// filesystem. - virtual absl::Status Name(absl::string_view* result) const { - return errors::Unimplemented("This filesystem does not support Name()"); - } - - /// \brief Syncs contents of file to filesystem. - /// - /// This waits for confirmation from the filesystem that the contents - /// of the file have been persisted to the filesystem; if the OS - /// or machine crashes after a successful Sync, the contents should - /// be properly saved. - virtual absl::Status Sync() = 0; - - /// \brief Retrieves the current write position in the file, or -1 on - /// error. - /// - /// This is an optional operation, subclasses may choose to return - /// errors::Unimplemented. - virtual absl::Status Tell(int64_t* position) { - *position = -1; - return errors::Unimplemented("This filesystem does not support Tell()"); - } - - private: - WritableFile(const WritableFile&) = delete; - void operator=(const WritableFile&) = delete; -}; - -/// \brief A readonly memmapped file abstraction. -/// -/// The implementation must guarantee that all memory is accessible when the -/// object exists, independently from the Env that created it. -class ReadOnlyMemoryRegion { - public: - ReadOnlyMemoryRegion() {} - virtual ~ReadOnlyMemoryRegion() = default; - - /// \brief Returns a pointer to the memory region. - virtual const void* data() = 0; - - /// \brief Returns the length of the memory region in bytes. - virtual uint64 length() = 0; -}; - -/// \brief A registry for file system implementations. -/// -/// Filenames are specified as an URI, which is of the form -/// [scheme://]. -/// File system implementations are registered using the REGISTER_FILE_SYSTEM -/// macro, providing the 'scheme' as the key. -/// -/// There are two `Register` methods: one using `Factory` for legacy filesystems -/// (deprecated mechanism of subclassing `FileSystem` and using -/// `REGISTER_FILE_SYSTEM` macro), and one using `std::unique_ptr` -/// for the new modular approach. -/// -/// Note that the new API expects a pointer to `ModularFileSystem` but this is -/// not checked as there should be exactly one caller to the API and doing the -/// check results in a circular dependency between `BUILD` targets. -/// -/// Plan is to completely remove the filesystem registration from `Env` and -/// incorporate it into `ModularFileSystem` class (which will be renamed to be -/// the only `FileSystem` class and marked as `final`). But this will happen at -/// a later time, after we convert all filesystems to the new API. -/// -/// TODO(b/139060984): After all filesystems are converted, remove old -/// registration and update comment. -class FileSystemRegistry { - public: - typedef std::function Factory; - - virtual ~FileSystemRegistry() = default; - virtual absl::Status Register(const std::string& scheme, Factory factory) = 0; - virtual absl::Status Register(const std::string& scheme, - std::unique_ptr filesystem) = 0; - virtual FileSystem* Lookup(const std::string& scheme) = 0; - virtual absl::Status GetRegisteredFileSystemSchemes( - std::vector* schemes) = 0; -}; - -/// \brief An abstraction for enforcing ACL checks in FileSystem. -class FileAcl { - public: - virtual absl::Status CheckAccess(std::string_view path) = 0; - virtual ~FileAcl() = default; -}; - -} // namespace tsl +#include "xla/tsl/platform/file_system.h" #endif // TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ diff --git a/tsl/platform/file_system_helper.cc b/tsl/platform/file_system_helper.cc deleted file mode 100644 index bfbea9808..000000000 --- a/tsl/platform/file_system_helper.cc +++ /dev/null @@ -1,280 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/file_system_helper.h" - -#include -#include -#include - -#include "tsl/platform/cpu_info.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/path.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/status.h" -#include "tsl/platform/str_util.h" -#include "tsl/platform/threadpool.h" - -namespace tsl { -namespace internal { - -namespace { - -const int kNumThreads = port::NumSchedulableCPUs(); - -// Run a function in parallel using a ThreadPool, but skip the ThreadPool -// on the iOS platform due to its problems with more than a few threads. -void ForEach(int first, int last, const std::function& f) { -#if TARGET_OS_IPHONE - for (int i = first; i < last; i++) { - f(i); - } -#else - int num_threads = std::min(kNumThreads, last - first); - thread::ThreadPool threads(Env::Default(), "ForEach", num_threads); - for (int i = first; i < last; i++) { - threads.Schedule([f, i] { f(i); }); - } -#endif -} - -// A globbing pattern can only start with these characters: -static const char kGlobbingChars[] = "*?[\\"; - -static inline bool IsGlobbingPattern(const std::string& pattern) { - return (pattern.find_first_of(kGlobbingChars) != std::string::npos); -} - -// Make sure that the first entry in `dirs` during glob expansion does not -// contain a glob pattern. This is to prevent a corner-case bug where -// `` would be treated differently than `./`. -static std::string PatchPattern(const std::string& pattern) { - const std::string fixed_prefix = - pattern.substr(0, pattern.find_first_of(kGlobbingChars)); - - // Patching is needed when there is no directory part in `prefix` - if (io::Dirname(fixed_prefix).empty()) { - return io::JoinPath(".", pattern); - } - - // No patching needed - return pattern; -} - -static std::vector AllDirectoryPrefixes(const std::string& d) { - std::vector dirs; - const std::string patched = PatchPattern(d); - absl::string_view dir(patched); - - // If the pattern ends with a `/` (or `\\` on Windows), we need to strip it - // otherwise we would have one additional matching step and the result set - // would be empty. - bool is_directory = d[d.size() - 1] == '/'; -#ifdef PLATFORM_WINDOWS - is_directory = is_directory || (d[d.size() - 1] == '\\'); -#endif - if (is_directory) { - dir = io::Dirname(dir); - } - - while (!dir.empty()) { - dirs.emplace_back(dir); - absl::string_view new_dir(io::Dirname(dir)); - // io::Dirname("/") returns "/" so we need to break the loop. - // On Windows, io::Dirname("C:\\") would return "C:\\", so we check for - // identity of the result instead of checking for dir[0] == `/`. - if (dir == new_dir) break; - dir = new_dir; - } - - // Order the array from parent to ancestor (reverse order). - std::reverse(dirs.begin(), dirs.end()); - - return dirs; -} - -static inline int GetFirstGlobbingEntry(const std::vector& dirs) { - int i = 0; - for (const auto& d : dirs) { - if (IsGlobbingPattern(d)) { - break; - } - i++; - } - return i; -} - -} // namespace - -absl::Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, - std::vector* results) { - // Check that `fs`, `env` and `results` are non-null. - if (fs == nullptr || env == nullptr || results == nullptr) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - "Filesystem calls GetMatchingPaths with nullptr arguments"); - } - - // By design, we don't match anything on empty pattern - results->clear(); - if (pattern.empty()) { - return absl::OkStatus(); - } - - // The pattern can contain globbing characters at multiple levels, e.g.: - // - // foo/ba?/baz/f*r - // - // To match the full pattern, we must match every prefix subpattern and then - // operate on the children for each match. Thus, we separate all subpatterns - // in the `dirs` vector below. - std::vector dirs = AllDirectoryPrefixes(pattern); - - // We can have patterns that have several parents where no globbing is being - // done, for example, `foo/bar/baz/*`. We don't need to expand the directories - // which don't contain the globbing characters. - int matching_index = GetFirstGlobbingEntry(dirs); - - // If we don't have globbing characters in the pattern then it specifies a - // path in the filesystem. We add it to the result set if it exists. - if (matching_index == dirs.size()) { - if (fs->FileExists(pattern).ok()) { - results->emplace_back(pattern); - } - return absl::OkStatus(); - } - - // To expand the globbing, we do a BFS from `dirs[matching_index-1]`. - // At every step, we work on a pair `{dir, ix}` such that `dir` is a real - // directory, `ix < dirs.size() - 1` and `dirs[ix+1]` is a globbing pattern. - // To expand the pattern, we select from all the children of `dir` only those - // that match against `dirs[ix+1]`. - // If there are more entries in `dirs` after `dirs[ix+1]` this mean we have - // more patterns to match. So, we add to the queue only those children that - // are also directories, paired with `ix+1`. - // If there are no more entries in `dirs`, we return all children as part of - // the answer. - // Since we can get into a combinatorial explosion issue (e.g., pattern - // `/*/*/*`), we process the queue in parallel. Each parallel processing takes - // elements from `expand_queue` and adds them to `next_expand_queue`, after - // which we swap these two queues (similar to double buffering algorithms). - // PRECONDITION: `IsGlobbingPattern(dirs[0]) == false` - // PRECONDITION: `matching_index > 0` - // INVARIANT: If `{d, ix}` is in queue, then `d` and `dirs[ix]` are at the - // same level in the filesystem tree. - // INVARIANT: If `{d, _}` is in queue, then `IsGlobbingPattern(d) == false`. - // INVARIANT: If `{d, _}` is in queue, then `d` is a real directory. - // INVARIANT: If `{_, ix}` is in queue, then `ix < dirs.size() - 1`. - // INVARIANT: If `{_, ix}` is in queue, `IsGlobbingPattern(dirs[ix + 1])`. - std::deque> expand_queue; - std::deque> next_expand_queue; - expand_queue.emplace_back(dirs[matching_index - 1], matching_index - 1); - - // Adding to `result` or `new_expand_queue` need to be protected by mutexes - // since there are multiple threads writing to these. - mutex result_mutex; - mutex queue_mutex; - - while (!expand_queue.empty()) { - next_expand_queue.clear(); - - // The work item for every item in `expand_queue`. - // pattern, we process them in parallel. - auto handle_level = [&fs, &results, &dirs, &expand_queue, - &next_expand_queue, &result_mutex, - &queue_mutex](int i) { - // See invariants above, all of these are valid accesses. - const auto& queue_item = expand_queue.at(i); - const std::string& parent = queue_item.first; - const int index = queue_item.second + 1; - const std::string& match_pattern = dirs[index]; - - // Get all children of `parent`. If this fails, return early. - std::vector children; - absl::Status s = fs->GetChildren(parent, &children); - if (s.code() == absl::StatusCode::kPermissionDenied) { - return; - } - - // Also return early if we don't have any children - if (children.empty()) { - return; - } - - // Since we can get extremely many children here and on some filesystems - // `IsDirectory` is expensive, we process the children in parallel. - // We also check that children match the pattern in parallel, for speedup. - // We store the status of the match and `IsDirectory` in - // `children_status` array, one element for each children. - std::vector children_status(children.size()); - auto handle_children = [&fs, &match_pattern, &parent, &children, - &children_status](int j) { - const std::string path = io::JoinPath(parent, children[j]); - if (!fs->Match(path, match_pattern)) { - children_status[j] = absl::Status(absl::StatusCode::kCancelled, - "Operation not needed"); - } else { - children_status[j] = fs->IsDirectory(path); - } - }; - ForEach(0, children.size(), handle_children); - - // At this point, pairing `children` with `children_status` will tell us - // if a children: - // * does not match the pattern - // * matches the pattern and is a directory - // * matches the pattern and is not a directory - // We fully ignore the first case. - // If we matched the last pattern (`index == dirs.size() - 1`) then all - // remaining children get added to the result. - // Otherwise, only the directories get added to the next queue. - for (size_t j = 0; j < children.size(); j++) { - if (children_status[j].code() == absl::StatusCode::kCancelled) { - continue; - } - - const std::string path = io::JoinPath(parent, children[j]); - if (index == dirs.size() - 1) { - mutex_lock l(result_mutex); - results->emplace_back(path); - } else if (children_status[j].ok()) { - mutex_lock l(queue_mutex); - next_expand_queue.emplace_back(path, index); - } - } - }; - ForEach(0, expand_queue.size(), handle_level); - - // After evaluating one level, swap the "buffers" - std::swap(expand_queue, next_expand_queue); - } - - return absl::OkStatus(); -} - -absl::StatusOr FileExists(Env* env, const string& fname) { - absl::Status status = env->FileExists(fname); - if (errors::IsNotFound(status)) { - return false; - } - TF_RETURN_IF_ERROR(status); - return true; -} - -} // namespace internal -} // namespace tsl diff --git a/tsl/platform/file_system_helper.h b/tsl/platform/file_system_helper.h index e9e7df6aa..49a0bd1c2 100644 --- a/tsl/platform/file_system_helper.h +++ b/tsl/platform/file_system_helper.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,49 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ #define TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ -#include -#include - -#include "tsl/platform/env.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace tsl { - -class FileSystem; -class Env; - -namespace internal { - -// Given a pattern, stores in 'results' the set of paths (in the given file -// system) that match that pattern. -// -// This helper may be used by implementations of FileSystem::GetMatchingPaths() -// in order to provide parallel scanning of subdirectories (except on iOS). -// -// Arguments: -// fs: may not be null and will be used to identify directories and list -// their contents. -// env: may not be null and will be used to check if a match has been found. -// pattern: see FileSystem::GetMatchingPaths() for details. -// results: will be cleared and may not be null. -// -// Returns an error status if any call to 'fs' failed. -absl::Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, - std::vector* results); - -// Given a file path, determines whether the file exists. This helper simplifies -// the use of Env::FileExists. -// -// Arguments: -// env: may not be null. -// fname: the file path to look up -// -// Returns true if the file exists, false if it does not exist, or an error -// Status. -absl::StatusOr FileExists(Env* env, const string& fname); - -} // namespace internal -} // namespace tsl +#include "xla/tsl/platform/file_system_helper.h" #endif // TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ diff --git a/tsl/platform/logging.h b/tsl/platform/logging.h index 939398882..193cb9b51 100644 --- a/tsl/platform/logging.h +++ b/tsl/platform/logging.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_LOGGING_H_ #define TENSORFLOW_TSL_PLATFORM_LOGGING_H_ -#include "tsl/platform/platform.h" - -#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) || \ - defined(PLATFORM_GOOGLE_IOS) || defined(GOOGLE_LOGGING) || \ - defined(__EMSCRIPTEN__) || defined(PLATFORM_CHROMIUMOS) -#include "xla/tsl/platform/google/logging.h" // IWYU pragma: export -#else -#include "xla/tsl/platform/default/logging.h" // IWYU pragma: export -#endif +#include "xla/tsl/platform/logging.h" #endif // TENSORFLOW_TSL_PLATFORM_LOGGING_H_ diff --git a/tsl/platform/logging_test.cc b/tsl/platform/logging_test.cc deleted file mode 100644 index 070696f19..000000000 --- a/tsl/platform/logging_test.cc +++ /dev/null @@ -1,352 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/logging.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/log_severity.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "tsl/platform/path.h" -#include "tsl/platform/stacktrace_handler.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -// Make sure popen and pclose are available on Windows. -#ifdef PLATFORM_WINDOWS -#define popen _popen -#define pclose _pclose -#endif - -static char* program_name; - -namespace tsl { -namespace { - -using ::testing::HasSubstr; -using ::testing::Not; - -TEST(Logging, Log) { - LOG(INFO) << "Hello"; - LOG(INFO) << "Another log message"; - LOG(ERROR) << "Error message"; - VLOG(1) << "A VLOG message"; - VLOG(2) << "A higher VLOG message"; - DVLOG(1) << "A DVLOG message"; - DVLOG(2) << "A higher DVLOG message"; -} - -TEST(Logging, CheckChecks) { - CHECK(true); - CHECK(7 > 5); - string a("abc"); - string b("xyz"); - CHECK_EQ(a, a); - CHECK_NE(a, b); - CHECK_EQ(3, 3); - CHECK_NE(4, 3); - CHECK_GT(4, 3); - CHECK_GE(3, 3); - CHECK_LT(2, 3); - CHECK_LE(2, 3); - - DCHECK(true); - DCHECK(7 > 5); - DCHECK_EQ(a, a); - DCHECK_NE(a, b); - DCHECK_EQ(3, 3); - DCHECK_NE(4, 3); - DCHECK_GT(4, 3); - DCHECK_GE(3, 3); - DCHECK_LT(2, 3); - DCHECK_LE(2, 3); -} - -TEST(LoggingDeathTest, FailedChecks) { - string a("abc"); - string b("xyz"); - const char* p_const = "hello there"; - const char* p_null_const = nullptr; - char mybuf[10]; - char* p_non_const = mybuf; - char* p_null = nullptr; - CHECK_NOTNULL(p_const); - CHECK_NOTNULL(p_non_const); - - ASSERT_DEATH(CHECK(false), "false"); - ASSERT_DEATH(CHECK(9 < 7), "9 < 7"); - ASSERT_DEATH(CHECK_EQ(a, b), "a == b"); - ASSERT_DEATH(CHECK_EQ(3, 4), "3 == 4"); - ASSERT_DEATH(CHECK_NE(3, 3), "3 != 3"); - ASSERT_DEATH(CHECK_GT(2, 3), "2 > 3"); - ASSERT_DEATH(CHECK_GE(2, 3), "2 >= 3"); - ASSERT_DEATH(CHECK_LT(3, 2), "3 < 2"); - ASSERT_DEATH(CHECK_LE(3, 2), "3 <= 2"); - ASSERT_DEATH(CHECK(false), "false"); - ASSERT_DEATH(printf("%s", CHECK_NOTNULL(p_null)), "Must be non NULL"); - ASSERT_DEATH(printf("%s", CHECK_NOTNULL(p_null_const)), "Must be non NULL"); -#ifndef NDEBUG - ASSERT_DEATH(DCHECK(9 < 7), "9 < 7"); - ASSERT_DEATH(DCHECK(9 < 7), "9 < 7"); - ASSERT_DEATH(DCHECK_EQ(a, b), "a == b"); - ASSERT_DEATH(DCHECK_EQ(3, 4), "3 == 4"); - ASSERT_DEATH(DCHECK_NE(3, 3), "3 != 3"); - ASSERT_DEATH(DCHECK_GT(2, 3), "2 > 3"); - ASSERT_DEATH(DCHECK_GE(2, 3), "2 >= 3"); - ASSERT_DEATH(DCHECK_LT(3, 2), "3 < 2"); - ASSERT_DEATH(DCHECK_LE(3, 2), "3 <= 2"); -#endif -} - -TEST(InternalLogString, Basic) { - // Just make sure that this code compiles (we don't actually verify - // the output) - internal::LogString(__FILE__, __LINE__, absl::LogSeverity::kInfo, - "Hello there"); -} - -class TestSink : public TFLogSink { - public: - void Send(const TFLogEntry& entry) override { - ss_ << entry.text_message() << std::endl; - } - - std::string Get() const { return ss_.str(); } - - private: - std::stringstream ss_; -}; - -TEST(LogSinkTest, testLogSinks) { - const int sinks_initial_size = TFGetLogSinks().size(); - TestSink sink; - - TFAddLogSink(&sink); - - EXPECT_EQ(TFGetLogSinks().size(), sinks_initial_size + 1); - - LOG(INFO) << "Foo"; - LOG(INFO) << "Bar"; - EXPECT_EQ(sink.Get(), "Foo\nBar\n"); - - TFRemoveLogSink(&sink); - - EXPECT_EQ(TFGetLogSinks().size(), sinks_initial_size); -} - -std::string ReadFromFilePointer(FILE* fp) { - std::string result; - while (!feof(fp)) { - char buf[512]; - size_t len = fread(buf, sizeof(buf[0]), 512, fp); - result.append(buf, len); - } - return result; -} - -absl::StatusOr ReadFromFile(const std::string& filename) { - std::shared_ptr fp(fopen(filename.c_str(), "r"), fclose); - if (fp == nullptr) { - return absl::ErrnoToStatus(errno, - absl::StrFormat("Cannot fopen '%s'", filename)); - } - return ReadFromFilePointer(fp.get()); -} - -class SubcommandTest : public ::testing::Test { - public: - static constexpr absl::string_view kLogVLog = "log_and_vlog"; - - static bool IsSubcommand(absl::string_view subcommand) { - return subcommand == kLogVLog; - } - - static int Run(absl::string_view subcommand) { - CHECK_EQ(subcommand, kLogVLog); - LOG(INFO) << "LOG INFO"; - LOG(WARNING) << "LOG WARNING"; - LOG(ERROR) << "LOG ERROR"; - LOG(INFO) << absl::StrFormat("VLOG_IS_ON(1)? %d", VLOG_IS_ON(1)); - LOG(INFO) << absl::StrFormat("VLOG_IS_ON(2)? %d", VLOG_IS_ON(2)); - LOG(INFO) << absl::StrFormat("VLOG_IS_ON(3)? %d", VLOG_IS_ON(3)); - VLOG(1) << "VLevel 1"; - VLOG(2) << "VLevel 2"; - VLOG(3) << "VLevel 3"; - return EXIT_SUCCESS; - } - - protected: - absl::StatusOr CaptureOutput(const char* invocation) { - std::shared_ptr fp(popen(invocation, "r"), pclose); - if (fp == nullptr) { - return absl::ErrnoToStatus( - errno, absl::StrFormat("Cannot popen '%s'", invocation)); - } - return ReadFromFilePointer(fp.get()); - } -}; - -// By default, messages with severity >= INFO should be printed. -TEST_F(SubcommandTest, LogDefaultTest) { - std::string command = absl::StrFormat("%s %s", program_name, kLogVLog); -#if defined(PLATFORM_GOOGLE) - command += " --alsologtostderr"; -#endif - command += " 2>&1"; - TF_ASSERT_OK_AND_ASSIGN(std::string out, CaptureOutput(command.c_str())); - EXPECT_THAT(out, HasSubstr("LOG INFO")); - EXPECT_THAT(out, HasSubstr("LOG WARNING")); - EXPECT_THAT(out, HasSubstr("LOG ERROR")); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(1)? 0")); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(2)? 0")); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(3)? 0")); -} - -TEST_F(SubcommandTest, MinLogLevelTest) { - std::string command = absl::StrFormat("%s %s", program_name, kLogVLog); -#if defined(PLATFORM_GOOGLE) - command += " --minloglevel=1 --alsologtostderr"; -#elif defined(PLATFORM_WINDOWS) - command = absl::StrFormat("set TF_CPP_MIN_LOG_LEVEL=1 && %s", command); -#else - command = absl::StrFormat("TF_CPP_MIN_LOG_LEVEL=1 %s", command); -#endif - command += " 2>&1"; - TF_ASSERT_OK_AND_ASSIGN(std::string out, CaptureOutput(command.c_str())); - EXPECT_THAT(out, Not(HasSubstr("LOG INFO"))); - EXPECT_THAT(out, HasSubstr("LOG WARNING")); - EXPECT_THAT(out, HasSubstr("LOG ERROR")); -} - -// By default, no VLOG messages should be printed. -TEST_F(SubcommandTest, VLogDefaultTest) { - std::string command = absl::StrFormat("%s %s", program_name, kLogVLog); -#if defined(PLATFORM_GOOGLE) - command += " --alsologtostderr"; -#endif - command += " 2>&1"; - TF_ASSERT_OK_AND_ASSIGN(std::string out, CaptureOutput(command.c_str())); - EXPECT_THAT(out, Not(HasSubstr("VLevel 1"))); - EXPECT_THAT(out, Not(HasSubstr("VLevel 2"))); - EXPECT_THAT(out, Not(HasSubstr("VLevel 3"))); -} - -TEST_F(SubcommandTest, MaxVLogLevelTest) { - std::string command = absl::StrFormat("%s %s", program_name, kLogVLog); -#if defined(PLATFORM_GOOGLE) - command += " --v=2 --alsologtostderr"; -#elif defined(PLATFORM_WINDOWS) - command = absl::StrFormat("set TF_CPP_MAX_VLOG_LEVEL=2 && %s", command); -#else - command = absl::StrFormat("TF_CPP_MAX_VLOG_LEVEL=2 %s", command); -#endif - command += " 2>&1"; - TF_ASSERT_OK_AND_ASSIGN(std::string out, CaptureOutput(command.c_str())); - EXPECT_THAT(out, HasSubstr("VLevel 1")); - EXPECT_THAT(out, HasSubstr("VLevel 2")); - EXPECT_THAT(out, Not(HasSubstr("VLevel 3"))); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(1)? 1")); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(2)? 1")); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(3)? 0")); -} - -TEST_F(SubcommandTest, VModuleTest) { - std::string command = absl::StrFormat("%s %s", program_name, kLogVLog); -#if defined(PLATFORM_GOOGLE) - command += " --vmodule=logging_test=2,shoobadooba=3 --alsologtostderr"; -#elif defined(PLATFORM_WINDOWS) - command = absl::StrFormat( - "set TF_CPP_VMODULE=logging_test=2,shoobadooba=3 && %s", command); -#else - command = absl::StrFormat("TF_CPP_VMODULE=logging_test=2,shoobadooba=3 %s", - command); -#endif - command += " 2>&1"; - TF_ASSERT_OK_AND_ASSIGN(std::string out, CaptureOutput(command.c_str())); - EXPECT_THAT(out, HasSubstr("VLevel 1")); - EXPECT_THAT(out, HasSubstr("VLevel 2")); - EXPECT_THAT(out, Not(HasSubstr("VLevel 3"))); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(1)? 1")); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(2)? 1")); - EXPECT_THAT(out, HasSubstr("VLOG_IS_ON(3)? 0")); -} - -TEST_F(SubcommandTest, VLogFilenameTest) { -#if defined(PLATFORM_GOOGLE) - constexpr bool kVLogFilenameEnvVarIsSupported = false; -#else - constexpr bool kVLogFilenameEnvVarIsSupported = true; -#endif - if (!kVLogFilenameEnvVarIsSupported) { - GTEST_SKIP() << "Not supported on this platform"; - } - - std::string command = absl::StrFormat("%s %s", program_name, kLogVLog); - std::string filename = io::GetTempFilename("logging_test"); -#if defined(PLATFORM_WINDOWS) - command = absl::StrFormat( - "set TF_CPP_VLOG_FILENAME=%s && set TF_CPP_MAX_VLOG_LEVEL=1 && %s", - filename, command); -#else - command = absl::StrFormat( - "TF_CPP_VLOG_FILENAME=%s TF_CPP_MAX_VLOG_LEVEL=1 %s", filename, command); -#endif - command += " 2>&1"; - - // All output should be in the file, not in stderr. - TF_ASSERT_OK_AND_ASSIGN(std::string out, CaptureOutput(command.c_str())); - EXPECT_THAT(out, Not(HasSubstr("LOG INFO"))); - EXPECT_THAT(out, Not(HasSubstr("LOG WARNING"))); - EXPECT_THAT(out, Not(HasSubstr("LOG ERROR"))); - EXPECT_THAT(out, Not(HasSubstr("VLOG_IS_ON(1)?"))); - EXPECT_THAT(out, Not(HasSubstr("VLOG_IS_ON(2)?"))); - EXPECT_THAT(out, Not(HasSubstr("VLOG_IS_ON(3)?"))); - EXPECT_THAT(out, Not(HasSubstr("VLevel 1"))); - EXPECT_THAT(out, Not(HasSubstr("VLevel 2"))); - EXPECT_THAT(out, Not(HasSubstr("VLevel 3"))); - - TF_ASSERT_OK_AND_ASSIGN(std::string log_file, ReadFromFile(filename)); - EXPECT_THAT(log_file, HasSubstr("LOG INFO")); - EXPECT_THAT(log_file, HasSubstr("LOG WARNING")); - EXPECT_THAT(log_file, HasSubstr("LOG ERROR")); - EXPECT_THAT(log_file, HasSubstr("VLOG_IS_ON(1)")); - EXPECT_THAT(log_file, HasSubstr("VLOG_IS_ON(2)")); - EXPECT_THAT(log_file, HasSubstr("VLOG_IS_ON(3)")); - EXPECT_THAT(log_file, HasSubstr("VLevel 1")); - EXPECT_THAT(log_file, Not(HasSubstr("VLevel 2"))); - EXPECT_THAT(log_file, Not(HasSubstr("VLevel 3"))); -} - -} // namespace -} // namespace tsl - -GTEST_API_ int main(int argc, char** argv) { - tsl::testing::InstallStacktraceHandler(); - testing::InitGoogleTest(&argc, argv); - program_name = argv[0]; - if (argc >= 2 && tsl::SubcommandTest::IsSubcommand(argv[1])) { - return tsl::SubcommandTest::Run(argv[1]); - } - return RUN_ALL_TESTS(); -} diff --git a/tsl/platform/macros.h b/tsl/platform/macros.h index cb91c4ff6..960d7ed2e 100644 --- a/tsl/platform/macros.h +++ b/tsl/platform/macros.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,147 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_MACROS_H_ #define TENSORFLOW_TSL_PLATFORM_MACROS_H_ -// Compiler attributes -#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG) -// Compiler supports GCC-style attributes -#define TF_ATTRIBUTE_NORETURN __attribute__((noreturn)) -#define TF_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) -#define TF_ATTRIBUTE_NOINLINE __attribute__((noinline)) -#define TF_ATTRIBUTE_UNUSED __attribute__((unused)) -#define TF_ATTRIBUTE_COLD __attribute__((cold)) -#define TF_ATTRIBUTE_WEAK __attribute__((weak)) -#define TF_PACKED __attribute__((packed)) -#define TF_MUST_USE_RESULT __attribute__((warn_unused_result)) -#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) \ - __attribute__((__format__(__printf__, string_index, first_to_check))) -#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) \ - __attribute__((__format__(__scanf__, string_index, first_to_check))) -#elif defined(_MSC_VER) -// Non-GCC equivalents -#define TF_ATTRIBUTE_NORETURN __declspec(noreturn) -#define TF_ATTRIBUTE_ALWAYS_INLINE __forceinline -#define TF_ATTRIBUTE_NOINLINE -#define TF_ATTRIBUTE_UNUSED -#define TF_ATTRIBUTE_COLD -#define TF_ATTRIBUTE_WEAK -#define TF_MUST_USE_RESULT -#define TF_PACKED -#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) -#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) -#else -// Non-GCC equivalents -#define TF_ATTRIBUTE_NORETURN -#define TF_ATTRIBUTE_ALWAYS_INLINE -#define TF_ATTRIBUTE_NOINLINE -#define TF_ATTRIBUTE_UNUSED -#define TF_ATTRIBUTE_COLD -#define TF_ATTRIBUTE_WEAK -#define TF_MUST_USE_RESULT -#define TF_PACKED -#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) -#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) -#endif - -// Control visibility outside .so -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_EXPORT __declspec(dllexport) -#else -#define TF_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 - -#ifdef __has_builtin -#define TF_HAS_BUILTIN(x) __has_builtin(x) -#else -#define TF_HAS_BUILTIN(x) 0 -#endif - -// C++11-style attributes (N2761) -#if defined(__has_cpp_attribute) -// Safely checks if an attribute is supported. Equivalent to -// ABSL_HAVE_CPP_ATTRIBUTE. -#define TF_HAS_CPP_ATTRIBUTE(n) __has_cpp_attribute(n) -#else -#define TF_HAS_CPP_ATTRIBUTE(n) 0 -#endif - -// [[clang::annotate("x")]] allows attaching custom strings (e.g. "x") to -// declarations (variables, functions, fields, etc.) for use by tools. They are -// represented in the Clang AST (as AnnotateAttr nodes) and in LLVM IR, but not -// in final output. -#if TF_HAS_CPP_ATTRIBUTE(clang::annotate) -#define TF_ATTRIBUTE_ANNOTATE(str) [[clang::annotate(str)]] -#else -#define TF_ATTRIBUTE_ANNOTATE(str) -#endif - -// A variable declaration annotated with the `TF_CONST_INIT` attribute will -// not compile (on supported platforms) unless the variable has a constant -// initializer. -#if TF_HAS_CPP_ATTRIBUTE(clang::require_constant_initialization) -#define TF_CONST_INIT [[clang::require_constant_initialization]] -#else -#define TF_CONST_INIT -#endif - -// Compilers can be told that a certain branch is not likely to be taken -// (for instance, a CHECK failure), and use that information in static -// analysis. Giving it this information can help it optimize for the -// common case in the absence of better information (ie. -// -fprofile-arcs). -#if TF_HAS_BUILTIN(__builtin_expect) || (defined(__GNUC__) && __GNUC__ >= 3) -#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0)) -#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) -#else -#define TF_PREDICT_FALSE(x) (x) -#define TF_PREDICT_TRUE(x) (x) -#endif - -// DEPRECATED: directly use the macro implementation instead. -// A macro to disallow the copy constructor and operator= functions -// This is usually placed in the private: declarations for a class. -#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName&) = delete; \ - void operator=(const TypeName&) = delete - -// The TF_ARRAYSIZE(arr) macro returns the # of elements in an array arr. -// -// The expression TF_ARRAYSIZE(a) is a compile-time constant of type -// size_t. -#define TF_ARRAYSIZE(a) \ - ((sizeof(a) / sizeof(*(a))) / \ - static_cast(!(sizeof(a) % sizeof(*(a))))) - -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ - (defined(_MSC_VER) && _MSC_VER >= 1900) -// Define this to 1 if the code is compiled in C++11 mode; leave it -// undefined otherwise. Do NOT define it to 0 -- that causes -// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. -#define LANG_CXX11 1 -#endif - -#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning) -#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") -#define TF_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT -#endif -#endif - -#ifndef TF_FALLTHROUGH_INTENDED -#define TF_FALLTHROUGH_INTENDED \ - do { \ - } while (0) -#endif - -namespace tsl { -namespace internal { -template -void remove_unused_variable_compiler_warning(const T&){}; -} // namespace internal -} // namespace tsl -#define TF_UNUSED_VARIABLE(x) \ - tensorflow::internal::remove_unused_variable_compiler_warning(x) +#include "xla/tsl/platform/macros.h" #endif // TENSORFLOW_TSL_PLATFORM_MACROS_H_ diff --git a/tsl/platform/ram_file_system.h b/tsl/platform/ram_file_system.h index 861b06666..64d04a9a6 100644 --- a/tsl/platform/ram_file_system.h +++ b/tsl/platform/ram_file_system.h @@ -29,8 +29,8 @@ limitations under the License. #include #include "absl/strings/match.h" -#include "tsl/platform/env.h" -#include "tsl/platform/file_system.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/file_system.h" #include "tsl/platform/mutex.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/types.h" diff --git a/tsl/platform/status.cc b/tsl/platform/status.cc deleted file mode 100644 index f6d4aed1d..000000000 --- a/tsl/platform/status.cc +++ /dev/null @@ -1,361 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/status.h" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/call_once.h" -#include "absl/functional/function_ref.h" -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/escaping.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/stack_frame.h" -#include "tsl/platform/stacktrace.h" -#include "tsl/platform/str_util.h" -#include "tsl/platform/strcat.h" -#include "tsl/platform/stringprintf.h" - -namespace tsl { - -namespace { - -// Log sink is used to collect recent warning and error log messages to be -// attached to the error status. -class StatusLogSink : public TFLogSink { - public: - static StatusLogSink* GetInstance() { - static StatusLogSink* sink = new StatusLogSink(); - return sink; - } - - void enable() { - absl::call_once(flag_, [this] { - num_messages_ = 5; // default to 5 messages - - if (const char* num_msgs_str = - getenv("TF_WORKER_NUM_FORWARDED_LOG_MESSAGES")) { - if (!absl::SimpleAtoi(num_msgs_str, &num_messages_)) { - LOG(WARNING) << "Failed to parse env variable " - "TF_WORKER_NUM_WARNING_ERROR_LOG_IN_STATUS=" - << num_msgs_str << " as int. Using the default value " - << num_messages_ << "."; - } - } - - if (num_messages_ > 0) { - TFAddLogSink(this); - } - }); - } - - void GetMessages(std::vector* logs) TF_LOCKS_EXCLUDED(mu_) { - mutex_lock lock(mu_); - - for (auto& msg : messages_) { - logs->push_back(msg); - } - } - - void Send(const TFLogEntry& entry) override TF_LOCKS_EXCLUDED(mu_) { - if (entry.log_severity() < absl::LogSeverity::kWarning) return; - - mutex_lock lock(mu_); - messages_.emplace_back(entry.ToString()); - if (messages_.size() > static_cast(num_messages_)) { - messages_.pop_front(); - } - } - - private: - mutex mu_; - // for allowing repeated/concurrent calls to enable() - absl::once_flag flag_; - int num_messages_ = 0; - std::deque messages_ TF_GUARDED_BY(mu_); -}; - -} // namespace - -// TODO(b/197552541) Move this namespace to errors.h after absl migration. -namespace errors { -static constexpr const char kStackTraceProtoUrl[] = - "type.googleapis.com/tensorflow.StackTracePayload"; - -void SetStackTrace(absl::Status& status, std::vector stack_trace) { - // Given the StackFrame fields are (a) line number (b) filename (c) function - // name, we can safely assume that there is no `\n` in there. - // Thus, we can serialize as strings using a simple new line delimiter. - // - // This has the benefit that we don't need to depend on protobuf. Note that - // we do this only the serialization of the StackFrame is an implementation - // detail and that we don't not need persistent storage or wire serialization. - std::vector items; - items.reserve(stack_trace.size()); - for (StackFrame& frame : stack_trace) { - // We are extra safe and remove any new line in the filename and function - // name. - items.push_back( - absl::StrCat(absl::StrReplaceAll(frame.file_name, {{"\n", ""}}), "\n", - frame.line_number, "\n", - absl::StrReplaceAll(frame.function_name, {{"\n", ""}}))); - } - status.SetPayload(kStackTraceProtoUrl, - absl::Cord(absl::StrJoin(items, "\n"))); -} - -std::vector GetStackTrace(const absl::Status& status) { - std::vector stack_trace; - absl::optional maybe_serialized_payload = - status.GetPayload(kStackTraceProtoUrl); - if (maybe_serialized_payload.has_value()) { - std::vector split = - absl::StrSplit(maybe_serialized_payload.value().Flatten(), '\n'); - assert(split.size() % 3 == 0); - for (int i = 0; i < split.size() / 3; ++i) { - const int idx = 3 * i; - int line_number = -1; - CHECK(absl::SimpleAtoi(split[idx + 1], &line_number)); // Crash OK - stack_trace.emplace_back(std::move(split[idx]), line_number, - std::move(split[idx + 2])); - } - } - return stack_trace; -} - -} // namespace errors - -// NB: This Windows-only implementation is exists only to avoid a linker error. -// Remove if this is resolved. -#ifdef _WIN32 -const char* NullTerminatedMessage(const absl::Status& status) { - return absl::StatusMessageAsCStr(status); -} -#endif - -std::string* TfCheckOpHelperOutOfLine(const absl::Status& v, const char* msg) { - std::stringstream ss; - ss << "Non-OK-status: " << msg << "\nStatus: " << v; - - // Leaks string but this is only to be used in a fatal error message - return new std::string(ss.str()); -} - -StatusGroup::StatusGroup() {} - -StatusGroup::StatusGroup(std::initializer_list statuses) { - for (const absl::Status& s : statuses) { - Update(s); - } -} - -static constexpr const char kDerivedStatusProtoUrl[] = - "type.googleapis.com/tensorflow.DerivedStatus"; - -absl::Status StatusGroup::MakeDerived(const absl::Status& s) { - if (IsDerived(s)) { - return s; - } else { - absl::Status derived(s); - // TODO(b/200167936): Serialize an instance of DerivedStatus proto instead - // of using the string directly. The string is never used so it is not - // causing any issues at the moment. - derived.SetPayload(kDerivedStatusProtoUrl, absl::Cord("")); - return derived; - } -} - -bool StatusGroup::IsDerived(const absl::Status& s) { - return s.GetPayload(kDerivedStatusProtoUrl).has_value(); -} - -void StatusGroup::ConfigureLogHistory() { - StatusLogSink::GetInstance()->enable(); -} - -void StatusGroup::Update(const absl::Status& s) { - if (s.ok()) { - ++num_ok_; - } else { - ok_ = false; - if (IsDerived(s)) { - derived_.insert(s); - } else { - non_derived_.insert(s); - } - } -} - -static constexpr int kMaxAggregatedStatusMessageSize = 8 * 1024; -static constexpr int kMaxAttachedLogMessageSize = 512; - -std::unordered_map StatusGroup::GetPayloads() const { - std::unordered_map payloads; - auto capture_payload = [&payloads](absl::string_view key, - const absl::Cord& value) { - payloads[std::string(key)] = value; - }; - for (const auto& status : derived_) { - status.ForEachPayload(capture_payload); - } - - // If a key appears in both derived_ and non_derived_ payloads, then the - // non_derived_ payload receives priority. - for (const auto& status : non_derived_) { - status.ForEachPayload(capture_payload); - } - - payloads.erase(kDerivedStatusProtoUrl); - - return payloads; -} - -absl::Status MakeStatus( - absl::StatusCode code, absl::string_view message, - const std::unordered_map& payloads) { - absl::Status status(code, message); - for (const auto& payload : payloads) { - status.SetPayload(payload.first, payload.second); - } - return status; -} - -std::string MakeString(const absl::Status& status) { - return absl::StrCat(absl::StatusCodeToString(status.code()), ": ", - status.message()); -} - -// Summarize all the status objects in the StatusGroup. This is used when -// individual Status objects in the StatusGroup are not already summarized. -absl::Status StatusGroup::as_summary_status() const { - if (ok_) { - return absl::OkStatus(); - } - - // Gather recent logs as a string - auto get_recent_logs = [this]() -> std::string { - if (!recent_logs_.empty()) { - std::vector fmt; - fmt.push_back("\nRecent warning and error logs:"); - for (auto& log : recent_logs_) { - // Add an indentation to make it look nicer. - fmt.push_back(" " + log.substr(0, kMaxAttachedLogMessageSize)); - } - return absl::StrJoin(fmt, "\n"); - } else { - return ""; - } - }; - - // If only one root status is found, do not add summary header and footer. - if (non_derived_.size() == 1) { - return MakeStatus( - non_derived_.begin()->code(), - strings::StrCat(non_derived_.begin()->message(), get_recent_logs()), - GetPayloads()); - } - - if (!non_derived_.empty()) { - std::vector fmt; - - fmt.push_back( - strings::Printf("%zu root error(s) found.", non_derived_.size())); - - int index = 0; - auto code = absl::StatusCode::kCancelled; - for (const auto& s : non_derived_) { - // NOTE: Avoid using CANCELLED as the code of summary status if the group - // contains other error code. - if (code == absl::StatusCode::kCancelled && - s.code() != absl::StatusCode::kCancelled) { - code = s.code(); - } - fmt.emplace_back(strings::StrCat(" (", index, ") ", MakeString(s))); - ++index; - } - - fmt.push_back(strings::Printf("%zu successful operations.", num_ok_)); - fmt.push_back( - strings::Printf("%zu derived errors ignored.", derived_.size())); - - std::string error_msg = - absl::StrJoin(fmt, "\n").substr(0, kMaxAggregatedStatusMessageSize); - - return MakeStatus(code, strings::StrCat(error_msg, get_recent_logs()), - GetPayloads()); - } else { - // All statuses are derived. Pick the first available status to return. - return MakeDerived(MakeStatus(derived_.begin()->code(), - derived_.begin()->message(), GetPayloads())); - } -} - -// Concatenate all the status objects in the StatusGroup. This is used when -// individual Status objects in the StatusGroup are already summarized Status. -absl::Status StatusGroup::as_concatenated_status() const { - if (ok_) { - return absl::OkStatus(); - } - - // If only one root status is found, return it directly. - if (non_derived_.size() == 1) { - return MakeStatus(non_derived_.begin()->code(), - non_derived_.begin()->message(), GetPayloads()); - } - - if (!non_derived_.empty()) { - std::vector fmt; - fmt.emplace_back("\n====================="); - for (const auto& s : non_derived_) { - fmt.emplace_back(MakeString(s)); - } - fmt.emplace_back("=====================\n"); - return MakeStatus( - non_derived_.begin()->code(), - absl::StrJoin(fmt, "\n").substr(0, kMaxAggregatedStatusMessageSize), - GetPayloads()); - } else { - // All statuses are derived. Pick the first available status to return. - // This should not happen in normal execution. - return MakeDerived(MakeStatus(derived_.begin()->code(), - derived_.begin()->message(), GetPayloads())); - } -} - -void StatusGroup::AttachLogMessages() { - recent_logs_.clear(); - StatusLogSink::GetInstance()->GetMessages(&recent_logs_); -} - -} // namespace tsl diff --git a/tsl/platform/status.h b/tsl/platform/status.h index 61238a13f..fdd9343ac 100644 --- a/tsl/platform/status.h +++ b/tsl/platform/status.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,211 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_H_ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/functional/function_ref.h" -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/stack_frame.h" -#include "tsl/platform/types.h" - -// Include appropriate platform-dependent parts of status. -#if defined(PLATFORM_GOOGLE) -#include "xla/tsl/platform/google/status.h" // IWYU pragma: export -#else -#include "xla/tsl/platform/default/status.h" // IWYU pragma: export -#endif - -// TODO: b/323943471 - This macro should eventually be provided by Abseil. -#ifndef ABSL_DEPRECATE_AND_INLINE -#define ABSL_DEPRECATE_AND_INLINE() -#endif - -namespace tsl { - -// Since April 2023, tensorflow::Status is an alias to absl::Status. The first -// TF release including this change will be TF 2.14 (the latest release in -// April 2023 is 2.13). -// At the same time `tsl::errors::Code` aliases `absl::StatusCode`. -// -// Here is a set of correspondences: -// - Use `absl::OkStatus()` instead of `tsl::OkStatus()`. -typedef absl::Status Status ABSL_DEPRECATE_AND_INLINE(); - -namespace errors { -typedef absl::StatusCode Code ABSL_DEPRECATE_AND_INLINE(); -} // namespace errors -namespace error { -typedef ::tensorflow::error::Code Code; -} // namespace error -} // namespace tsl - -// Transparent comparison between tensorflow::error::Code protobuf enum and -// absl::Status. -// -// The longer term objective is to delete these when we have done the transition -// to absl::Status. -namespace tensorflow::error { -inline bool operator==(const ::tensorflow::error::Code& c1, - const absl::StatusCode& c2) { - return static_cast(c1) == static_cast(c2); -} - -inline bool operator!=(const ::tensorflow::error::Code& c1, - const absl::StatusCode& c2) { - return static_cast(c1) != static_cast(c2); -} -} // namespace tensorflow::error - -namespace absl { -inline bool operator==(const ::absl::StatusCode& c1, - const ::tensorflow::error::Code& c2) { - return static_cast(c1) == static_cast(c2); -} - -inline bool operator!=(const ::absl::StatusCode& c1, - const ::tensorflow::error::Code& c2) { - return static_cast(c1) != static_cast(c2); -} -} // namespace absl - -namespace tsl { - -// OkStatus() -// -// Returns an OK status, equivalent to a default constructed instance. Prefer -// usage of `OkStatus()` when constructing such an OK status. -ABSL_DEPRECATE_AND_INLINE() inline absl::Status OkStatus() { - return absl::OkStatus(); -}; - -ABSL_DEPRECATE_AND_INLINE() -inline absl::Status FromAbslStatus(const absl::Status& s) { return s; } -ABSL_DEPRECATE_AND_INLINE() -inline absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } - -// Given `Status.message()` does not guarantee to be always backed by a -// null-terminated string, we have this utility function when it's needed for -// the Tensorflow C-API. -// A more robust API would be to get both a `char*` of the beginning of the -// string, plus the size (see e.g. `XlaCustomCallStatusSetFailure`). -// NB: This Windows-only implementation is exists only to avoid a linker error. -// Remove if this is resolved. -#ifdef _WIN32 -const char* NullTerminatedMessage(const absl::Status& status); -#else -ABSL_DEPRECATE_AND_INLINE() -inline const char* NullTerminatedMessage(const absl::Status& status) { - return absl::StatusMessageAsCStr(status); -} -#endif - -// TODO(b/197552541) Move this namespace to errors.h. -namespace errors { - -void SetStackTrace(absl::Status& status, std::vector stack_trace); - -std::vector GetStackTrace(const absl::Status& status); -} // namespace errors - -// Helper class to manage multiple child status values. -class StatusGroup { - public: - StatusGroup(); - // Constructor to form a StatusGroup from any N set of Status arguments. - // Usage: StatusGroup({status_a, status_b, status_c}); - StatusGroup(std::initializer_list statuses); - - // Utility function to mark a Status as derived. By marking derived status, - // Derived status messages are ignored when reporting errors to end users. - static absl::Status MakeDerived(const absl::Status& s); - static bool IsDerived(const absl::Status& s); - - // Enable warning and error log collection for appending to the aggregated - // status. This function may be called more than once. - static void ConfigureLogHistory(); - - // Returns merged payloads of all statuses. In case multiple statuses have the - // same payload key, non-derived statuses have priority over derived ones, - // otherwise one payload value will be chosen in an unspecified but - // deterministic order. - // NOTE: The payload marking derived statuses as derived will not be returned. - std::unordered_map GetPayloads() const; - - // Return a merged status with combined child status messages with a summary. - absl::Status as_summary_status() const; - // Return a merged status with combined child status messages with - // concatenation. - absl::Status as_concatenated_status() const; - - bool ok() const { return ok_; } - - // Augment this group with the child status `status`. - void Update(const absl::Status& status); - - // Attach recent warning and error log messages - void AttachLogMessages(); - bool HasLogMessages() const { return !recent_logs_.empty(); } - - private: - bool ok_ = true; - size_t num_ok_ = 0; - - // Maintain a sorted collection of statuses. - struct CompareStatus { - bool operator()(const absl::Status& a, const absl::Status& b) const { - return a.ToString() > b.ToString(); - } - }; - // Using std::set instead of absl::btree_set to keep size for certain - // dependent libraries under the limit. - std::set derived_; - std::set non_derived_; - - std::vector recent_logs_; // recent warning and error logs -}; - -typedef std::function StatusCallback; - -extern ::tsl::string* TfCheckOpHelperOutOfLine(const absl::Status& v, - const char* msg); - -inline ::tsl::string* TfCheckOpHelper(absl::Status v, const char* msg) { - if (v.ok()) return nullptr; - return TfCheckOpHelperOutOfLine(v, msg); -} - -#define TF_DO_CHECK_OK(val, level) \ - while (auto* _result = ::tsl::TfCheckOpHelper(val, #val)) \ - LOG(level) << *(_result) - -#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) -#define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL) - -// DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt -// mode. -#ifndef NDEBUG -#define TF_DCHECK_OK(val) TF_CHECK_OK(val) -#else -#define TF_DCHECK_OK(val) \ - while (false && (::tsl::OkStatus() == (val))) LOG(FATAL) -#endif - -} // namespace tsl +#include "xla/tsl/platform/status.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUS_H_ diff --git a/tsl/platform/status_matchers.cc b/tsl/platform/status_matchers.cc deleted file mode 100644 index bcb04018d..000000000 --- a/tsl/platform/status_matchers.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tsl/platform/status_matchers.h" - -#include -#include - -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" - -namespace tsl { -namespace testing { -namespace internal_status { - -void StatusIsMatcherCommonImpl::DescribeTo(std::ostream* os) const { - *os << "has a status code that "; - code_matcher_.DescribeTo(os); - *os << ", and has an error message that "; - message_matcher_.DescribeTo(os); -} - -void StatusIsMatcherCommonImpl::DescribeNegationTo(std::ostream* os) const { - *os << "has a status code that "; - code_matcher_.DescribeNegationTo(os); - *os << ", or has an error message that "; - message_matcher_.DescribeNegationTo(os); -} - -bool StatusIsMatcherCommonImpl::MatchAndExplain( - const absl::Status& status, - ::testing::MatchResultListener* result_listener) const { - ::testing::StringMatchResultListener inner_listener; - - inner_listener.Clear(); - if (!code_matcher_.MatchAndExplain( - static_cast(status.code()), &inner_listener)) { - *result_listener << (inner_listener.str().empty() - ? "whose status code is wrong" - : "which has a status code " + - inner_listener.str()); - return false; - } - - if (!message_matcher_.Matches(std::string(status.message()))) { - *result_listener << "whose error message is wrong"; - return false; - } - - return true; -} - -} // namespace internal_status -} // namespace testing -} // namespace tsl diff --git a/tsl/platform/status_matchers.h b/tsl/platform/status_matchers.h index e7e12c269..e9a559860 100644 --- a/tsl/platform/status_matchers.h +++ b/tsl/platform/status_matchers.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,332 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_MATCHERS_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_MATCHERS_H_ -#include -#include -#include - -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -// Defines the following utilities: -// -// =============== -// IsOkAndHolds(m) -// =============== -// -// This matcher matches a StatusOr value whose status is OK and whose inner -// value matches matcher m. Example: -// -// using ::tsl::testing::IsOkAndHolds; -// using ::testing::HasSubstr; -// ... -// StatusOr status_or_message("Hello, world"); -// EXPECT_THAT(status_or_message, IsOkAndHolds("Hello, world"))); -// EXPECT_THAT(status_or_message, IsOkAndHolds(HasSubstr("Hello,"))); -// -// =============================== -// StatusIs(status_code_matcher, -// error_message_matcher) -// =============================== -// -// This matcher matches a Status or StatusOr if the following are true: -// -// - the status's code() matches status_code_matcher, and -// - the status's error_message() matches error_message_matcher. -// -// Example: -// -// using ::tsl::testing::StatusIs; -// using ::testing::HasSubstr; -// using ::testing::MatchesRegex; -// using ::testing::Ne; -// using ::testing::_; -// StatusOr GetMessage(int id); -// ... -// -// // The status code must be CANCELLED; the error message can be anything. -// EXPECT_THAT(GetName(42), -// StatusIs(tsl::error::CANCELLED, _)); -// -// // The status code can be anything; the error message must match the regex. -// EXPECT_THAT(GetName(43), -// StatusIs(_, MatchesRegex("server.*time-out"))); -// -// // The status code should not be CANCELLED; the error message can be -// // anything with "Cancelled" in it. -// EXPECT_THAT(GetName(44), -// StatusIs(Ne(tsl::error::CANCELLED), -// HasSubstr("Cancelled")))); -// -// ============================= -// StatusIs(status_code_matcher) -// ============================= -// -// This is a shorthand for -// StatusIs(status_code_matcher, ::testing::_) -// -// In other words, it's like the two-argument StatusIs(), except that it ignores -// error messages. -// -// ====== -// IsOk() -// ====== -// -// Matches a Status or StatusOr whose status value is OK. -// Equivalent to 'StatusIs(error::OK)'. -// -// Example: -// ... -// StatusOr message("Hello, world"); -// EXPECT_THAT(message, IsOk()); -// Status status = OkStatus(); -// EXPECT_THAT(status, IsOk()); - -namespace tsl { - -inline void PrintTo(const tsl::error::Code code, std::ostream* os) { - *os << Code_Name(code); -} - -template -void PrintTo(const StatusOr& status_or, std::ostream* os) { - *os << ::testing::PrintToString(status_or.status()); - if (status_or.ok()) { - *os << ": " << ::testing::PrintToString(status_or.value()); - } -} - -namespace testing { -namespace internal_status { - -inline const absl::Status& GetStatus(const absl::Status& status) { - return status; -} - -template -inline const absl::Status& GetStatus(const StatusOr& status) { - return status.status(); -} - -//////////////////////////////////////////////////////////// -// Implementation of IsOkAndHolds(). -// -// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a -// reference to StatusOr. -template -class IsOkAndHoldsMatcherImpl - : public ::testing::MatcherInterface { - public: - typedef - typename std::remove_reference::type::value_type value_type; - - template - explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) - : inner_matcher_(::testing::SafeMatcherCast( - std::forward(inner_matcher))) {} - - void DescribeTo(std::ostream* os) const override { - *os << "is OK and has a value that "; - inner_matcher_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - *os << "isn't OK or has a value that "; - inner_matcher_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - StatusOrType actual_value, - ::testing::MatchResultListener* result_listener) const override { - if (!actual_value.ok()) { - *result_listener << "which has status " << actual_value.status(); - return false; - } - - ::testing::StringMatchResultListener inner_listener; - const bool matches = - inner_matcher_.MatchAndExplain(*actual_value, &inner_listener); - const std::string inner_explanation = inner_listener.str(); - if (!inner_explanation.empty()) { - *result_listener << "which contains value " - << ::testing::PrintToString(*actual_value) << ", " - << inner_explanation; - } - return matches; - } - - private: - const ::testing::Matcher inner_matcher_; -}; - -// Implements IsOkAndHolds(m) as a polymorphic matcher. -template -class IsOkAndHoldsMatcher { - public: - explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher) - : inner_matcher_(std::move(inner_matcher)) {} - - // Converts this polymorphic matcher to a monomorphic matcher of the given - // type. StatusOrType can be either StatusOr or a reference to StatusOr. - template - operator ::testing::Matcher() const { // NOLINT - return ::testing::Matcher( - new IsOkAndHoldsMatcherImpl(inner_matcher_)); - } - - private: - const InnerMatcher inner_matcher_; -}; - -//////////////////////////////////////////////////////////// -// Implementation of StatusIs(). -// -// StatusIs() is a polymorphic matcher. This class is the common -// implementation of it shared by all types T where StatusIs() can be used as -// a Matcher. - -class StatusIsMatcherCommonImpl { - public: - StatusIsMatcherCommonImpl( - ::testing::Matcher code_matcher, - ::testing::Matcher message_matcher) - : code_matcher_(std::move(code_matcher)), - message_matcher_(std::move(message_matcher)) {} - - void DescribeTo(std::ostream* os) const; - - void DescribeNegationTo(std::ostream* os) const; - - bool MatchAndExplain(const absl::Status& status, - ::testing::MatchResultListener* result_listener) const; - - private: - const ::testing::Matcher code_matcher_; - const ::testing::Matcher message_matcher_; -}; - -// Monomorphic implementation of matcher StatusIs() for a given type T. T can -// be Status, StatusOr<>, or a reference to either of them. -template -class MonoStatusIsMatcherImpl : public ::testing::MatcherInterface { - public: - explicit MonoStatusIsMatcherImpl(StatusIsMatcherCommonImpl common_impl) - : common_impl_(std::move(common_impl)) {} - - void DescribeTo(std::ostream* os) const override { - common_impl_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - common_impl_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - T actual_value, - ::testing::MatchResultListener* result_listener) const override { - return common_impl_.MatchAndExplain(GetStatus(actual_value), - result_listener); - } - - private: - StatusIsMatcherCommonImpl common_impl_; -}; - -// Implements StatusIs() as a polymorphic matcher. -class StatusIsMatcher { - public: - StatusIsMatcher(::testing::Matcher code_matcher, - ::testing::Matcher message_matcher) - : common_impl_( - ::testing::MatcherCast(code_matcher), - ::testing::MatcherCast(message_matcher)) {} - - // Converts this polymorphic matcher to a monomorphic matcher of the given - // type. T can be StatusOr<>, Status, or a reference to either of them. - template - operator ::testing::Matcher() const { // NOLINT - return ::testing::MakeMatcher(new MonoStatusIsMatcherImpl(common_impl_)); - } - - private: - const StatusIsMatcherCommonImpl common_impl_; -}; - -// Monomorphic implementation of matcher IsOk() for a given type T. -// T can be Status, StatusOr<>, or a reference to either of them. -template -class MonoIsOkMatcherImpl : public ::testing::MatcherInterface { - public: - void DescribeTo(std::ostream* os) const override { *os << "is OK"; } - void DescribeNegationTo(std::ostream* os) const override { - *os << "is not OK"; - } - bool MatchAndExplain(T actual_value, - ::testing::MatchResultListener*) const override { - return GetStatus(actual_value).ok(); - } -}; - -// Implements IsOk() as a polymorphic matcher. -class IsOkMatcher { - public: - template - operator ::testing::Matcher() const { // NOLINT - return ::testing::Matcher(new MonoIsOkMatcherImpl()); - } -}; -} // namespace internal_status - -// Returns a matcher that matches a StatusOr<> whose status is OK and whose -// value matches the inner matcher. -template -internal_status::IsOkAndHoldsMatcher::type> -IsOkAndHolds(InnerMatcher&& inner_matcher) { - return internal_status::IsOkAndHoldsMatcher< - typename std::decay::type>( - std::forward(inner_matcher)); -} - -// Returns a matcher that matches a Status or StatusOr<> whose status code -// matches code_matcher, and whose error message matches message_matcher. -template -internal_status::StatusIsMatcher StatusIs(CodeMatcher code_matcher, - MessageMatcher message_matcher) { - return internal_status::StatusIsMatcher(std::move(code_matcher), - std::move(message_matcher)); -} -// Remove this specialization when tensorflow::Status is absl::Status -template -internal_status::StatusIsMatcher StatusIs(tensorflow::error::Code code_matcher, - MessageMatcher message_matcher) { - return internal_status::StatusIsMatcher( - static_cast(code_matcher), std::move(message_matcher)); -} - -// Returns a matcher that matches a Status or StatusOr<> whose status code -// matches code_matcher. -template -internal_status::StatusIsMatcher StatusIs(CodeMatcher code_matcher) { - return StatusIs(std::move(code_matcher), ::testing::_); -} -// Remove this specialization when tensorflow::Status is absl::Status -template <> -inline internal_status::StatusIsMatcher StatusIs( - tensorflow::error::Code code_matcher) { - return StatusIs(static_cast(code_matcher), ::testing::_); -} - -// Returns a matcher that matches a Status or StatusOr<> which is OK. -inline internal_status::IsOkMatcher IsOk() { - return internal_status::IsOkMatcher(); -} - -} // namespace testing -} // namespace tsl +#include "xla/tsl/platform/status_matchers.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUS_MATCHERS_H_ diff --git a/tsl/platform/status_matchers_test.cc b/tsl/platform/status_matchers_test.cc deleted file mode 100644 index 3a681f6f3..000000000 --- a/tsl/platform/status_matchers_test.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tsl/platform/status_matchers.h" - -#include -#include -#include - -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -namespace tsl { -namespace testing { -namespace { - -using ::testing::_; -using ::testing::ElementsAre; -using ::testing::HasSubstr; -using ::testing::Matcher; -using ::testing::MatchesRegex; -using ::testing::Ne; -using ::testing::Not; -using ::testing::PrintToString; - -// Matches a value less than the given upper bound. This matcher is chatty (it -// always explains the match result with some detail), and thus is useful for -// testing that an outer matcher correctly incorporates an inner matcher's -// explanation. -MATCHER_P(LessThan, upper, "") { - if (arg < upper) { - *result_listener << "which is " << (upper - arg) << " less than " << upper; - return true; - } - *result_listener << "which is " << (arg - upper) << " more than " << upper; - return false; -} - -// Returns the description of the given matcher. -template -std::string Describe(const Matcher& matcher) { - std::stringstream ss; - matcher.DescribeTo(&ss); - return ss.str(); -} - -// Returns the description of the negation of the given matcher. -template -std::string DescribeNegation(const Matcher& matcher) { - std::stringstream ss; - matcher.DescribeNegationTo(&ss); - return ss.str(); -} - -// Returns the explanation on the result of using the given matcher to -// match the given value. -template -std::string ExplainMatch(const Matcher& matcher, const V& value) { - ::testing::StringMatchResultListener listener; - matcher.MatchAndExplain(value, &listener); - return listener.str(); -} - -TEST(IsOkAndHoldsTest, MatchesValue) { - absl::StatusOr status_or_message("Hello, world"); - EXPECT_THAT(status_or_message, IsOkAndHolds("Hello, world")); - EXPECT_THAT(status_or_message, IsOkAndHolds(HasSubstr("Hello,"))); -} - -TEST(IsOkAndHoldsTest, MatchesContainer) { - absl::StatusOr> status_or_messages = - std::vector{"Hello, world", "Hello, tf"}; - EXPECT_THAT(status_or_messages, - IsOkAndHolds(ElementsAre("Hello, world", "Hello, tf"))); - EXPECT_THAT(status_or_messages, - IsOkAndHolds(ElementsAre(HasSubstr("world"), HasSubstr("tf")))); -} - -TEST(IsOkAndHoldsTest, DoesNotMatchStatus) { - absl::StatusOr status_or_message = - errors::InvalidArgument("Invalid argument"); - EXPECT_THAT(status_or_message, Not(IsOkAndHolds("Hello, world"))); -} - -TEST(IsOkAndHoldsTest, DoesNotMatchValue) { - absl::StatusOr status_or_message("Hello, tf"); - EXPECT_THAT(status_or_message, Not(IsOkAndHolds("Hello, world"))); -} - -TEST(IsOkAndHoldsTest, DoesNotMatchContainer) { - absl::StatusOr> status_or_container({1, 2, 3}); - EXPECT_THAT(status_or_container, Not(IsOkAndHolds(ElementsAre(4, 5, 6)))); -} - -TEST(IsOkAndHoldsTest, DescribeExpectedValue) { - Matcher> is_ok_and_has_substr = - IsOkAndHolds(HasSubstr("Hello")); - EXPECT_EQ(Describe(is_ok_and_has_substr), - "is OK and has a value that has substring \"Hello\""); - EXPECT_EQ(DescribeNegation(is_ok_and_has_substr), - "isn't OK or has a value that has no substring \"Hello\""); -} - -TEST(IsOkAndHoldsTest, ExplainNotMatchingStatus) { - Matcher> is_ok_and_less_than = - IsOkAndHolds(LessThan(100)); - absl::StatusOr status = errors::Unknown("Unknown"); - EXPECT_THAT(ExplainMatch(is_ok_and_less_than, status), - HasSubstr("which has status UNKNOWN: Unknown")); -} - -TEST(IsOkAndHoldsTest, ExplainNotMatchingValue) { - Matcher> is_ok_and_less_than = - IsOkAndHolds(LessThan(100)); - EXPECT_EQ(ExplainMatch(is_ok_and_less_than, 120), - "which contains value 120, which is 20 more than 100"); -} - -TEST(IsOkAndHoldsTest, ExplainNotMatchingContainer) { - Matcher>> is_ok_and_less_than = - IsOkAndHolds(ElementsAre(1, 2, 3)); - std::vector actual{4, 5, 6}; - EXPECT_THAT(ExplainMatch(is_ok_and_less_than, actual), - HasSubstr("which contains value " + PrintToString(actual))); -} - -TEST(StatusIsTest, MatchesOK) { - EXPECT_THAT(absl::OkStatus(), StatusIs(error::OK)); - absl::StatusOr message("Hello, world"); - EXPECT_THAT(message, StatusIs(error::OK)); -} - -TEST(StatusIsTest, DoesNotMatchOk) { - EXPECT_THAT(errors::DeadlineExceeded("Deadline exceeded"), - Not(StatusIs(error::OK))); - absl::StatusOr status = errors::NotFound("Not found"); - EXPECT_THAT(status, Not(StatusIs(error::OK))); -} - -TEST(StatusIsTest, MatchesStatus) { - absl::Status s = errors::Cancelled("Cancelled"); - EXPECT_THAT(s, StatusIs(error::CANCELLED)); - EXPECT_THAT(s, StatusIs(error::CANCELLED, "Cancelled")); - EXPECT_THAT(s, StatusIs(_, "Cancelled")); - EXPECT_THAT(s, StatusIs(error::CANCELLED, _)); - EXPECT_THAT(s, StatusIs(Ne(error::INVALID_ARGUMENT), _)); - EXPECT_THAT(s, StatusIs(error::CANCELLED, HasSubstr("Can"))); - EXPECT_THAT(s, StatusIs(error::CANCELLED, MatchesRegex("Can.*"))); -} - -TEST(StatusIsTest, StatusOrMatchesStatus) { - absl::StatusOr s = errors::InvalidArgument("Invalid Argument"); - EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT)); - EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT, "Invalid Argument")); - EXPECT_THAT(s, StatusIs(_, "Invalid Argument")); - EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT, _)); - EXPECT_THAT(s, StatusIs(Ne(error::CANCELLED), _)); - EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT, HasSubstr("Argument"))); - EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT, MatchesRegex(".*Argument"))); -} - -TEST(StatusIsTest, DoesNotMatchStatus) { - absl::Status s = errors::Internal("Internal"); - EXPECT_THAT(s, Not(StatusIs(error::FAILED_PRECONDITION))); - EXPECT_THAT(s, Not(StatusIs(error::INTERNAL, "Failed Precondition"))); - EXPECT_THAT(s, Not(StatusIs(_, "Failed Precondition"))); - EXPECT_THAT(s, Not(StatusIs(error::FAILED_PRECONDITION, _))); -} - -TEST(StatusIsTest, StatusOrDoesNotMatchStatus) { - absl::StatusOr s = errors::FailedPrecondition("Failed Precondition"); - EXPECT_THAT(s, Not(StatusIs(error::INTERNAL))); - EXPECT_THAT(s, Not(StatusIs(error::FAILED_PRECONDITION, "Internal"))); - EXPECT_THAT(s, Not(StatusIs(_, "Internal"))); - EXPECT_THAT(s, Not(StatusIs(error::INTERNAL, _))); -} - -TEST(StatusIsTest, DescribeExpectedValue) { - Matcher status_is = - StatusIs(error::UNAVAILABLE, std::string("Unavailable")); - EXPECT_EQ(Describe(status_is), - "has a status code that is equal to UNAVAILABLE, " - "and has an error message that is equal to \"Unavailable\""); -} - -TEST(StatusIsTest, DescribeNegatedExpectedValue) { - Matcher> status_is = - StatusIs(error::ABORTED, std::string("Aborted")); - EXPECT_EQ(DescribeNegation(status_is), - "has a status code that isn't equal to ABORTED, " - "or has an error message that isn't equal to \"Aborted\""); -} - -TEST(StatusIsTest, ExplainNotMatchingErrorCode) { - Matcher status_is = StatusIs(error::NOT_FOUND, _); - const absl::Status status = errors::AlreadyExists("Already exists"); - EXPECT_EQ(ExplainMatch(status_is, status), "whose status code is wrong"); -} - -TEST(StatusIsTest, ExplainNotMatchingErrorMessage) { - Matcher status_is = StatusIs(error::NOT_FOUND, "Not found"); - const absl::Status status = errors::NotFound("Already exists"); - EXPECT_EQ(ExplainMatch(status_is, status), "whose error message is wrong"); -} - -TEST(StatusIsTest, ExplainStatusOrNotMatchingErrorCode) { - Matcher> status_is = StatusIs(error::ALREADY_EXISTS, _); - const absl::StatusOr status_or = errors::NotFound("Not found"); - EXPECT_EQ(ExplainMatch(status_is, status_or), "whose status code is wrong"); -} - -TEST(StatusIsTest, ExplainStatusOrNotMatchingErrorMessage) { - Matcher> status_is = - StatusIs(error::ALREADY_EXISTS, "Already exists"); - const absl::StatusOr status_or = errors::AlreadyExists("Not found"); - EXPECT_EQ(ExplainMatch(status_is, status_or), "whose error message is wrong"); -} - -TEST(StatusIsTest, ExplainStatusOrHasValue) { - Matcher> status_is = - StatusIs(error::RESOURCE_EXHAUSTED, "Resource exhausted"); - const absl::StatusOr value = -1; - EXPECT_EQ(ExplainMatch(status_is, value), "whose status code is wrong"); -} - -TEST(IsOkTest, MatchesOK) { - EXPECT_THAT(absl::OkStatus(), IsOk()); - absl::StatusOr message = std::string("Hello, world"); - EXPECT_THAT(message, IsOk()); -} - -TEST(IsOkTest, DoesNotMatchOK) { - EXPECT_THAT(errors::PermissionDenied("Permission denied"), Not(IsOk())); - absl::StatusOr status = - errors::Unauthenticated("Unauthenticated"); - EXPECT_THAT(status, Not(IsOk())); -} - -TEST(IsOkTest, DescribeExpectedValue) { - Matcher status_is_ok = IsOk(); - EXPECT_EQ(Describe(status_is_ok), "is OK"); - Matcher> status_or_is_ok = IsOk(); - EXPECT_EQ(Describe(status_or_is_ok), "is OK"); -} - -TEST(IsOkTest, DescribeNegatedExpectedValue) { - Matcher status_is_ok = IsOk(); - EXPECT_EQ(DescribeNegation(status_is_ok), "is not OK"); - Matcher> status_or_is_ok = IsOk(); - EXPECT_EQ(DescribeNegation(status_or_is_ok), "is not OK"); -} - -} // namespace -} // namespace testing -} // namespace tsl diff --git a/tsl/platform/status_test.cc b/tsl/platform/status_test.cc deleted file mode 100644 index e716a15b9..000000000 --- a/tsl/platform/status_test.cc +++ /dev/null @@ -1,229 +0,0 @@ - -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tsl/platform/status.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/str_format.h" -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "xla/tsl/protobuf/status.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/stack_frame.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/status_to_from_proto.h" -#include "tsl/platform/test.h" - -namespace tsl { -namespace { - -using ::testing::ElementsAre; -using ::testing::IsEmpty; -using ::testing::Pair; -using ::tsl::testing::IsOk; -using ::tsl::testing::StatusIs; - -TEST(ToStringTest, PayloadsArePrinted) { - absl::Status status = errors::Aborted("Aborted Error Message"); - status.SetPayload("payload_key", absl::Cord(absl::StrFormat( - "payload_value %c%c%c", 1, 2, 3))); - - EXPECT_EQ(status.ToString(), - "ABORTED: Aborted Error Message [payload_key='payload_value " - "\\x01\\x02\\x03']"); -} - -TEST(ToStringTest, MatchesAbslStatus) { - absl::Status status = errors::Aborted("Aborted Error Message"); - status.SetPayload("payload_key", absl::Cord(absl::StrFormat( - "payload_value %c%c%c", 1, 2, 3))); - - absl::Status absl_status = - absl::Status(absl::StatusCode::kAborted, status.message()); - absl_status.SetPayload("payload_key", absl::Cord(absl::StrFormat( - "payload_value %c%c%c", 1, 2, 3))); - - EXPECT_EQ(status.ToString(), absl_status.ToString()); -} - -TEST(StackTrace, SerializeAndDeserializeCorrectly) { - absl::Status status = errors::Aborted("Aborted Error Message"); - std::vector stack_trace; - stack_trace.push_back(StackFrame("filename_1", 33, "func_name_1")); - stack_trace.push_back(StackFrame("filename_2", 66, "func_name_2")); - errors::SetStackTrace(status, stack_trace); - - std::vector deserialized = errors::GetStackTrace(status); - - EXPECT_EQ(stack_trace.size(), deserialized.size()); - for (size_t i = 0; i < stack_trace.size(); ++i) { - EXPECT_EQ(stack_trace[i], deserialized[i]); - } -} - -TEST(StatusGroupTest, DeterministicOrderWithoutPayloads) { - absl::Status status_a = errors::Aborted("Status A"); - absl::Status status_b = errors::Aborted("Status B"); - absl::Status status_c = errors::Aborted("Status C"); - - absl::Status combined = - StatusGroup({status_a, status_b, status_c}).as_summary_status(); - - EXPECT_EQ(combined, - StatusGroup({status_a, status_b, status_c}).as_summary_status()); - EXPECT_EQ(combined, - StatusGroup({status_a, status_c, status_b}).as_summary_status()); - EXPECT_EQ(combined, - StatusGroup({status_b, status_a, status_c}).as_summary_status()); - EXPECT_EQ(combined, - StatusGroup({status_b, status_c, status_a}).as_summary_status()); - EXPECT_EQ(combined, - StatusGroup({status_c, status_a, status_b}).as_summary_status()); - EXPECT_EQ(combined, - StatusGroup({status_c, status_b, status_a}).as_summary_status()); -} - -TEST(StatusGroupTest, DeterministicOrderWithPayloads) { - absl::Status status_a = errors::Aborted("Status A"); - status_a.SetPayload("payload_key", absl::Cord("payload_value_a")); - absl::Status status_b = errors::Aborted("Status B"); - status_b.SetPayload("payload_key", absl::Cord("payload_value_b")); - absl::Status status_c = errors::Aborted("Status C"); - status_c.SetPayload("payload_key", absl::Cord("payload_value_c")); - - absl::Status combined = - StatusGroup({status_a, status_b, status_c}).as_summary_status(); - ASSERT_TRUE(combined.GetPayload("payload_key").has_value()); - std::string payload(combined.GetPayload("payload_key").value()); - - EXPECT_EQ(payload, StatusGroup({status_a, status_b, status_c}) - .as_summary_status() - .GetPayload("payload_key")); - EXPECT_EQ(payload, StatusGroup({status_a, status_c, status_b}) - .as_summary_status() - .GetPayload("payload_key")); - EXPECT_EQ(payload, StatusGroup({status_b, status_a, status_c}) - .as_summary_status() - .GetPayload("payload_key")); - EXPECT_EQ(payload, StatusGroup({status_b, status_c, status_a}) - .as_summary_status() - .GetPayload("payload_key")); - EXPECT_EQ(payload, StatusGroup({status_c, status_a, status_b}) - .as_summary_status() - .GetPayload("payload_key")); - EXPECT_EQ(payload, StatusGroup({status_c, status_b, status_a}) - .as_summary_status() - .GetPayload("payload_key")); -} - -TEST(StatusGroupTest, PayloadsMergedProperly) { - absl::Status status_a = errors::Aborted("Status A"); - status_a.SetPayload("payload_key_a", - absl::Cord(std::string("payload_value_a"))); - absl::Status status_b = errors::Aborted("Status B"); - status_b.SetPayload("payload_key_b", - absl::Cord(std::string("payload_value_b"))); - absl::Status status_c = errors::Aborted("Status C"); - status_c.SetPayload("payload_key_c", - absl::Cord(std::string("payload_value_c"))); - absl::Status derived_status_c = - StatusGroup::MakeDerived(errors::Aborted("Status C")); - derived_status_c.SetPayload( - "payload_key_c", absl::Cord(std::string("derived_payload_value_c"))); - - StatusGroup status_group({status_a, status_b, status_c, derived_status_c}); - EXPECT_THAT(status_group.GetPayloads(), ::testing::SizeIs(3)); - - absl::Status combined = status_group.as_summary_status(); - EXPECT_EQ(combined.GetPayload("payload_key_a"), "payload_value_a"); - EXPECT_EQ(combined.GetPayload("payload_key_b"), "payload_value_b"); - EXPECT_EQ(combined.GetPayload("payload_key_c"), "payload_value_c"); -} - -TEST(Status, ErrorStatusForEachPayloadIteratesOverAll) { - absl::Status s(absl::StatusCode::kInternal, "Error message"); - s.SetPayload("key1", absl::Cord("value1")); - s.SetPayload("key2", absl::Cord("value2")); - s.SetPayload("key3", absl::Cord("value3")); - - std::unordered_map payloads; - s.ForEachPayload([&payloads](absl::string_view key, const absl::Cord& value) { - payloads[std::string(key)] = value; - }); - - EXPECT_EQ(payloads.size(), 3); - EXPECT_EQ(payloads["key1"], "value1"); - EXPECT_EQ(payloads["key2"], "value2"); - EXPECT_EQ(payloads["key3"], "value3"); -} - -TEST(Status, OkStatusForEachPayloadNoIteration) { - absl::Status s = absl::OkStatus(); - s.SetPayload("key1", absl::Cord("value1")); - s.SetPayload("key2", absl::Cord("value2")); - s.SetPayload("key3", absl::Cord("value3")); - - std::unordered_map payloads; - s.ForEachPayload([&payloads](absl::string_view key, const absl::Cord& value) { - payloads[std::string(key)] = value; - }); - - EXPECT_EQ(payloads.size(), 0); -} - -TEST(Status, SaveOKStatusToProto) { - tensorflow::StatusProto status_proto = StatusToProto(absl::OkStatus()); - EXPECT_EQ(status_proto.code(), error::OK); - EXPECT_THAT(status_proto.message(), IsEmpty()); -} - -TEST(Status, SaveErrorStatusToProto) { - tensorflow::StatusProto status_proto = StatusToProto(errors::Create( - absl::StatusCode::kNotFound, "Not found", {{"foo", "bar"}})); - EXPECT_EQ(status_proto.code(), error::NOT_FOUND); - EXPECT_EQ(status_proto.message(), "Not found"); - EXPECT_THAT(status_proto.payload(), ElementsAre(Pair("foo", "bar"))); -} - -TEST(Status, SaveEmptyStatusToProto) { - tensorflow::StatusProto status_proto = StatusToProto(absl::Status()); - EXPECT_EQ(status_proto.code(), error::OK); - EXPECT_THAT(status_proto.message(), IsEmpty()); - EXPECT_THAT(status_proto.payload(), IsEmpty()); -} - -TEST(Status, MakeOKStatusFromProto) { - tensorflow::StatusProto status_proto; - status_proto.set_code(error::OK); - EXPECT_THAT(StatusFromProto(status_proto), IsOk()); -} - -TEST(Status, MakeErrorStatusFromProto) { - tensorflow::StatusProto status_proto; - status_proto.set_code(error::INVALID_ARGUMENT); - status_proto.set_message("Invalid argument"); - status_proto.mutable_payload()->insert({"foo", "bar"}); - absl::Status s = StatusFromProto(status_proto); - EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT, "Invalid argument")); - EXPECT_EQ(s.GetPayload("foo"), "bar"); -} - -TEST(Status, MakeStatusFromEmptyProto) { - EXPECT_THAT(StatusFromProto(tensorflow::StatusProto()), IsOk()); -} - -} // namespace -} // namespace tsl diff --git a/tsl/platform/status_to_from_proto.cc b/tsl/platform/status_to_from_proto.cc deleted file mode 100644 index 54e2b2ef3..000000000 --- a/tsl/platform/status_to_from_proto.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tsl/platform/status_to_from_proto.h" - -#include - -#include "absl/strings/cord.h" -#include "absl/strings/string_view.h" -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "xla/tsl/protobuf/status.pb.h" -#include "tsl/platform/status.h" - -namespace tsl { - -tensorflow::StatusProto StatusToProto(const absl::Status& s) { - tensorflow::StatusProto status_proto; - if (s.ok()) { - return status_proto; - } - - status_proto.set_code(static_cast(s.code())); - if (!s.message().empty()) { - status_proto.set_message(std::string(s.message())); - } - - s.ForEachPayload( - [&status_proto](absl::string_view type_url, absl::Cord value) { - status_proto.mutable_payload()->insert( - {std::string(type_url), std::string(value)}); - }); - return status_proto; -} - -#if defined(PLATFORM_GOOGLE) -absl::Status StatusFromProto(const tensorflow::StatusProto& proto, - absl::SourceLocation loc) { - if (proto.code() == tensorflow::error::OK) { - return absl::OkStatus(); - } - absl::Status s(static_cast(proto.code()), proto.message(), - loc); - for (const auto& [key, payload] : proto.payload()) { - s.SetPayload(key, absl::Cord(payload)); - } - return s; -} -#else -Status StatusFromProto(const tensorflow::StatusProto& proto) { - if (proto.code() == tensorflow::error::OK) { - return OkStatus(); - } - Status s(static_cast(proto.code()), proto.message()); - for (const auto& [key, payload] : proto.payload()) { - s.SetPayload(key, absl::Cord(payload)); - } - return s; -} -#endif - -} // namespace tsl diff --git a/tsl/platform/status_to_from_proto.h b/tsl/platform/status_to_from_proto.h index 021e002ae..89b0de803 100644 --- a/tsl/platform/status_to_from_proto.h +++ b/tsl/platform/status_to_from_proto.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,32 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ -#include "xla/tsl/protobuf/status.pb.h" -#include "tsl/platform/status.h" - -namespace tsl { - -// TODO(b/250921378): Merge this file with `status.h` once we figure out how to -// fix the following error with the MacOS build: -// -// ImportError: -// dlopen(/org_tensorflow/tensorflow/python/platform/_pywrap_tf2.so, 2): -// Symbol not found: tensorflow11StatusProtoC1EPN6protobuf5ArenaEb - -// Converts a `Status` to a `StatusProto`. -tensorflow::StatusProto StatusToProto(const absl::Status& s); - -#if defined(PLATFORM_GOOGLE) -// Constructs a `Status` from a `StatusProto`. -absl::Status StatusFromProto( - const tensorflow::StatusProto& proto, - absl::SourceLocation loc = absl::SourceLocation::current()); -#else -Status StatusFromProto(const tensorflow::StatusProto& proto); -#endif -} // namespace tsl +#include "xla/tsl/platform/status_to_from_proto.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ diff --git a/tsl/platform/statusor.h b/tsl/platform/statusor.h index ac27ede31..c4e6da372 100644 --- a/tsl/platform/statusor.h +++ b/tsl/platform/statusor.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,99 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T object. StatusOr models -// the concept of an object that is either a value, or an error Status -// explaining why such a value is not present. To this end, StatusOr does not -// allow its Status value to be Status::OK. -// -// The primary use-case for StatusOr is as the return value of a -// function which may fail. -// -// Example client usage for a StatusOr, where T is not a pointer: -// -// StatusOr result = DoBigCalculationThatCouldFail(); -// if (result.ok()) { -// float answer = result.value(); -// printf("Big calculation yielded: %f", answer); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr: -// -// StatusOr result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo(result.value()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr>: -// -// StatusOr> result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo = std::move(result.value()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example factory implementation returning StatusOr: -// -// StatusOr FooFactory::MakeNewFoo(int arg) { -// if (arg <= 0) { -// return tsl::InvalidArgument("Arg must be positive"); -// } else { -// return new Foo(arg); -// } -// } -// -// Note that the assignment operators require that destroying the currently -// stored value cannot invalidate the argument; in other words, the argument -// cannot be an alias for the current value, or anything owned by the current -// value. #ifndef TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ #define TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/status/statusor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/status.h" - -// Include appropriate platform-dependent `TF_ASSIGN_OR_RETURN`. -#if defined(PLATFORM_GOOGLE) -#include "xla/tsl/platform/google/statusor.h" // IWYU pragma: export -#else -#include "xla/tsl/platform/default/statusor.h" // IWYU pragma: export -#endif - -// TODO: b/323943471 - This macro should eventually be provided by Abseil. -#ifndef ABSL_DEPRECATE_AND_INLINE -#define ABSL_DEPRECATE_AND_INLINE() -#endif - -namespace tsl { - -template -using StatusOr ABSL_DEPRECATE_AND_INLINE() = absl::StatusOr; - -} // namespace tsl - -#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ - TF_ASSERT_OK_AND_ASSIGN_IMPL( \ - TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ - rexpr); - -#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ - lhs = std::move(statusor).value() - -#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) -#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y +#include "xla/tsl/platform/statusor.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ diff --git a/tsl/platform/statusor_test.cc b/tsl/platform/statusor_test.cc deleted file mode 100644 index fd0ee7886..000000000 --- a/tsl/platform/statusor_test.cc +++ /dev/null @@ -1,742 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Unit tests for StatusOr - -#include "tsl/platform/statusor.h" - -#include -#include -#include -#include - -#include "absl/base/config.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" - -namespace tsl { -namespace { - -class Base1 { - public: - virtual ~Base1() {} - int pad_; -}; - -class Base2 { - public: - virtual ~Base2() {} - int yetotherpad_; -}; - -class Derived : public Base1, public Base2 { - public: - ~Derived() override {} - int evenmorepad_; -}; - -class CopyNoAssign { - public: - explicit CopyNoAssign(int value) : foo_(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} - int foo_; - - private: - const CopyNoAssign& operator=(const CopyNoAssign&); -}; - -class NoDefaultConstructor { - public: - explicit NoDefaultConstructor(int foo); -}; - -static_assert(!std::is_default_constructible(), - "Should not be default-constructible."); - -absl::StatusOr> ReturnUniquePtr() { - // Uses implicit constructor from T&& - return std::unique_ptr(new int(0)); -} - -TEST(StatusOr, NullPointerStatusOr) { - // As a very special case, null-plain-pointer StatusOr used to be an - // error. Test that it no longer is. - absl::StatusOr null_status(nullptr); - EXPECT_TRUE(null_status.ok()); - EXPECT_EQ(null_status.value(), nullptr); -} - -TEST(StatusOr, TestNoDefaultConstructorInitialization) { - // Explicitly initialize it with an error code. - absl::StatusOr statusor(errors::Cancelled("")); - EXPECT_FALSE(statusor.ok()); - EXPECT_EQ(statusor.status().code(), absl::StatusCode::kCancelled); - - // Default construction of StatusOr initializes it with an UNKNOWN error code. - absl::StatusOr statusor2; - EXPECT_FALSE(statusor2.ok()); - EXPECT_EQ(statusor2.status().code(), absl::StatusCode::kUnknown); -} - -TEST(StatusOr, TestMoveOnlyInitialization) { - absl::StatusOr> thing(ReturnUniquePtr()); - ASSERT_TRUE(thing.ok()); - EXPECT_EQ(0, *thing.value()); - int* previous = thing.value().get(); - - thing = ReturnUniquePtr(); - EXPECT_TRUE(thing.ok()); - EXPECT_EQ(0, *thing.value()); - EXPECT_NE(previous, thing.value().get()); -} - -TEST(StatusOr, TestMoveOnlyStatusCtr) { - absl::StatusOr> thing(errors::Cancelled("")); - ASSERT_FALSE(thing.ok()); -} - -TEST(StatusOr, TestMoveOnlyValueExtraction) { - absl::StatusOr> thing(ReturnUniquePtr()); - ASSERT_TRUE(thing.ok()); - std::unique_ptr ptr = std::move(thing).value(); - EXPECT_EQ(0, *ptr); - - thing = std::move(ptr); - ptr = std::move(thing.value()); - EXPECT_EQ(0, *ptr); -} - -TEST(StatusOr, TestMoveOnlyConversion) { - absl::StatusOr> const_thing(ReturnUniquePtr()); - EXPECT_TRUE(const_thing.ok()); - EXPECT_EQ(0, *const_thing.value()); - - // Test rvalue converting assignment - const int* const_previous = const_thing.value().get(); - const_thing = ReturnUniquePtr(); - EXPECT_TRUE(const_thing.ok()); - EXPECT_EQ(0, *const_thing.value()); - EXPECT_NE(const_previous, const_thing.value().get()); -} - -TEST(StatusOr, TestMoveOnlyVector) { - // Sanity check that StatusOr works in vector. - std::vector>> vec; - vec.push_back(ReturnUniquePtr()); - vec.resize(2); - auto another_vec = std::move(vec); - EXPECT_EQ(0, *another_vec[0].value()); - EXPECT_EQ(absl::StatusCode::kUnknown, another_vec[1].status().code()); -} - -TEST(StatusOr, TestMoveWithValuesAndErrors) { - absl::StatusOr status_or(std::string(1000, '0')); - absl::StatusOr value1(std::string(1000, '1')); - absl::StatusOr value2(std::string(1000, '2')); - absl::StatusOr error1( - absl::Status(absl::StatusCode::kUnknown, "error1")); - absl::StatusOr error2( - absl::Status(absl::StatusCode::kUnknown, "error2")); - - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '0'), status_or.value()); - - // Overwrite the value in status_or with another value. - status_or = std::move(value1); - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '1'), status_or.value()); - - // Overwrite the value in status_or with an error. - status_or = std::move(error1); - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error1", status_or.status().message()); - - // Overwrite the error in status_or with another error. - status_or = std::move(error2); - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error2", status_or.status().message()); - - // Overwrite the error with a value. - status_or = std::move(value2); - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '2'), status_or.value()); -} - -TEST(StatusOr, TestCopyWithValuesAndErrors) { - absl::StatusOr status_or(std::string(1000, '0')); - absl::StatusOr value1(std::string(1000, '1')); - absl::StatusOr value2(std::string(1000, '2')); - absl::StatusOr error1( - absl::Status(absl::StatusCode::kUnknown, "error1")); - absl::StatusOr error2( - absl::Status(absl::StatusCode::kUnknown, "error2")); - - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '0'), status_or.value()); - - // Overwrite the value in status_or with another value. - status_or = value1; - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '1'), status_or.value()); - - // Overwrite the value in status_or with an error. - status_or = error1; - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error1", status_or.status().message()); - - // Overwrite the error in status_or with another error. - status_or = error2; - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error2", status_or.status().message()); - - // Overwrite the error with a value. - status_or = value2; - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '2'), status_or.value()); - - // Verify original values unchanged. - EXPECT_EQ(std::string(1000, '1'), value1.value()); - EXPECT_EQ("error1", error1.status().message()); - EXPECT_EQ("error2", error2.status().message()); - EXPECT_EQ(std::string(1000, '2'), value2.value()); -} - -TEST(StatusOr, TestDefaultCtor) { - absl::StatusOr thing; - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status().code(), absl::StatusCode::kUnknown); -} - -TEST(StatusOrDeathTest, TestDefaultCtorValue) { - absl::StatusOr thing; -#ifdef ABSL_HAVE_EXCEPTIONS - try { - thing.value(); - ADD_FAILURE() - << "value() returned successfully while the access is illegal"; - } catch (absl::BadStatusOrAccess& ex) { - } -#else - EXPECT_DEATH(thing.value(), ""); -#endif - - const absl::StatusOr thing2; -#ifdef ABSL_HAVE_EXCEPTIONS - try { - thing.value(); - ADD_FAILURE() - << "value() returned successfully while the access is illegal"; - } catch (absl::BadStatusOrAccess& ex) { - } -#else - EXPECT_DEATH(thing.value(), ""); -#endif -} - -TEST(StatusOr, TestStatusCtor) { - absl::StatusOr thing(absl::Status(absl::StatusCode::kCancelled, "")); - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status().code(), absl::StatusCode::kCancelled); -} - -TEST(StatusOr, TestValueCtor) { - const int kI = 4; - const absl::StatusOr thing(kI); - EXPECT_TRUE(thing.ok()); - EXPECT_EQ(kI, thing.value()); -} - -TEST(StatusOr, TestCopyCtorStatusOk) { - const int kI = 4; - const absl::StatusOr original(kI); - const absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.value(), copy.value()); -} - -TEST(StatusOr, TestCopyCtorStatusNotOk) { - absl::StatusOr original(absl::Status(absl::StatusCode::kCancelled, "")); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestCopyCtorNonAssignable) { - const int kI = 4; - CopyNoAssign value(kI); - absl::StatusOr original(value); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.value().foo_, copy.value().foo_); -} - -TEST(StatusOr, TestCopyCtorStatusOKConverting) { - const int kI = 4; - absl::StatusOr original(kI); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_DOUBLE_EQ(original.value(), copy.value()); -} - -TEST(StatusOr, TestCopyCtorStatusNotOkConverting) { - absl::StatusOr original(absl::Status(absl::StatusCode::kCancelled, "")); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestAssignmentStatusOk) { - const int kI = 4; - absl::StatusOr source(kI); - absl::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); - EXPECT_EQ(source.value(), target.value()); -} - -TEST(StatusOr, TestAssignmentStatusNotOk) { - absl::StatusOr source(absl::Status(absl::StatusCode::kCancelled, "")); - absl::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); -} - -TEST(StatusOr, TestStatus) { - absl::StatusOr good(4); - EXPECT_TRUE(good.ok()); - absl::StatusOr bad(absl::Status(absl::StatusCode::kCancelled, "")); - EXPECT_FALSE(bad.ok()); - EXPECT_EQ(bad.status(), absl::Status(absl::StatusCode::kCancelled, "")); -} - -TEST(StatusOr, TestValue) { - const int kI = 4; - absl::StatusOr thing(kI); - EXPECT_EQ(kI, thing.value()); -} - -TEST(StatusOr, TestValueConst) { - const int kI = 4; - const absl::StatusOr thing(kI); - EXPECT_EQ(kI, thing.value()); -} - -TEST(StatusOrDeathTest, TestValueNotOk) { - absl::StatusOr thing( - absl::Status(absl::StatusCode::kCancelled, "cancelled")); -#ifdef ABSL_HAVE_EXCEPTIONS - try { - thing.value(); - ADD_FAILURE() - << "value() returned successfully while the access is illegal"; - } catch (absl::BadStatusOrAccess& ex) { - } -#else - EXPECT_DEATH(thing.value(), "cancelled"); -#endif -} - -TEST(StatusOrDeathTest, TestValueNotOkConst) { - const absl::StatusOr thing(absl::Status(absl::StatusCode::kUnknown, "")); -#ifdef ABSL_HAVE_EXCEPTIONS - try { - thing.value(); - ADD_FAILURE() - << "value() returned successfully while the access is illegal"; - } catch (absl::BadStatusOrAccess& ex) { - } -#else - EXPECT_DEATH(thing.value(), ""); -#endif -} - -TEST(StatusOr, TestPointerDefaultCtor) { - absl::StatusOr thing; - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status().code(), absl::StatusCode::kUnknown); -} - -TEST(StatusOrDeathTest, TestPointerDefaultCtorValue) { - absl::StatusOr thing; -#ifdef ABSL_HAVE_EXCEPTIONS - try { - thing.value(); - ADD_FAILURE() - << "value() returned successfully while the access is illegal"; - } catch (absl::BadStatusOrAccess& ex) { - } -#else - EXPECT_DEATH(thing.value(), ""); -#endif -} - -TEST(StatusOr, TestPointerStatusCtor) { - absl::StatusOr thing(absl::Status(absl::StatusCode::kCancelled, "")); - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status(), absl::Status(absl::StatusCode::kCancelled, "")); -} - -TEST(StatusOr, TestPointerValueCtor) { - const int kI = 4; - absl::StatusOr thing(&kI); - EXPECT_TRUE(thing.ok()); - EXPECT_EQ(&kI, thing.value()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusOk) { - const int kI = 0; - absl::StatusOr original(&kI); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.value(), copy.value()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusNotOk) { - absl::StatusOr original(absl::Status(absl::StatusCode::kCancelled, "")); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusOKConverting) { - Derived derived; - absl::StatusOr original(&derived); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(static_cast(original.value()), copy.value()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusNotOkConverting) { - absl::StatusOr original( - absl::Status(absl::StatusCode::kCancelled, "")); - absl::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestPointerAssignmentStatusOk) { - const int kI = 0; - absl::StatusOr source(&kI); - absl::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); - EXPECT_EQ(source.value(), target.value()); -} - -TEST(StatusOr, TestPointerAssignmentStatusNotOk) { - absl::StatusOr source(absl::Status(absl::StatusCode::kCancelled, "")); - absl::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); -} - -TEST(StatusOr, TestPointerStatus) { - const int kI = 0; - absl::StatusOr good(&kI); - EXPECT_TRUE(good.ok()); - absl::StatusOr bad( - absl::Status(absl::StatusCode::kCancelled, "")); - EXPECT_EQ(bad.status(), absl::Status(absl::StatusCode::kCancelled, "")); -} - -TEST(StatusOr, TestPointerValue) { - const int kI = 0; - absl::StatusOr thing(&kI); - EXPECT_EQ(&kI, thing.value()); -} - -TEST(StatusOr, TestPointerValueConst) { - const int kI = 0; - const absl::StatusOr thing(&kI); - EXPECT_EQ(&kI, thing.value()); -} - -TEST(StatusOr, TestArrowOperator) { - absl::StatusOr> uptr = ReturnUniquePtr(); - EXPECT_EQ(*uptr->get(), 0); -} - -TEST(StatusOr, TestStarOperator) { - absl::StatusOr> uptr = ReturnUniquePtr(); - EXPECT_EQ(**uptr, 0); -} - -TEST(StatusOr, TestStarOperatorDeath) { - absl::StatusOr error( - absl::Status(absl::StatusCode::kCancelled, "cancelled")); - EXPECT_DEATH(*error, "cancelled"); -} - -// NOTE(tucker): StatusOr does not support this kind -// of resize op. -// TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { -// using EvilType = std::vector>; -// static_assert(std::is_copy_constructible::value, ""); -// std::vector> v(5); -// v.reserve(v.capacity() + 10); -// } - -static absl::StatusOr MakeStatus() { return 100; } -// A factory to help us benchmark the various factory styles. All of -// the factory methods are marked as non-inlineable so as to more -// accurately simulate calling a factory for which you do not have -// visibility of implementation. Similarly, the value_ variable is -// marked volatile to prevent the compiler from getting too clever -// about detecting that the same value is used in all loop iterations. -template -class BenchmarkFactory { - public: - // Construct a new factory. Allocate an object which will always - // be the result of the factory methods. - BenchmarkFactory() : value_(new T) {} - - // Destroy this factory, including the result value. - ~BenchmarkFactory() { delete value_; } - - // A trivial factory that just returns the value. There is no status - // object that could be returned to encapsulate an error - T* TrivialFactory() TF_ATTRIBUTE_NOINLINE { return value_; } - - // A more sophisticated factory, which returns a status to indicate - // the result of the operation. The factory result is populated into - // the user provided pointer result. - absl::Status ArgumentFactory(T** result) TF_ATTRIBUTE_NOINLINE { - *result = value_; - return absl::OkStatus(); - } - - absl::Status ArgumentFactoryFail(T** result) TF_ATTRIBUTE_NOINLINE { - *result = nullptr; - return absl::Status(absl::StatusCode::kCancelled, ""); - } - - absl::Status ArgumentFactoryFailShortMsg(T** result) TF_ATTRIBUTE_NOINLINE { - *result = nullptr; - return absl::Status(absl::StatusCode::kInternal, ""); - } - - absl::Status ArgumentFactoryFailLongMsg(T** result) TF_ATTRIBUTE_NOINLINE { - *result = nullptr; - return absl::Status(absl::StatusCode::kInternal, - "a big string of message junk that will never be read"); - } - - // A factory that returns a StatusOr. If the factory operation - // is OK, then the StatusOr will hold a T*. Otherwise, it will - // hold a status explaining the error. - StatusOr StatusOrFactory() TF_ATTRIBUTE_NOINLINE { - return static_cast(value_); - } - - StatusOr StatusOrFactoryFail() TF_ATTRIBUTE_NOINLINE { - return absl::Status(absl::StatusCode::kCancelled, ""); - } - - StatusOr StatusOrFactoryFailShortMsg() TF_ATTRIBUTE_NOINLINE { - return absl::Status(absl::StatusCode::kInternal, ""); - } - - StatusOr StatusOrFactoryFailLongMsg() TF_ATTRIBUTE_NOINLINE { - return absl::Status(absl::StatusCode::kInternal, - "a big string of message junk that will never be read"); - } - - private: - T* volatile value_; - BenchmarkFactory(const BenchmarkFactory&) = delete; - void operator=(const BenchmarkFactory&) = delete; -}; - -// A simple type we use with the factory. -class BenchmarkType { - public: - BenchmarkType() {} - virtual ~BenchmarkType() {} - virtual void DoWork() TF_ATTRIBUTE_NOINLINE {} - - private: - BenchmarkType(const BenchmarkType&) = delete; - void operator=(const BenchmarkType&) = delete; -}; - -// Calibrate the amount of time spent just calling DoWork, since each of our -// tests will do this, we can subtract this out of benchmark results. -void BM_CalibrateWorkLoop(::testing::benchmark::State& state) { - BenchmarkFactory factory; - BenchmarkType* result = factory.TrivialFactory(); - for (auto s : state) { - if (result != nullptr) { - result->DoWork(); - } - } -} -BENCHMARK(BM_CalibrateWorkLoop); - -// Measure the time taken to call into the factory, return the value, -// determine that it is OK, and invoke a trivial function. -void BM_TrivialFactory(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - BenchmarkType* result = factory.TrivialFactory(); - if (result != nullptr) { - result->DoWork(); - } - } -} -BENCHMARK(BM_TrivialFactory); - -// Measure the time taken to call into the factory, providing an -// out-param for the result, evaluating the status result and the -// result pointer, and invoking the trivial function. -void BM_ArgumentFactory(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - BenchmarkType* result = nullptr; - absl::Status status = factory.ArgumentFactory(&result); - if (status.ok() && result != nullptr) { - result->DoWork(); - } - } -} -BENCHMARK(BM_ArgumentFactory); - -// Measure the time to use the StatusOr factory, evaluate the result, -// and invoke the trivial function. -void BM_StatusOrFactory(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - absl::StatusOr result = factory.StatusOrFactory(); - if (result.ok()) { - result.value()->DoWork(); - } - } -} -BENCHMARK(BM_StatusOrFactory); - -// Measure the time taken to call into the factory, providing an -// out-param for the result, evaluating the status result and the -// result pointer, and invoking the trivial function. -void BM_ArgumentFactoryFail(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - BenchmarkType* result = nullptr; - absl::Status status = factory.ArgumentFactoryFail(&result); - if (status.ok() && result != nullptr) { - result->DoWork(); - } - } -} -BENCHMARK(BM_ArgumentFactoryFail); - -// Measure the time to use the StatusOr factory, evaluate the result, -// and invoke the trivial function. -void BM_StatusOrFactoryFail(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - absl::StatusOr result = factory.StatusOrFactoryFail(); - if (result.ok()) { - result.value()->DoWork(); - } - } -} -BENCHMARK(BM_StatusOrFactoryFail); - -// Measure the time taken to call into the factory, providing an -// out-param for the result, evaluating the status result and the -// result pointer, and invoking the trivial function. -void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - BenchmarkType* result = nullptr; - absl::Status status = factory.ArgumentFactoryFailShortMsg(&result); - if (status.ok() && result != nullptr) { - result->DoWork(); - } - } -} -BENCHMARK(BM_ArgumentFactoryFailShortMsg); - -// Measure the time to use the StatusOr factory, evaluate the result, -// and invoke the trivial function. -void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - absl::StatusOr result = - factory.StatusOrFactoryFailShortMsg(); - if (result.ok()) { - result.value()->DoWork(); - } - } -} -BENCHMARK(BM_StatusOrFactoryFailShortMsg); - -// Measure the time taken to call into the factory, providing an -// out-param for the result, evaluating the status result and the -// result pointer, and invoking the trivial function. -void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - BenchmarkType* result = nullptr; - absl::Status status = factory.ArgumentFactoryFailLongMsg(&result); - if (status.ok() && result != nullptr) { - result->DoWork(); - } - } -} -BENCHMARK(BM_ArgumentFactoryFailLongMsg); - -// Measure the time to use the StatusOr factory, evaluate the result, -// and invoke the trivial function. -void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) { - BenchmarkFactory factory; - for (auto s : state) { - absl::StatusOr result = - factory.StatusOrFactoryFailLongMsg(); - if (result.ok()) { - result.value()->DoWork(); - } - } -} -BENCHMARK(BM_StatusOrFactoryFailLongMsg); - -#if defined(PLATFORM_GOOGLE) - -absl::StatusOr GetError() { - return absl::InvalidArgumentError("An invalid argument error"); -} - -absl::StatusOr PropagateError() { - TF_ASSIGN_OR_RETURN(int a, GetError()); - return a; -} - -absl::StatusOr PropagateError2() { - TF_ASSIGN_OR_RETURN(int a, PropagateError()); - return a; -} - -TEST(Status, StackTracePropagation) { - absl::StatusOr s = PropagateError2(); - auto sources = s.status().GetSourceLocations(); - ASSERT_EQ(sources.size(), 3); - - for (int i = 0; i < 3; ++i) { - ASSERT_EQ(sources[i].file_name(), - "third_party/tensorflow/tsl/platform/statusor_test.cc"); - } -} - -#endif - -} // namespace -} // namespace tsl diff --git a/tsl/platform/test.cc b/tsl/platform/test.cc deleted file mode 100644 index b2b2a8936..000000000 --- a/tsl/platform/test.cc +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/test.h" - -#include -#include -#include - -#include "tsl/platform/logging.h" -#include "tsl/platform/net.h" -#include "tsl/platform/path.h" - -namespace tsl { -namespace testing { -namespace { - -std::string GetEnvVarOrDie(const char* env_var) { - const char* value = std::getenv(env_var); - if (!value) { - LOG(FATAL) << "Failed to find environment variable:" << env_var; - } - return value; -} - -} // namespace - -std::string TmpDir() { - const char* tmp_dir = std::getenv("TEST_TMPDIR"); - if (!tmp_dir) { - tmp_dir = std::getenv("TMPDIR"); - } - if (tmp_dir) { - return tmp_dir; - } - LOG(FATAL) // Crash OK - << "Failed to find environment variables: TEST_TMPDIR, TMPDIR"; - - return tmp_dir; -} - -int PickUnusedPortOrDie() { return internal::PickUnusedPortOrDie(); } - -int RandomSeed() { - const char* random_seed_str = std::getenv("TEST_RANDOM_SEED"); - int seed; - if (random_seed_str && std::sscanf(random_seed_str, "%d", &seed) == 1) { - return seed; - } - return 301; -} - -std::string TensorFlowSrcRoot() { - std::string workspace = GetEnvVarOrDie("TEST_WORKSPACE"); - std::string srcdir = GetEnvVarOrDie("TEST_SRCDIR"); - - return kIsOpenSource - ? io::JoinPath(srcdir, workspace, "tensorflow") - : io::JoinPath(srcdir, workspace, "third_party/tensorflow"); -} - -std::string XlaSrcRoot() { - std::string workspace = GetEnvVarOrDie("TEST_WORKSPACE"); - std::string srcdir = GetEnvVarOrDie("TEST_SRCDIR"); - - return kIsOpenSource ? io::JoinPath(srcdir, workspace, "xla") - : io::JoinPath(srcdir, workspace, - "third_party/tensorflow/compiler/xla"); -} - -std::string TslSrcRoot() { - std::string workspace = GetEnvVarOrDie("TEST_WORKSPACE"); - std::string srcdir = GetEnvVarOrDie("TEST_SRCDIR"); - const char* tsl_path = "tsl"; - - return kIsOpenSource - ? io::JoinPath(srcdir, workspace, tsl_path) - : io::JoinPath(srcdir, workspace, "third_party", tsl_path); -} - -} // namespace testing -} // namespace tsl diff --git a/tsl/platform/test.h b/tsl/platform/test.h index 77591d8c0..31ca87536 100644 --- a/tsl/platform/test.h +++ b/tsl/platform/test.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,71 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_TEST_H_ #define TENSORFLOW_TSL_PLATFORM_TEST_H_ -#include -#include -#include - -#include // IWYU pragma: export -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/types.h" - -// Includes gmock.h and enables the use of gmock matchers in tensorflow tests. -// -// Test including this header can use the macros EXPECT_THAT(...) and -// ASSERT_THAT(...) in combination with gmock matchers. -// Example: -// std::vector vec = Foo(); -// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3)); -// EXPECT_THAT(vec, ::testing::UnorderedElementsAre(2,3,1)); -// -// For more details on gmock matchers see: -// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers -// -// The advantages of using gmock matchers instead of self defined matchers are -// better error messages, more maintainable tests and more test coverage. -#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) && \ - !defined(PLATFORM_CHROMIUMOS) -#include -#include // IWYU pragma: export -#include // IWYU pragma: export -#endif -#include // IWYU pragma: export - -namespace tsl { -namespace testing { - -// Return a temporary directory suitable for temporary testing files. -// -// Where possible, consider using Env::LocalTempFilename over this function. -std::string TmpDir(); - -// Returns the path to TensorFlow in the directory containing data -// dependencies. -// -// A better alternative would be making use if -// tensorflow/tsl/platform/resource_loader.h:GetDataDependencyFilepath. That -// function should do the right thing both within and outside of tests allowing -// avoiding test specific APIs. -std::string TensorFlowSrcRoot(); - -// Returns the path to XLA in the directory containing data -// dependencies. -std::string XlaSrcRoot(); - -// Returns the path to TSL in the directory containing data -// dependencies. -std::string TslSrcRoot(); - -// Return a random number generator seed to use in randomized tests. -// Returns the same value for the lifetime of the process. -int RandomSeed(); - -// Returns an unused port number, for use in multi-process testing. -// NOTE: This function is not thread-safe. -int PickUnusedPortOrDie(); - -} // namespace testing -} // namespace tsl +#include "xla/tsl/platform/test.h" #endif // TENSORFLOW_TSL_PLATFORM_TEST_H_ diff --git a/tsl/platform/test_benchmark.h b/tsl/platform/test_benchmark.h index d1ce3cdac..6772a5f12 100644 --- a/tsl/platform/test_benchmark.h +++ b/tsl/platform/test_benchmark.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,36 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Simple benchmarking facility. #ifndef TENSORFLOW_TSL_PLATFORM_TEST_BENCHMARK_H_ #define TENSORFLOW_TSL_PLATFORM_TEST_BENCHMARK_H_ -#include "benchmark/benchmark.h" // from @com_google_benchmark // IWYU pragma: export -#include "tsl/platform/platform.h" - -// FIXME(vyng): Remove this. -// Background: During the benchmark-migration projects, all benchmarks were made -// to use "testing::benchmark::" prefix because that is what the internal -// Google benchmark library use. -namespace testing { -namespace benchmark { -using ::benchmark::State; // NOLINT -} // namespace benchmark -} // namespace testing - -namespace tsl { -namespace testing { - -inline void RunBenchmarks() { benchmark::RunSpecifiedBenchmarks(); } -inline void InitializeBenchmarks(int* argc, char** argv) { - benchmark::Initialize(argc, argv); -} - -template -void DoNotOptimize(const T& var) { - ::benchmark::DoNotOptimize(var); -} -} // namespace testing -} // namespace tsl +#include "xla/tsl/platform/test_benchmark.h" #endif // TENSORFLOW_TSL_PLATFORM_TEST_BENCHMARK_H_ diff --git a/tsl/platform/test_main.cc b/tsl/platform/test_main.cc deleted file mode 100644 index fb9265618..000000000 --- a/tsl/platform/test_main.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// A program with a main that is suitable for unittests, including those -// that also define microbenchmarks. Based on whether the user specified -// the --benchmark_filter flag which specifies which benchmarks to run, -// we will either run benchmarks or run the gtest tests in the program. - -#include - -#include "absl/strings/match.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/stacktrace_handler.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" - -GTEST_API_ int main(int argc, char** argv) { - tsl::testing::InstallStacktraceHandler(); - - for (int i = 1; i < argc; i++) { - if (absl::StartsWith(argv[i], "--benchmark_filter=")) { - ::benchmark::Initialize(&argc, argv); - - // XXX: Must be called after benchmark's init because - // InitGoogleTest eventually calls absl::ParseCommandLine() which would - // complain that benchmark_filter flag is not known because that flag is - // defined by the benchmark library via its own command-line flag - // facility, which is not known to absl flags. - // FIXME(vyng): Fix this mess once we make benchmark use absl flags - testing::InitGoogleTest(&argc, argv); - ::benchmark::RunSpecifiedBenchmarks(); - return 0; - } - } - - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tsl/platform/threadpool.cc b/tsl/platform/threadpool.cc deleted file mode 100644 index 8b2c85033..000000000 --- a/tsl/platform/threadpool.cc +++ /dev/null @@ -1,300 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/threadpool.h" - -#define EIGEN_USE_THREADS - -#include "absl/types/optional.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "tsl/platform/blocking_counter.h" -#include "tsl/platform/context.h" -#include "tsl/platform/denormal.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/numa.h" -#include "tsl/platform/setround.h" -#include "tsl/platform/tracing.h" - -#ifdef DNNL_AARCH64_USE_ACL -#include "tsl/platform/cpu_info.h" -#endif // DNNL_AARCH64_USE_ACL - -#ifdef TENSORFLOW_THREADSCALING_EXPERIMENTAL -ABSL_FLAG(float, tensorflow_num_threads_scale_factor, 1.0, - "Allows to scale all Tensorflow ThreadPools. Total number of threads " - "in a given ThreadPool equals to num_threads * " - "tensorflow_num_threads_scale_factor. Default scale factor of 1 is a " - "no-op."); -#endif // TENSORFLOW_THREADSCALING_EXPERIMENTAL - -namespace tsl { - -namespace thread { - -struct EigenEnvironment { - typedef Thread EnvThread; - struct TaskImpl { - std::function f; - Context context; - uint64 trace_id; - }; - struct Task { - std::unique_ptr f; - }; - - Env* const env_; - const ThreadOptions thread_options_; - const string name_; - - EigenEnvironment(Env* env, const ThreadOptions& thread_options, - const string& name) - : env_(env), thread_options_(thread_options), name_(name) {} - - EnvThread* CreateThread(std::function f) { - return env_->StartThread(thread_options_, name_, [=]() { - // Set the processor flag to flush denormals to zero. - port::ScopedFlushDenormal flush; - // Set the processor rounding mode to ROUND TO NEAREST. - tsl::port::ScopedSetRound round(FE_TONEAREST); - if (thread_options_.numa_node != port::kNUMANoAffinity) { - port::NUMASetThreadNodeAffinity(thread_options_.numa_node); - } - f(); - }); - } - - Task CreateTask(std::function f) { - uint64 id = 0; - if (tracing::EventCollector::IsEnabled()) { - id = tracing::GetUniqueArg(); - tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id); - } - return Task{ - std::unique_ptr(new TaskImpl{ - std::move(f), - Context(ContextKind::kThread), - id, - }), - }; - } - - void ExecuteTask(const Task& t) { - WithContext wc(t.f->context); - tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, - t.f->trace_id); - t.f->f(); - } -}; - -ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) - : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {} - -ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads) - : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {} - -ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads, - bool low_latency_hint, Eigen::Allocator* allocator) { - CHECK_GE(num_threads, 1); - -#ifdef DNNL_AARCH64_USE_ACL - // To avoid cost of swapping in and out threads from running processes - // we do not use all available cores to parallelise TF operations. - if (num_threads == tsl::port::NumTotalCPUs() && num_threads >= 16) { - num_threads = num_threads - 1; - } -#endif // DNNL_AARCH64_USE_ACL - -#ifdef TENSORFLOW_THREADSCALING_EXPERIMENTAL - CHECK_GT(absl::GetFlag(FLAGS_tensorflow_num_threads_scale_factor), 0); - num_threads *= absl::GetFlag(FLAGS_tensorflow_num_threads_scale_factor); - if (num_threads < 1) num_threads = 1; -#endif // TENSORFLOW_THREADSCALING_EXPERIMENTAL - - eigen_threadpool_.reset(new Eigen::ThreadPoolTempl( - num_threads, low_latency_hint, - EigenEnvironment(env, thread_options, "tf_" + name))); - underlying_threadpool_ = eigen_threadpool_.get(); - threadpool_device_.reset(new Eigen::ThreadPoolDevice(underlying_threadpool_, - num_threads, allocator)); -} - -ThreadPool::ThreadPool(thread::ThreadPoolInterface* user_threadpool) { - underlying_threadpool_ = user_threadpool; - threadpool_device_.reset(new Eigen::ThreadPoolDevice( - underlying_threadpool_, underlying_threadpool_->NumThreads(), nullptr)); -} - -ThreadPool::~ThreadPool() {} - -void ThreadPool::Schedule(std::function fn) { - CHECK(fn != nullptr); - underlying_threadpool_->Schedule(std::move(fn)); -} - -int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling( - const int64_t total, const int64_t block_size) { - if (block_size <= 0 || total <= 1 || total <= block_size || - NumThreads() == 1) { - return 1; - } - return (total + block_size - 1) / block_size; -} - -int ThreadPool::NumShardsUsedByTransformRangeConcurrently( - const int64_t block_size, const int64_t total) { - return NumShardsUsedByFixedBlockSizeScheduling(total, block_size); -} - -void ThreadPool::ParallelFor(int64_t total, - const SchedulingParams& scheduling_params, - const std::function& fn) { - switch (scheduling_params.strategy()) { - case SchedulingStrategy::kAdaptive: { - if (scheduling_params.cost_per_unit().has_value()) { - ParallelFor(total, *scheduling_params.cost_per_unit(), fn); - } - break; - } - case SchedulingStrategy::kFixedBlockSize: { - if (scheduling_params.block_size().has_value()) { - ParallelForFixedBlockSizeScheduling( - total, *scheduling_params.block_size(), fn); - } - break; - } - } -} - -void ThreadPool::TransformRangeConcurrently( - const int64_t block_size, const int64_t total, - const std::function& fn) { - ParallelFor(total, - SchedulingParams(SchedulingStrategy::kFixedBlockSize, - absl::nullopt /* cost_per_unit */, block_size), - fn); -} - -// This functionality is similar to parallelFor, except that reasoning about -// the number of shards used is significantly easier. -void ThreadPool::ParallelForFixedBlockSizeScheduling( - const int64_t total, const int64_t block_size, - const std::function& fn) { - const int num_shards_used = - NumShardsUsedByFixedBlockSizeScheduling(total, block_size); - if (num_shards_used == 1) { - fn(0, total); - return; - } - - // Adapted from Eigen's parallelFor implementation. - BlockingCounter counter(num_shards_used); - std::function handle_range = - [=, &handle_range, &counter, &fn](int64_t first, int64_t last) { - while (last - first > block_size) { - // Find something near the midpoint which is a multiple of block size. - const int64_t mid = first + ((last - first) / 2 + block_size - 1) / - block_size * block_size; - Schedule([=, &handle_range]() { handle_range(mid, last); }); - last = mid; - } - // Single block or less, execute directly. - fn(first, last); - counter.DecrementCount(); // The shard is done. - }; - if (num_shards_used <= NumThreads()) { - // Avoid a thread hop by running the root of the tree and one block on the - // main thread. - handle_range(0, total); - } else { - // Execute the root in the thread pool to avoid running work on more than - // numThreads() threads. - Schedule([=, &handle_range]() { handle_range(0, total); }); - } - counter.Wait(); -} - -void ThreadPool::ParallelFor(int64_t total, int64_t cost_per_unit, - const std::function& fn) { - CHECK_GE(total, 0); - CHECK_EQ(total, (int64_t)(Eigen::Index)total); - threadpool_device_->parallelFor( - total, Eigen::TensorOpCost(0, 0, cost_per_unit), - [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); }); -} - -void ThreadPool::ParallelForWithWorkerId( - int64_t total, int64_t cost_per_unit, - const std::function& fn) { - CHECK_GE(total, 0); - CHECK_EQ(total, (int64_t)(Eigen::Index)total); - - threadpool_device_->parallelFor(total, - Eigen::TensorOpCost(0, 0, cost_per_unit), - [this, &fn](int64_t start, int64_t limit) { - // ParallelFor may use the current thread to - // do some work synchronously. When calling - // CurrentThreadId() from outside of the - // thread pool, we get -1, so we can shift - // every id up by 1. - int id = CurrentThreadId() + 1; - fn(start, limit, id); - }); -} - -void ThreadPool::ParallelForWithWorkerId( - int64_t total, const SchedulingParams& scheduling_params, - const std::function& fn) { - ParallelFor(total, scheduling_params, - [this, &fn](int64_t start, int64_t limit) { - // We may use the current thread to do some work synchronously. - // When calling CurrentThreadId() from outside of the thread - // pool, we get -1, so we can shift every id up by 1. - int id = CurrentThreadId() + 1; - fn(start, limit, id); - }); -} - -int ThreadPool::NumThreads() const { - return underlying_threadpool_->NumThreads(); -} - -int ThreadPool::CurrentThreadId() const { - return underlying_threadpool_->CurrentThreadId(); -} - -void ThreadPool::ScheduleWithHint(std::function fn, int start, - int limit) { - underlying_threadpool_->ScheduleWithHint(std::move(fn), start, limit); -} - -void ThreadPool::SetStealPartitions( - const std::vector>& partitions) { - // ThreadPool::SetStealPartitions is only called in the constructor of - // RunHandlerPool::Impl, which currently instantiates ThreadPool using a - // constructor that does not take user_threadpool. Thus we assume - // eigen_threadpool_ is not null here. - DCHECK(eigen_threadpool_ != nullptr); - eigen_threadpool_->SetStealPartitions(partitions); -} - -Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const { - DCHECK(underlying_threadpool_ != nullptr); - return underlying_threadpool_; -} -} // namespace thread -} // namespace tsl diff --git a/tsl/platform/threadpool.h b/tsl/platform/threadpool.h index df650f6ec..3ab00c4d4 100644 --- a/tsl/platform/threadpool.h +++ b/tsl/platform/threadpool.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,230 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_H_ -#include -#include - -#include "absl/types/optional.h" -#include "tsl/platform/env.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/threadpool_interface.h" -#include "tsl/platform/types.h" - -namespace Eigen { -class Allocator; -class ThreadPoolInterface; -struct ThreadPoolDevice; - -template -class ThreadPoolTempl; -} // namespace Eigen - -namespace tsl { -namespace thread { - -struct EigenEnvironment; - -class ThreadPool { - public: - // Scheduling strategies for ParallelFor. The strategy governs how the given - // units of work are distributed among the available threads in the - // threadpool. - enum class SchedulingStrategy { - // The Adaptive scheduling strategy adaptively chooses the shard sizes based - // on the cost of each unit of work, and the cost model of the underlying - // threadpool device. - // - // The 'cost_per_unit' is an estimate of the number of CPU cycles (or - // nanoseconds if not CPU-bound) to complete a unit of work. Overestimating - // creates too many shards and CPU time will be dominated by per-shard - // overhead, such as Context creation. Underestimating may not fully make - // use of the specified parallelism, and may also cause inefficiencies due - // to load balancing issues and stragglers. - kAdaptive, - // The Fixed Block Size scheduling strategy shards the given units of work - // into shards of fixed size. In case the total number of units is not - // evenly divisible by 'block_size', at most one of the shards may be of - // smaller size. The exact number of shards may be found by a call to - // NumShardsUsedByFixedBlockSizeScheduling. - // - // Each shard may be executed on a different thread in parallel, depending - // on the number of threads available in the pool. Note that when there - // aren't enough threads in the pool to achieve full parallelism, function - // calls will be automatically queued. - kFixedBlockSize - }; - - // Contains additional parameters for either the Adaptive or the Fixed Block - // Size scheduling strategy. - class SchedulingParams { - public: - explicit SchedulingParams(SchedulingStrategy strategy, - absl::optional cost_per_unit, - absl::optional block_size) - : strategy_(strategy), - cost_per_unit_(cost_per_unit), - block_size_(block_size) {} - - SchedulingStrategy strategy() const { return strategy_; } - absl::optional cost_per_unit() const { return cost_per_unit_; } - absl::optional block_size() const { return block_size_; } - - private: - // The underlying Scheduling Strategy for which this instance contains - // additional parameters. - SchedulingStrategy strategy_; - - // The estimated cost per unit of work in number of CPU cycles (or - // nanoseconds if not CPU-bound). Only applicable for Adaptive scheduling - // strategy. - absl::optional cost_per_unit_; - - // The block size of each shard. Only applicable for Fixed Block Size - // scheduling strategy. - absl::optional block_size_; - }; - - // Constructs a pool that contains "num_threads" threads with specified - // "name". env->StartThread() is used to create individual threads with the - // given ThreadOptions. If "low_latency_hint" is true the thread pool - // implementation may use it as a hint that lower latency is preferred at the - // cost of higher CPU usage, e.g. by letting one or more idle threads spin - // wait. Conversely, if the threadpool is used to schedule high-latency - // operations like I/O the hint should be set to false. - // - // REQUIRES: num_threads > 0 - ThreadPool(Env* env, const ThreadOptions& thread_options, - const std::string& name, int num_threads, bool low_latency_hint, - Eigen::Allocator* allocator = nullptr); - - // Constructs a pool for low-latency ops that contains "num_threads" threads - // with specified "name". env->StartThread() is used to create individual - // threads. - // REQUIRES: num_threads > 0 - ThreadPool(Env* env, const std::string& name, int num_threads); - - // Constructs a pool for low-latency ops that contains "num_threads" threads - // with specified "name". env->StartThread() is used to create individual - // threads with the given ThreadOptions. - // REQUIRES: num_threads > 0 - ThreadPool(Env* env, const ThreadOptions& thread_options, - const std::string& name, int num_threads); - - // Constructs a pool that wraps around the thread::ThreadPoolInterface - // instance provided by the caller. Caller retains ownership of - // `user_threadpool` and must ensure its lifetime is longer than the - // ThreadPool instance. - explicit ThreadPool(thread::ThreadPoolInterface* user_threadpool); - - // Waits until all scheduled work has finished and then destroy the - // set of threads. - ~ThreadPool(); - - // Schedules fn() for execution in the pool of threads. - void Schedule(std::function fn); - - void SetStealPartitions( - const std::vector>& partitions); - - void ScheduleWithHint(std::function fn, int start, int limit); - - // Returns the number of shards used by ParallelForFixedBlockSizeScheduling - // with these parameters. - int NumShardsUsedByFixedBlockSizeScheduling(const int64_t total, - const int64_t block_size); - - // Returns the number of threads spawned by calling TransformRangeConcurrently - // with these parameters. - // Deprecated. Use NumShardsUsedByFixedBlockSizeScheduling. - int NumShardsUsedByTransformRangeConcurrently(const int64_t block_size, - const int64_t total); - - // ParallelFor shards the "total" units of work assuming each unit of work - // having roughly "cost_per_unit" cost, in cycles. Each unit of work is - // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work - // and the total cost of each shard is roughly the same. - // - // "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds - // if not CPU-bound) to complete a unit of work. Overestimating creates too - // many shards and CPU time will be dominated by per-shard overhead, such as - // Context creation. Underestimating may not fully make use of the specified - // parallelism, and may also cause inefficiencies due to load balancing - // issues and stragglers. - void ParallelFor(int64_t total, int64_t cost_per_unit, - const std::function& fn); - - // Similar to ParallelFor above, but takes the specified scheduling strategy - // into account. - void ParallelFor(int64_t total, const SchedulingParams& scheduling_params, - const std::function& fn); - - // Same as ParallelFor with Fixed Block Size scheduling strategy. - // Deprecated. Prefer ParallelFor with a SchedulingStrategy argument. - void TransformRangeConcurrently( - const int64_t block_size, const int64_t total, - const std::function& fn); - - // Shards the "total" units of work. For more details, see "ParallelFor". - // - // The function is passed a thread_id between 0 and NumThreads() *inclusive*. - // This is because some work can happen on the caller thread while the threads - // in the pool are also being used. - // - // The caller can allocate NumThreads() + 1 separate buffers for each thread. - // Each thread can safely write to the buffer given by its id without - // synchronization. However, the worker fn may be called multiple times - // sequentially with the same id. - // - // At most NumThreads() unique ids will actually be used, and only a few may - // be used for small workloads. If each buffer is expensive, the buffers - // should be stored in an array initially filled with null, and a buffer - // should be allocated by fn the first time that the id is used. - void ParallelForWithWorkerId( - int64_t total, int64_t cost_per_unit, - const std::function& fn); - - // Similar to ParallelForWithWorkerId above, but takes the specified - // scheduling strategy into account. - void ParallelForWithWorkerId( - int64_t total, const SchedulingParams& scheduling_params, - const std::function& fn); - - // Returns the number of threads in the pool. - int NumThreads() const; - - // Returns current thread id between 0 and NumThreads() - 1, if called from a - // thread in the pool. Returns -1 otherwise. - int CurrentThreadId() const; - - // If ThreadPool implementation is compatible with Eigen::ThreadPoolInterface, - // returns a non-null pointer. The caller does not own the object the returned - // pointer points to, and should not attempt to delete. - Eigen::ThreadPoolInterface* AsEigenThreadPool() const; - - private: - // Divides the work represented by the range [0, total) into k shards. - // Calls fn(i*block_size, (i+1)*block_size) from the ith shard (0 <= i < k). - // Each shard may be executed on a different thread in parallel, depending on - // the number of threads available in the pool. - // When (i+1)*block_size > total, fn(i*block_size, total) is called instead. - // Here, k = NumShardsUsedByFixedBlockSizeScheduling(total, block_size). - // Requires 0 < block_size <= total. - void ParallelForFixedBlockSizeScheduling( - const int64_t total, const int64_t block_size, - const std::function& fn); - - // underlying_threadpool_ is the user_threadpool if user_threadpool is - // provided in the constructor. Otherwise it is the eigen_threadpool_. - Eigen::ThreadPoolInterface* underlying_threadpool_; - // eigen_threadpool_ is instantiated and owned by thread::ThreadPool if - // user_threadpool is not in the constructor. - std::unique_ptr> eigen_threadpool_; - std::unique_ptr threadpool_device_; - ThreadPool(const ThreadPool&) = delete; - void operator=(const ThreadPool&) = delete; -}; - -} // namespace thread -} // namespace tsl +#include "xla/tsl/platform/threadpool.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_H_ diff --git a/tsl/platform/threadpool_async_executor.h b/tsl/platform/threadpool_async_executor.h index 59f14aab1..deadc9511 100644 --- a/tsl/platform/threadpool_async_executor.h +++ b/tsl/platform/threadpool_async_executor.h @@ -16,35 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ -#include - -#include "xla/tsl/concurrency/async_value.h" -#include "tsl/platform/threadpool.h" - -namespace tsl::thread { - -// An adaptor for a ThreadPool that converts it into the AsyncValue:Executor. -// -// AsncValue::Executor task is a move-only absl::AnyInvocable, and ThreadPool -// expects a copyable std::function. This class adapts the two and makes sure -// that the task is deleted when it's done executing. -class ThreadPoolAsyncExecutor : public AsyncValue::Executor { - public: - explicit ThreadPoolAsyncExecutor(ThreadPool* thread_pool) - : thread_pool_(thread_pool) {} - - void Execute(Task task) final { - auto* task_ptr = new Task(std::move(task)); - thread_pool_->Schedule([task_ptr] { - (*task_ptr)(); - delete task_ptr; - }); - } - - private: - ThreadPool* thread_pool_; -}; - -} // namespace tsl::thread +#include "xla/tsl/platform/threadpool_async_executor.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ diff --git a/tsl/platform/threadpool_async_executor_test.cc b/tsl/platform/threadpool_async_executor_test.cc deleted file mode 100644 index acc00aa21..000000000 --- a/tsl/platform/threadpool_async_executor_test.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tsl/platform/threadpool_async_executor.h" - -#include "absl/synchronization/notification.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" - -namespace tsl::thread { -namespace { - -TEST(ThreadPoolAsyncExecutorTest, ExecuteTasks) { - ThreadPool thread_pool(Env::Default(), "test", 4); - ThreadPoolAsyncExecutor executor(&thread_pool); - - absl::Notification notification; - executor.Execute([&] { notification.Notify(); }); - notification.WaitForNotification(); -} - -} // namespace -} // namespace tsl::thread diff --git a/tsl/platform/threadpool_interface.h b/tsl/platform/threadpool_interface.h index 0dac04d5e..930d8bcd2 100644 --- a/tsl/platform/threadpool_interface.h +++ b/tsl/platform/threadpool_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ -#include "unsupported/Eigen/CXX11/ThreadPool" // from @eigen_archive -#include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" - -namespace tsl { -namespace thread { - -class ThreadPoolInterface : public Eigen::ThreadPoolInterface {}; - -} // namespace thread -} // namespace tsl +#include "xla/tsl/platform/threadpool_interface.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ diff --git a/tsl/platform/threadpool_options.h b/tsl/platform/threadpool_options.h index 21c74fbaa..ea884edfc 100644 --- a/tsl/platform/threadpool_options.h +++ b/tsl/platform/threadpool_options.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,20 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ -#include "tsl/platform/threadpool_interface.h" - -namespace tsl { -namespace thread { - -struct ThreadPoolOptions { - // If not null, use this threadpool to schedule inter-op operation - thread::ThreadPoolInterface* inter_op_threadpool = nullptr; - - // If not null, use this threadpool to schedule intra-op operation - thread::ThreadPoolInterface* intra_op_threadpool = nullptr; -}; - -} // namespace thread -} // namespace tsl +#include "xla/tsl/platform/threadpool_options.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ diff --git a/tsl/platform/types.h b/tsl/platform/types.h index 1768d57bb..90aa7993f 100644 --- a/tsl/platform/types.h +++ b/tsl/platform/types.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,59 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_TYPES_H_ #define TENSORFLOW_TSL_PLATFORM_TYPES_H_ -#include - -#include "tsl/platform/bfloat16.h" -#include "tsl/platform/ml_dtypes.h" // IWYU pragma: export -#include "tsl/platform/platform.h" -#include "tsl/platform/tstring.h" - -// Include appropriate platform-dependent implementations -#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES) -#include "xla/tsl/platform/google/integral_types.h" // IWYU pragma: export -#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ - defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ - defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) -#include "xla/tsl/platform/default/integral_types.h" // IWYU pragma: export -#else -#error Define the appropriate PLATFORM_ macro for this platform -#endif - -namespace tsl { - -// Alias tsl::string to std::string. -using std::string; - -static const uint4 kuint4max = static_cast(0x0F); -static const uint8 kuint8max = static_cast(0xFF); -static const uint16 kuint16max = static_cast(0xFFFF); -static const uint32 kuint32max = static_cast(0xFFFFFFFF); -static const uint64 kuint64max = static_cast(0xFFFFFFFFFFFFFFFFull); -static const int8_t kint8min = static_cast(~0x7F); -static const int8_t kint8max = static_cast(0x7F); -static const int4 kint4min = static_cast(0x08); -static const int4 kint4max = static_cast(0x07); -static const int16_t kint16min = static_cast(~0x7FFF); -static const int16_t kint16max = static_cast(0x7FFF); -static const int32_t kint32min = static_cast(~0x7FFFFFFF); -static const int32_t kint32max = static_cast(0x7FFFFFFF); -static const int64_t kint64min = static_cast(~0x7FFFFFFFFFFFFFFFll); -static const int64_t kint64max = static_cast(0x7FFFFFFFFFFFFFFFll); - -// A typedef for a uint64 used as a short fingerprint. -using Fprint = uint64; - -} // namespace tsl - -// Alias namespace ::stream_executor as ::tensorflow::se. -namespace stream_executor {} -namespace tensorflow { -namespace se = ::stream_executor; -} // namespace tensorflow - -#if defined(PLATFORM_WINDOWS) -#include -typedef std::ptrdiff_t ssize_t; -#endif +#include "xla/tsl/platform/types.h" #endif // TENSORFLOW_TSL_PLATFORM_TYPES_H_