Skip to content

Commit

Permalink
Fix non contiguous tensors in consolidation (#736)
Browse files Browse the repository at this point in the history
* Fix non-contiguous tensor issue in checkpoint consolidation

* Update PR with more edge cases where tensor may not be contiguous after placed on cpu

---------

Co-authored-by: gioannides <[email protected]>
  • Loading branch information
dacorvo and gioannides authored Nov 14, 2024
1 parent 02c331d commit 2897a08
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,21 @@ def consolidate_tensor_parallel_checkpoints(
# This might not be the case anymore when `ParameterMetadata` uses slices.
sharded_metadata = sharded_metadatas[name]
if sharded_metadata.is_tied:
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu")
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous()
else:
weights = [state_dict[name] for state_dict in state_dicts]
# Ensure that all tensors are contiguous before concatenating or further processing
weights = [state_dict[name].contiguous() for state_dict in state_dicts]
tp_size = len(weights)
full_weight = torch.cat(
weights,
dim=sharded_metadata.partition_dim,
)
full_weight = full_weight.to("cpu")

full_weight = (
torch.cat(
weights,
dim=sharded_metadata.partition_dim,
)
.to("cpu")
.contiguous()
) # Ensure the result is also contiguous

if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]:
full_weight = (
torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()
Expand Down

0 comments on commit 2897a08

Please sign in to comment.