diff --git a/safetensors_file.py b/safetensors_file.py index 782a1eb..34e7b58 100644 --- a/safetensors_file.py +++ b/safetensors_file.py @@ -63,12 +63,12 @@ def parse_object_pairs(pairs): raise SafeTensorsException.invalid_file(self.filename,"duplicate keys in header") @staticmethod - def open_file(filename:str,quiet=False): + def open_file(filename:str,quiet=False,parseHeader=True): s=SafeTensorsFile() - s.open(filename,quiet) + s.open(filename,quiet,parseHeader) return s - def open(self,fn:str,quiet=False)->int: + def open(self,fn:str,quiet=False,parseHeader=True)->int: st=os.stat(fn) if st.st_size<8: #test file: zero_len_file.safetensors raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes") @@ -93,13 +93,12 @@ def open(self,fn:str,quiet=False)->int: self.hdrbuf=hdrbuf self.error=0 self.headerlen=headerlen + if parseHeader==True: + self._CheckDuplicateHeaderKeys() + self.header=json.loads(self.hdrbuf) return 0 def get_header(self): - if self.header is not None: return self.header - if self.hdrbuf is None: raise Exception("SafetensorsFile no header buffer") - self._CheckDuplicateHeaderKeys() - self.header=json.loads(self.hdrbuf) return self.header def load_one_tensor(self,tensor_name:str): diff --git a/safetensors_util.py b/safetensors_util.py index ffd7310..7d1ce69 100644 --- a/safetensors_util.py +++ b/safetensors_util.py @@ -17,7 +17,7 @@ help="Quiet mode, don't print informational stuff" ) @click.group() -@click.version_option(version=6) +@click.version_option(version=7) @quiet_flag @click.pass_context diff --git a/safetensors_worker.py b/safetensors_worker.py index 82c3a66..f0b0de6 100644 --- a/safetensors_worker.py +++ b/safetensors_worker.py @@ -22,7 +22,6 @@ def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_fi s=SafeTensorsFile.open_file(in_st_file) js=s.get_header() - if js is None: return -1 if inmeta==[]: js.pop("__metadata__",0) @@ -57,8 +56,6 @@ def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_fi def PrintHeader(cmdLine:dict,input_file:str) -> int: s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) js=s.get_header() - if js is None: return -1 - # All the .safetensors files I've seen have long key names, and as a result, # neither json nor pprint package prints text in very readable format, @@ -108,7 +105,6 @@ def _ParseMore(d:dict): def PrintMetadata(cmdLine:dict,input_file:str) -> int: with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s: js=s.get_header() - if js is None: return -1 if not "__metadata__" in js: print("file header does not contain a __metadata__ item",file=sys.stderr) @@ -123,7 +119,6 @@ def PrintMetadata(cmdLine:dict,input_file:str) -> int: def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int: s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) js=s.get_header() - if js is None: return -1 _lora_keys:list[tuple(str,bool)]=[] # use list to sort by name for key in js: @@ -156,7 +151,7 @@ def printkeylist(kl): def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int: if _need_force_overwrite(output_file,cmdLine): return -1 - s=SafeTensorsFile.open_file(input_file) + s=SafeTensorsFile.open_file(input_file,parseHeader=False) if s.error!=0: return s.error hdrbuf=s.hdrbuf @@ -171,10 +166,8 @@ def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int: def _CheckLoRA_internal(s:SafeTensorsFile)->int: - js=s.get_header() - if js is None: return -1 - import lora_keys + js=s.get_header() set_scalar=set() set_nonscalar=set() for x in lora_keys._lora_keys: