Skip to content

Commit

Permalink
don't shard norm weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jspark1105 committed Sep 22, 2023
1 parent 0224797 commit db6a1c7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def _shard_parameters_(self) -> None:
assert p.dtype == torch.float32

# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1
p._is_sharded = self.world_size > 1 and isinstance(p, FlatParameter)
p._orig_size = p.data.size()

if not p._is_sharded:
Expand Down Expand Up @@ -1159,7 +1159,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
non_shared_params
), f"{len(full_tensors)} vs. {len(non_shared_params)}"
for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors):
if not volatile:
if not volatile and p._is_sharded:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
Expand Down Expand Up @@ -1773,7 +1773,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
# assert self.world_size == 1
self._post_reduction_hook(param, param.grad.data)

# After _post_backward_hook returns, orig_grad_data will eventually
Expand Down

0 comments on commit db6a1c7

Please sign in to comment.