Skip to content

Commit

Permalink
minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
by321 committed Dec 11, 2023
1 parent 57abf4e commit 96ddd58
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 23 deletions.
15 changes: 6 additions & 9 deletions safetensors_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,18 @@ def __init__(self,name:str,dtype:str,shape:list[int],offset0:int,offset1:int):
class SafeTensorsFile:
def __init__(self):
self.f=None #file handle
self.hdrbuf=None #header byet buffer
self.hdrbuf=None #header byte buffer
self.header=None #parsed header as a dict
self.error=0

def __del__(self):
self.close_file()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
if self.f is not None:
self.f.close()

self.close_file()

def close_file(self):
if self.f is not None:
Expand Down Expand Up @@ -71,17 +69,16 @@ def open_file(filename:str,quiet=False):

def open(self,fn:str,quiet=False)->int:
st=os.stat(fn)
if st.st_size<8:
if st.st_size<8: #test file: zero_len_file.safetensors
raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes")

f=open(fn,"rb")
b8=f.read(8) #read header size
if len(b8)!=8:
raise SafeTensorsException.invalid_file(fn,f"read only {len(b8)} bytes at start of file")
headerlen=int.from_bytes(b8,'little',signed=False)
if headerlen==0:
raise SafeTensorsException.invalid_file(fn,"header size is 0")
if (8+headerlen>st.st_size):

if (8+headerlen>st.st_size): #test file: header_size_too_big.safetensors
raise SafeTensorsException.invalid_file(fn,"header extends past end of file")

if quiet==False:
Expand Down
28 changes: 14 additions & 14 deletions safetensors_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def PrintHeader(cmdLine:dict,input_file:str) -> int:
return 0

def _ParseMore(d:dict):
'''Basically when printing, try to turn this:
'''Basically try to turn this:
"ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}",
Expand Down Expand Up @@ -106,18 +106,18 @@ def _ParseMore(d:dict):
_ParseMore(value)

def PrintMetadata(cmdLine:dict,input_file:str) -> int:
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
js=s.get_header()
if js is None: return -1

if not "__metadata__" in js:
print("file header does not contain a __metadata__ item",file=sys.stderr)
return -2

md=js["__metadata__"]
if cmdLine['parse_more']:
_ParseMore(md)
json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1)
with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s:
js=s.get_header()
if js is None: return -1

if not "__metadata__" in js:
print("file header does not contain a __metadata__ item",file=sys.stderr)
return -2

md=js["__metadata__"]
if cmdLine['parse_more']:
_ParseMore(md)
json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1)
return 0

def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int:
Expand Down Expand Up @@ -225,7 +225,7 @@ def _CheckLoRA_internal(s:SafeTensorsFile)->int:
def CheckLoRA(cmdLine:dict,input_file:str)->int:
s=SafeTensorsFile.open_file(input_file)
i:int=_CheckLoRA_internal(s)
if i==0: print("looks like an OK LoRA file")
if i==0: print("looks like an OK SD 1.x LoRA file")
return 0

def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int:
Expand Down

0 comments on commit 96ddd58

Please sign in to comment.