Skip to content

Commit

Permalink
process_group: register via public API (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Nov 22, 2024
1 parent 4631d6f commit ee864cf
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
32 changes: 21 additions & 11 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def size(self) -> int:
def getBackendName(self) -> str:
raise NotImplementedError("not implemented")

def register(self, name: str) -> None:
def register(self, name: str) -> BaseProcessGroup:
"""
Registers the process group with the global registry. This enables usage
with things like functional_collectives which are compilable.
Expand All @@ -113,32 +113,42 @@ def register(self, name: str) -> None:
name: name must be a unique name for this process group
"""

self._group_name = f"{self.getBackendName()}:{name}"
_register_process_group(self.group_name, self)
group_name = f"{self.getBackendName()}:{name}"

# This is needed for DeviceMesh to work
# This is needed for DeviceMesh and functional collectives to work.
# Resizable worlds don't fit well into DeviceMesh so we register a world
# size 1 PG.
_world.pg_map[self] = (None, None)
_world.pg_names[self] = self._group_name
_world.pg_to_tag[self] = self._group_name
_world.tags_to_pg.setdefault(self._group_name, []).append(self)
# these PGs can be resized so we lie about the rank mapping
_world.pg_group_ranks[self] = {get_rank(): 0}

def create_pg(
prefix_store: PrefixStore, rank: int, world_size: int, timeout: float
) -> ProcessGroup:
return self

dist.Backend.register_backend(group_name, create_pg)

return dist.new_group(
ranks=[dist.get_rank()],
backend=group_name,
group_desc=group_name,
timeout=timedelta(seconds=60.0), # this timeout isn't used
)

@property
def group_name(self) -> str:
if self._group_name is None:
raise ValueError("ProcessGroup name not set")
return self._group_name

def _set_group_name(self, name: str) -> None:
self._group_name = name

def unregister(self) -> None:
"""
Unregisters the process group with the global registry.
Must be registered first.
"""
_unregister_process_group(self.group_name)
dist.destroy_process_group(self)


class ProcessGroupWrapper(ProcessGroup):
Expand Down
34 changes: 20 additions & 14 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,35 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest import TestCase, skipUnless
from concurrent.futures import ThreadPoolExecutor
import os
from concurrent.futures import ThreadPoolExecutor
from unittest import skipUnless, TestCase

import torch
from torch.distributed import TCPStore, ReduceOp
import torch.distributed as dist
from torch import nn
from torch._C._distributed_c10d import (
_resolve_process_group,
)
from torch.distributed import _functional_collectives
from torch._C._distributed_c10d import _resolve_process_group
from torch.distributed import _functional_collectives, ReduceOp, TCPStore
from torch.distributed.device_mesh import init_device_mesh

from torchft.process_group import (
extend_device_mesh,
ProcessGroup,
ProcessGroupBabyGloo,
ProcessGroupBabyNCCL,
ProcessGroupDummy,
ProcessGroupGloo,
ProcessGroupNCCL,
ProcessGroupDummy,
ProcessGroup,
extend_device_mesh,
)


def dummy_init_pg() -> None:
if not dist.is_initialized():
dist.init_process_group(
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
)


class ProcessGroupTest(TestCase):
def test_gloo(self) -> None:
store = TCPStore(
Expand Down Expand Up @@ -168,18 +172,20 @@ def test_device_mesh(self) -> None:
mesh_2d = extend_device_mesh(mesh_1d, pg)
assert mesh_2d.ndim == 2

pg.unregister()

def test_functional_collectives(self) -> None:
dummy_init_pg()

store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
store_addr = f"localhost:{store.port}/prefix"

pg = ProcessGroupGloo()
pg = ProcessGroupGloo().register("test_func_col")
pg.configure(store_addr, 0, 1)

pg.register("test_func_col")

self.assertEqual(pg.group_name, "torchft-gloo:test_func_col")
self.assertEqual(pg.group_name, str(dist.get_pg_count() - 1))

self.assertIs(_resolve_process_group(pg.group_name), pg)

Expand Down

0 comments on commit ee864cf

Please sign in to comment.