diff --git a/safetensors_file.py b/safetensors_file.py index 9d07325..b584def 100644 --- a/safetensors_file.py +++ b/safetensors_file.py @@ -24,20 +24,18 @@ def __init__(self,name:str,dtype:str,shape:list[int],offset0:int,offset1:int): class SafeTensorsFile: def __init__(self): self.f=None #file handle - self.hdrbuf=None #header byet buffer + self.hdrbuf=None #header byte buffer self.header=None #parsed header as a dict self.error=0 def __del__(self): self.close_file() - + def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - if self.f is not None: - self.f.close() - + self.close_file() def close_file(self): if self.f is not None: @@ -71,7 +69,7 @@ def open_file(filename:str,quiet=False): def open(self,fn:str,quiet=False)->int: st=os.stat(fn) - if st.st_size<8: + if st.st_size<8: #test file: zero_len_file.safetensors raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes") f=open(fn,"rb") @@ -79,9 +77,8 @@ def open(self,fn:str,quiet=False)->int: if len(b8)!=8: raise SafeTensorsException.invalid_file(fn,f"read only {len(b8)} bytes at start of file") headerlen=int.from_bytes(b8,'little',signed=False) - if headerlen==0: - raise SafeTensorsException.invalid_file(fn,"header size is 0") - if (8+headerlen>st.st_size): + + if (8+headerlen>st.st_size): #test file: header_size_too_big.safetensors raise SafeTensorsException.invalid_file(fn,"header extends past end of file") if quiet==False: diff --git a/safetensors_worker.py b/safetensors_worker.py index 75e9ec5..82c3a66 100644 --- a/safetensors_worker.py +++ b/safetensors_worker.py @@ -78,7 +78,7 @@ def PrintHeader(cmdLine:dict,input_file:str) -> int: return 0 def _ParseMore(d:dict): - '''Basically when printing, try to turn this: + '''Basically try to turn this: "ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}", @@ -106,18 +106,18 @@ def _ParseMore(d:dict): _ParseMore(value) def PrintMetadata(cmdLine:dict,input_file:str) -> int: - s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) - js=s.get_header() - if js is None: return -1 - - if not "__metadata__" in js: - print("file header does not contain a __metadata__ item",file=sys.stderr) - return -2 - - md=js["__metadata__"] - if cmdLine['parse_more']: - _ParseMore(md) - json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1) + with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s: + js=s.get_header() + if js is None: return -1 + + if not "__metadata__" in js: + print("file header does not contain a __metadata__ item",file=sys.stderr) + return -2 + + md=js["__metadata__"] + if cmdLine['parse_more']: + _ParseMore(md) + json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1) return 0 def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int: @@ -225,7 +225,7 @@ def _CheckLoRA_internal(s:SafeTensorsFile)->int: def CheckLoRA(cmdLine:dict,input_file:str)->int: s=SafeTensorsFile.open_file(input_file) i:int=_CheckLoRA_internal(s) - if i==0: print("looks like an OK LoRA file") + if i==0: print("looks like an OK SD 1.x LoRA file") return 0 def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int: