Skip to content

Commit

Permalink
fix: sharding_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 16, 2024
1 parent 12cc6af commit 2d6476f
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions src/refiners/training_utils/sharding_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, List

from torch import Tensor, device as Device
Expand Down Expand Up @@ -89,14 +88,7 @@ def recursive_to(self, obj: Any, device: Device) -> Any:

def add_device_hook(self, module: Hookable, device: Device, method_name: str) -> None:
old_method = getattr(module, method_name)

new_method = self.wrap_device(old_method, device)
# new_method = update_wrapper(partial(new_method, module), old_method)

new_method = self.wrap_device(old_method, device)

new_method = update_wrapper(partial(new_method, module), old_method)

setattr(module, method_name, new_method)

def wrap_device(self, method: WrappableMethod, device: Device) -> WrappableMethod:
Expand Down

0 comments on commit 2d6476f

Please sign in to comment.