Skip to content

Commit

Permalink
Add filter functionality to TraceMeRecorder to filter events based on…
Browse files Browse the repository at this point in the history
… filter parameter.

PiperOrigin-RevId: 696613849
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Nov 14, 2024
1 parent b807244 commit 8f72cbc
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions tsl/profiler/lib/traceme.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ limitations under the License.
#ifndef TENSORFLOW_TSL_PROFILER_LIB_TRACEME_H_
#define TENSORFLOW_TSL_PROFILER_LIB_TRACEME_H_

#include <sys/types.h>

#include <cstdint>
#include <limits>
#include <string>
#include <type_traits>
#include <utility>
Expand All @@ -34,6 +37,9 @@ limitations under the License.
namespace tsl {
namespace profiler {

constexpr uint64_t kTraceMeDefaultFilterMask =
std::numeric_limits<uint64_t>::max();

// Predefined levels:
// - Level 1 (kCritical) is the default and used only for user instrumentation.
// - Level 2 (kInfo) is used by profiler for instrumenting high level program
Expand Down Expand Up @@ -88,10 +94,12 @@ class TraceMe {
// - Can be a value in enum TraceMeLevel.
// Users are welcome to use level > 3 in their code, if they wish to filter
// out their host traces based on verbosity.
explicit TraceMe(absl::string_view name, int level = 1) {
explicit TraceMe(absl::string_view name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
DCHECK_GE(level, 1);
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
TraceMeRecorder::CheckFilter(filter_mask))) {
name_.Emplace(std::string(name));
start_time_ = GetCurrentTimeNanos();
}
Expand All @@ -102,19 +110,22 @@ class TraceMe {
// string should only be incurred when tracing is enabled. Wrap the temporary
// string generation (e.g., StrCat) in a lambda and use the name_generator
// template instead.
explicit TraceMe(std::string&& name, int level = 1) = delete;
explicit TraceMe(std::string&& name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) = delete;

// Do not allow passing strings by reference or value since the caller
// may unintentionally maintain ownership of the name.
// Explicitly wrap the name in a string_view if you really wish to maintain
// ownership of a string already generated for other purposes. For temporary
// strings (e.g., result of StrCat) use the name_generator template.
explicit TraceMe(const std::string& name, int level = 1) = delete;
explicit TraceMe(const std::string& name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) = delete;

// This overload is necessary to make TraceMe's with string literals work.
// Otherwise, the name_generator template would be used.
explicit TraceMe(const char* raw, int level = 1)
: TraceMe(absl::string_view(raw), level) {}
explicit TraceMe(const char* raw, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask)
: TraceMe(absl::string_view(raw), level, filter_mask) {}

// This overload only generates the name (and possibly metadata) if tracing is
// enabled. Useful for avoiding expensive operations (e.g., string
Expand All @@ -135,10 +146,12 @@ class TraceMe {
// });
template <typename NameGeneratorT,
std::enable_if_t<std::is_invocable_v<NameGeneratorT>, bool> = true>
explicit TraceMe(NameGeneratorT&& name_generator, int level = 1) {
explicit TraceMe(NameGeneratorT&& name_generator, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
DCHECK_GE(level, 1);
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
TraceMeRecorder::CheckFilter(filter_mask))) {
name_.Emplace(std::forward<NameGeneratorT>(name_generator)());
start_time_ = GetCurrentTimeNanos();
}
Expand Down Expand Up @@ -215,9 +228,12 @@ class TraceMe {
// Calls `name_generator` to get the name for activity.
template <typename NameGeneratorT,
std::enable_if_t<std::is_invocable_v<NameGeneratorT>, bool> = true>
static int64_t ActivityStart(NameGeneratorT&& name_generator, int level = 1) {
static int64_t ActivityStart(
NameGeneratorT&& name_generator, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
TraceMeRecorder::CheckFilter(filter_mask))) {
int64_t activity_id = TraceMeRecorder::NewActivityId();
TraceMeRecorder::Record({std::forward<NameGeneratorT>(name_generator)(),
GetCurrentTimeNanos(), -activity_id});
Expand All @@ -229,9 +245,12 @@ class TraceMe {

// Record the start time of an activity.
// Returns the activity ID, which is used to stop the activity.
static int64_t ActivityStart(absl::string_view name, int level = 1) {
static int64_t ActivityStart(
absl::string_view name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
TraceMeRecorder::CheckFilter(filter_mask))) {
int64_t activity_id = TraceMeRecorder::NewActivityId();
TraceMeRecorder::Record(
{std::string(name), GetCurrentTimeNanos(), -activity_id});
Expand All @@ -242,13 +261,17 @@ class TraceMe {
}

// Same as ActivityStart above, an overload for "const std::string&"
static int64_t ActivityStart(const std::string& name, int level = 1) {
return ActivityStart(absl::string_view(name), level);
static int64_t ActivityStart(
const std::string& name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
return ActivityStart(absl::string_view(name), level, filter_mask);
}

// Same as ActivityStart above, an overload for "const char*"
static int64_t ActivityStart(const char* name, int level = 1) {
return ActivityStart(absl::string_view(name), level);
static int64_t ActivityStart(
const char* name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
return ActivityStart(absl::string_view(name), level, filter_mask);
}

// Record the end time of an activity started by ActivityStart().
Expand All @@ -267,9 +290,12 @@ class TraceMe {
// Records the time of an instant activity.
template <typename NameGeneratorT,
std::enable_if_t<std::is_invocable_v<NameGeneratorT>, bool> = true>
static void InstantActivity(NameGeneratorT&& name_generator, int level = 1) {
static void InstantActivity(
NameGeneratorT&& name_generator, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) {
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
TraceMeRecorder::CheckFilter(filter_mask))) {
int64_t now = GetCurrentTimeNanos();
TraceMeRecorder::Record({std::forward<NameGeneratorT>(name_generator)(),
/*start_time=*/now, /*end_time=*/now});
Expand Down

0 comments on commit 8f72cbc

Please sign in to comment.