Skip to content

Commit

Permalink
: basic tbe input dump framework (#3593)
Browse files Browse the repository at this point in the history
Summary:

Plugin capability to dump TBE input and no-ops in OSS

Reviewed By: damianr99

Differential Revision: D68446857
  • Loading branch information
Sihui Han authored and facebook-github-bot committed Jan 22, 2025
1 parent b858408 commit a12f4ed
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers

from fbgemm_gpu.config import FeatureGate, FeatureGateName
from fbgemm_gpu.runtime_monitor import (
AsyncSeriesTimer,
Expand All @@ -49,6 +48,7 @@
generate_vbe_metadata,
is_torchdynamo_compiling,
)
from fbgemm_gpu.tbe_input_dump import TBEInputDump, TBEInputDumpConfig

from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc

Expand Down Expand Up @@ -647,6 +647,7 @@ def __init__( # noqa C901
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
uvm_host_mapped: bool = False,
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
tbe_input_dump_config: Optional[TBEInputDumpConfig] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -820,6 +821,21 @@ def __init__( # noqa C901
self.feature_table_map: List[int] = (
feature_table_map if feature_table_map is not None else list(range(T_))
)

self.tbe_input_dump: Optional[TBEInputDump] = (
tbe_input_dump_config.create_tbe_input_dump(
table_names=(
table_names
if table_names
else [f"table-{i}" for i in range(len(embedding_specs))]
),
table_heights=rows,
tbe_uuid=self.uuid,
feature_table_map=self.feature_table_map,
)
if tbe_input_dump_config is not None
else None
)
T = len(self.feature_table_map)
assert T_ <= T
table_has_feature = [False] * T_
Expand Down Expand Up @@ -1789,6 +1805,11 @@ def forward( # noqa: C901
self._report_io_size_count("fwd_input", indices)
self._report_tbe_mem_usage()

if self.tbe_input_dump is not None:
tbe_input_dump: TBEInputDump = self.tbe_input_dump
if tbe_input_dump.should_dump(self.step):
tbe_input_dump.run(indices, offsets, batch_size_per_feature_per_rank)

if len(self.timesteps_prefetched) == 0:
# In forward, we don't enable multi-pass prefetch as we want the process
# to be as fast as possible and memory usage doesn't matter (will be recycled
Expand Down
72 changes: 72 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe_input_dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import abc

from dataclasses import dataclass
from typing import List, Optional

from torch import Tensor


class TBEInputDump(abc.ABC):
"""
Interface for dump TBE input data out, actual implementation may store the data to files
"""

@abc.abstractmethod
def should_dump(self, step: int) -> bool:
"""
To check if the dump should be triggered at this step
Args:
step: the current step
Returns:
True if the dump should be triggered, otherwise False
"""
pass

@abc.abstractmethod
def run(
self,
indices: Tensor,
offsets: Tensor,
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
) -> None:
"""
To run the tbe input dump, and this is called for every batch that needs to be dumped
Args:
indices: A 1D-tensor that contains indices to be looked up
from all embedding table.
offsets: A 1D-tensor that conatins offsets of indices.
batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and
every feature. this is needed to support VBE.
"""
pass


@dataclass(frozen=True)
class TBEInputDumpConfig:
"""
Configuration for TBEInputDump
"""

# first batch to start dump, -1 means no dump
monitored_batch_start: int = -1
# total batch to dump
monitored_total_batch: int = 0

def create_tbe_input_dump(
self,
table_names: List[str],
table_heights: List[int],
tbe_uuid: str,
feature_table_map: List[int],
) -> Optional[TBEInputDump]:
assert (
self.monitored_batch_start == -1
), "Cannot specify monitored_batch_start without an actual implementation of tbe dump"
return None

0 comments on commit a12f4ed

Please sign in to comment.