Skip to content

Commit

Permalink
process_group: add PG timeouts + automatically assign manager port
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jan 9, 2025
1 parent 2ae42a0 commit d6f256c
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 40 deletions.
22 changes: 18 additions & 4 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from torchft.process_group import ProcessGroup

MANAGER_ADDR_KEY: str = "manager_addr"
MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511))
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
REPLICA_ID_KEY: str = "replica_id"

T = TypeVar("T")
Expand Down Expand Up @@ -74,6 +74,12 @@ class Manager:
"""
Manager manages the full fault tolerant training loop.
This requires the that the TCPStore specified by the store_addr and
store_port or MASTER_ADDR and MASTER_PORT environment variables to be
started prior to creating this manager. If using a modern version of
torchelastic this will already be the case. Otherwise, it should be started
via torch.distributed.init_process_group prior to creating this manager.
NOTE: when saving periodic checkpoints you must save and restore the
Manager's state_dict as well to avoid synchronization issues.
"""
Expand All @@ -84,7 +90,6 @@ def __init__(
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
min_replica_size: int,
port: int = MANAGER_DEFAULT_PORT,
use_async_quorum: bool = True,
timeout: timedelta = timedelta(seconds=60),
rank: Optional[int] = None,
Expand All @@ -94,13 +99,18 @@ def __init__(
store_port: Optional[int] = None,
lighthouse_addr: Optional[str] = None,
replica_id: Optional[str] = None,
port: Optional[int] = None,
) -> None:
"""
Args:
load_state_dict: function to load the state dict when recovering
state_dict: function to save the state dict with recovering
min_replica_size: minimum number of replicas on each step
port: if rank==0, the port to run the manager server on
port: if rank==0, the port to run the manager server on.
Port assignment priority:
1. this argument
2. TORCHFT_MANAGER_PORT env var
3. arbitrary port assigned via 0
use_async_quorum: whether to run the quorum asynchronously during the forward pass
timeout: timeout for all operations
rank: the replica group local rank
Expand Down Expand Up @@ -147,6 +157,10 @@ def _manager_state_dict() -> Dict[str, T]:

if rank == 0:
hostname = socket.gethostname()

if port is None:
port = int(os.environ.get(MANAGER_PORT_ENV, 0))

addr = f"http://{hostname}:{port}"
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
Expand All @@ -163,7 +177,7 @@ def _manager_state_dict() -> Dict[str, T]:
world_size=world_size,
)

self._store.set(MANAGER_ADDR_KEY, addr)
self._store.set(MANAGER_ADDR_KEY, self._manager.address())
self._store.set(REPLICA_ID_KEY, replica_id)

addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")
Expand Down
94 changes: 69 additions & 25 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
"""

import logging
import queue
import threading
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -53,8 +54,23 @@
_FUTURE_EXCEPTION = "fut_exception"


def _get(queue: mp.Queue, timeout: float) -> object:
v = queue.get(timeout=timeout)
def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
"""
Gets an item from a queue with a timeout. If the timeout is exceeded then
a TimeoutError is raised.
If an exception is returned from the queue then it is raised.
Args:
q: queue to get from
timeout: timeout in seconds
"""
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()
try:
v = q.get(timeout=timeout)
except queue.Empty as e:
raise TimeoutError(f"queue.get() timed out after {timeout} seconds") from e
if isinstance(v, Exception):
raise v
return v
Expand Down Expand Up @@ -95,6 +111,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
Every time this is called it must be provided with a unique prefixed
store address. I.e. localhost:1234/my/prefix/1
This function will block until the underlying ProcessGroup is created.
If an error occurs this will throw.
Args:
store_addr: address of the store to use
rank: rank of this process
Expand Down Expand Up @@ -187,7 +206,6 @@ def __repr__(self) -> str:


class ProcessGroupWrapper(ProcessGroup):
PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized
"""
This is a wrapper around any ProcessGroup with a reconfiguration method.
"""
Expand All @@ -209,9 +227,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

store = create_store_client(store_addr)

# TODO: set global timeout
# pyre-fixme[20]: expects argument options
self._pg = self.PG_CLASS(store, rank, world_size)
self._pg = self._create_pg(store, rank, world_size)

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
raise NotImplementedError("not implemented")

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return self.parent.allreduce(tensors, opts)
Expand Down Expand Up @@ -244,9 +263,13 @@ class ProcessGroupGloo(ProcessGroupWrapper):
This is a reconfigurable version of ProcessGroupGloo.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
)
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
super().__init__()
self._timeout = timeout

def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size, self._timeout)

def getBackendName(self) -> str:
return "torchft-gloo"
Expand All @@ -263,9 +286,9 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
abort when reconfiguring, we need to ensure this is safe.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
)
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)

def getBackendName(self) -> str:
return "torchft-nccl"
Expand Down Expand Up @@ -546,10 +569,9 @@ class ProcessGroupBaby(ProcessGroup):
"""

PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized
WORK_CLASS: Type[_BabyWork] = _BabyWork

def __init__(self, timeout: float = 60.0) -> None:
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
super().__init__(0, 1)

self._world_size = -1
Expand All @@ -562,7 +584,10 @@ def __init__(self, timeout: float = 60.0) -> None:
self._futures: Dict[int, Future[object]] = {}
self._futures_lock = threading.Lock()

self._timeout = timeout
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

self._timeout: float = timeout

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
if self._p is not None:
Expand All @@ -581,7 +606,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

ctx = mp.get_context("spawn")
self._tx = ctx.Queue()
self._rx = ctx.Queue()
self._rx = rx = ctx.Queue()

# futures need thread to fire callbacks
self._future_queue = ctx.Queue()
Expand All @@ -602,6 +627,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
)
self._p.start()

# fetch the status of the PG init
# if an exception was returned _get will throw
assert _get(rx, self._timeout) is None

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
"""
This is a class method to avoid pickling the class.
"""
raise NotImplementedError("not implemented")

@classmethod
def _worker(
cls,
Expand All @@ -615,8 +651,13 @@ def _worker(
try:
store = create_store_client(store_addr)

# pyre-fixme[20]: expects argument options
pg = cls.PG_CLASS(store, rank, world_size)
try:
pg = cls._create_pg(store, rank, world_size)
except Exception as e:
logger.exception(f"got exception in worker: {e}")
tx.put(e)
return
tx.put(None)

work = {}
next_op_id: int = 0
Expand Down Expand Up @@ -737,9 +778,10 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
ProcessGroupBabyNCCL.
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
)
@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupGloo
return BaseProcessGroupGloo(store, rank, world_size)

def getBackendName(self) -> str:
return "torchft-baby-gloo"
Expand All @@ -761,11 +803,13 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
tensors may leak in the current PyTorch implementation. TODO fix
"""

PG_CLASS: Type[BaseProcessGroup] = (
BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
)
WORK_CLASS = _BabyWorkNCCL

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupNCCL
return BaseProcessGroupNCCL(store, rank, world_size)

def getBackendName(self) -> str:
return "torchft-baby-nccl"

Expand Down
49 changes: 38 additions & 11 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Any, Dict, Tuple
from unittest import TestCase, skipUnless
from unittest.mock import Mock
Expand Down Expand Up @@ -122,6 +123,16 @@ def test_gloo(self) -> None:
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
m(torch.rand(2, 3))

def test_gloo_timeout(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"
pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01))
with self.assertRaisesRegex(RuntimeError, "timeout after 10ms"):
pg.configure(store_addr, 0, 2)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipUnless(torch.cuda.is_available(), "needs CUDA")
def test_nccl(self) -> None:
Expand Down Expand Up @@ -155,28 +166,44 @@ def test_baby_gloo(self) -> None:
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"
store_addr: str = f"localhost:{store.port}/prefix"

def run(rank: int) -> Tuple[torch.Tensor, Work]:
a = ProcessGroupBabyGloo()
a.configure(store_addr, rank, 2)

a = ProcessGroupBabyGloo()
b = ProcessGroupBabyGloo()
self.assertEqual(a.size(), 2)

a.configure(store_addr, 0, 2)
b.configure(store_addr, 1, 2)
at = torch.tensor([rank + 1])

self.assertEqual(a.size(), 2)
a_work = a.allreduce([at], ReduceOp.SUM)
return at, a_work

at = torch.tensor([1])
bt = torch.tensor([2])
with ThreadPoolExecutor(max_workers=2) as executor:
a_fut = executor.submit(run, 0)
b_fut = executor.submit(run, 1)

a_work = a.allreduce([at], ReduceOp.SUM)
b_work = b.allreduce([bt], ReduceOp.SUM)
at, a_work = a_fut.result()
bt, b_work = b_fut.result()

a_work.wait()
fut = b_work.get_future()

fut.wait()

torch.testing.assert_close(at, bt)
torch.testing.assert_close(at, torch.tensor([3]))
torch.testing.assert_close(bt, torch.tensor([3]))

def test_baby_gloo_timeout(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyGloo(timeout=timedelta(seconds=0.01))
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
a.configure(store_addr, 0, 2)

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
Expand Down

0 comments on commit d6f256c

Please sign in to comment.