Skip to content

Commit

Permalink
correct saving shared weights
Browse files Browse the repository at this point in the history
TOKENIZERS_PARALLELISM=false while finetuning
  • Loading branch information
JegernOUTT committed Nov 1, 2023
1 parent 40c81d7 commit a9dd691
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions self_hosting_machinery/finetune/scripts/script_aux/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib
import json
import logging
import os
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Dict, Any, List, Tuple
Expand Down Expand Up @@ -31,6 +33,17 @@ def _lora_state_dict(model, *args, destination=None, prefix='', keep_vars=False,
}


def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for ptr, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing


class ModelContext:
def __init__(
self,
Expand Down Expand Up @@ -207,7 +220,12 @@ def save_model_state(
_ = [p.unlink() for p in cp_path.iterdir() if 'model_states' not in p.name]
for cp_path in model_cps:
cp = torch.load(str(cp_path), map_location='cpu')
shared = _shared_pointers(cp["module"])
for shared_weights in shared:
for name in shared_weights[1:]:
cp["module"].pop(name)
tensors = {k: v.contiguous() for k, v in cp["module"].items()}

meta: Dict[str, str] = {
"skipped_steps": str(cp["skipped_steps"]),
"global_steps": str(cp["global_steps"]),
Expand Down Expand Up @@ -236,6 +254,8 @@ def _setup_encoding(
weights_path: str,
repo_id: str
) -> AutoTokenizer:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

assert "tokenizer" in self.model_mappings_config
encoding = AutoTokenizer.from_pretrained(
repo_id, cache_dir=weights_path,
Expand Down

0 comments on commit a9dd691

Please sign in to comment.