Skip to content

Commit

Permalink
new parseHeader flag when opening file
Browse files Browse the repository at this point in the history
  • Loading branch information
by321 committed Dec 13, 2023
1 parent 27e49df commit 791ad0b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 17 deletions.
13 changes: 6 additions & 7 deletions safetensors_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def parse_object_pairs(pairs):
raise SafeTensorsException.invalid_file(self.filename,"duplicate keys in header")

@staticmethod
def open_file(filename:str,quiet=False):
def open_file(filename:str,quiet=False,parseHeader=True):
s=SafeTensorsFile()
s.open(filename,quiet)
s.open(filename,quiet,parseHeader)
return s

def open(self,fn:str,quiet=False)->int:
def open(self,fn:str,quiet=False,parseHeader=True)->int:
st=os.stat(fn)
if st.st_size<8: #test file: zero_len_file.safetensors
raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes")
Expand All @@ -93,13 +93,12 @@ def open(self,fn:str,quiet=False)->int:
self.hdrbuf=hdrbuf
self.error=0
self.headerlen=headerlen
if parseHeader==True:
self._CheckDuplicateHeaderKeys()
self.header=json.loads(self.hdrbuf)
return 0

def get_header(self):
if self.header is not None: return self.header
if self.hdrbuf is None: raise Exception("SafetensorsFile no header buffer")
self._CheckDuplicateHeaderKeys()
self.header=json.loads(self.hdrbuf)
return self.header

def load_one_tensor(self,tensor_name:str):
Expand Down
2 changes: 1 addition & 1 deletion safetensors_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
help="Quiet mode, don't print informational stuff" )

@click.group()
@click.version_option(version=6)
@click.version_option(version=7)
@quiet_flag

@click.pass_context
Expand Down
11 changes: 2 additions & 9 deletions safetensors_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_fi

s=SafeTensorsFile.open_file(in_st_file)
js=s.get_header()
if js is None: return -1

if inmeta==[]:
js.pop("__metadata__",0)
Expand Down Expand Up @@ -57,8 +56,6 @@ def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_fi
def PrintHeader(cmdLine:dict,input_file:str) -> int:
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
js=s.get_header()
if js is None: return -1


# All the .safetensors files I've seen have long key names, and as a result,
# neither json nor pprint package prints text in very readable format,
Expand Down Expand Up @@ -108,7 +105,6 @@ def _ParseMore(d:dict):
def PrintMetadata(cmdLine:dict,input_file:str) -> int:
with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s:
js=s.get_header()
if js is None: return -1

if not "__metadata__" in js:
print("file header does not contain a __metadata__ item",file=sys.stderr)
Expand All @@ -123,7 +119,6 @@ def PrintMetadata(cmdLine:dict,input_file:str) -> int:
def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int:
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
js=s.get_header()
if js is None: return -1

_lora_keys:list[tuple(str,bool)]=[] # use list to sort by name
for key in js:
Expand Down Expand Up @@ -156,7 +151,7 @@ def printkeylist(kl):
def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int:
if _need_force_overwrite(output_file,cmdLine): return -1

s=SafeTensorsFile.open_file(input_file)
s=SafeTensorsFile.open_file(input_file,parseHeader=False)
if s.error!=0: return s.error

hdrbuf=s.hdrbuf
Expand All @@ -171,10 +166,8 @@ def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int:


def _CheckLoRA_internal(s:SafeTensorsFile)->int:
js=s.get_header()
if js is None: return -1

import lora_keys
js=s.get_header()
set_scalar=set()
set_nonscalar=set()
for x in lora_keys._lora_keys:
Expand Down

0 comments on commit 791ad0b

Please sign in to comment.