Skip to content

Commit

Permalink
added extractdata command
Browse files Browse the repository at this point in the history
  • Loading branch information
by321 committed Dec 11, 2023
1 parent f16739c commit 72bd8d1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 53 deletions.
1 change: 1 addition & 0 deletions lora_keys.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# use list to keep insertion order
# SD 1.5 LoRA keys
_lora_keys:list[tuple[str,bool]]=[
('lora_te_text_model_encoder_layers_0_mlp_fc1.alpha', True),
('lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight', False),
Expand Down
13 changes: 7 additions & 6 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ This is a light-weight utility program for [safetensors files](https://github.co
--help Show this message and exit.

Commands:
checklora see if input file is a SD 1.x LoRA file
extracthdr extract file header and save to output file
header print file header
listkeys print header key names (except __metadata__) as a Python list
metadata print only __metadata__ in file header
writemd write metadata to safetensors file header
checklora see if input file is a SD 1.x LoRA file
extractdata extract one tensor and save to file
extracthdr extract file header and save to output file
header print file header
listkeys print header key names (except \_\_metadata\_\_) as a Python list
metadata print only \_\_metadata\_\_ in file header
writemd read \_\_metadata\_\_ from json and write to safetensors file


The most useful thing is probably the read and write metadata commands. To read metadata:
Expand Down
53 changes: 6 additions & 47 deletions safetensors_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,60 +95,19 @@ def get_header(self):
if self.hdrbuf is None: raise Exception("SafetensorsFile no header buffer")
self._CheckDuplicateHeaderKeys()
self.header=json.loads(self.hdrbuf)

return self.header

def load_one_tensor(self,tensor_name:str):
self.get_header()
if tensor_name not in self.header: return None

t=self.header[tensor_name]
if t['dtype']=='F16':
dt=numpy.half
elif t['dtype']=='F32':
dt=numpy.single
#elif t['dtype']=='F64':
# dt=numpy.double
else:
msg=f"unsupported tensor data type in {self.filename}: {t['dtype']}"
raise SafeTensorsException(msg)
self.f.seek(8+self.headerlen+t['data_offsets'][0])
bytesToRead=t['data_offsets'][1]-t['data_offsets'][0]
bytes=self.f.read(bytesToRead)

n=1
for v in t['shape']: n=n*v
vals=numpy.frombuffer(bytes,dtype=dt,count=n,offset=0)
if dt!=numpy.single: vals=vals.astype(numpy.single)
print(self.header[tensor_name],dt,n,bytesToRead)
self.header[tensor_name]['values']=vals
"""a = np.array([1, 2, 3])
print(a)
print(a.dtype)
# [1 2 3]
# int64
a_float = a.astype(np.float32)"""

def load_data(self):
self.get_header()
#n:int=max([self.header[key].data_offsets[1] for key in self.header])
#self.f.seek(8+self.headerlen)
#databytes=self.f.read(n)
#if len(databytes)!=n:
# msg=f"error reading file {self.filename}, tried to read {n} bytes, only read {len(databytes)}"
# raise SafeTensorsException(msg)
d={}
d['__metadata__']=''
read_order:list[tuple(str,int)]=[]
for k,v in self.header.items():
if k=="__metadata__": #if metadata, just copy
d['__metadata__']=v
continue
read_order.append((k,v['data_offsets'][0]))
read_order.sort(key=lambda x:x[1])
for x in read_order:
self.load_one_tensor(x[0])
#print("asdfas",x[0])
#print(x[0],self.header[x[0]])
# 'dtype': 'F16', 'shape': [32, 768], 'data_offsets': [491526, 540678]
if len(bytes)!=bytesToRead:
print(f"{tensor_name}: length={bytesToRead}, only read {len(bytes)} bytes",file=sys.stderr)
return bytes

def copy_data_to_file(self,file_handle) -> int:

Expand Down
17 changes: 17 additions & 0 deletions safetensors_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ def cli(ctx,quiet:bool):
ctx.ensure_object(dict)
ctx.obj['quiet'] = quiet


@cli.command(name="header",short_help="print file header")
@readonly_input_file
@click.pass_context
def cmd_header(ctx,input_file:str) -> int:
sys.exit( safetensors_worker.PrintHeader(ctx.obj,input_file) )


@cli.command(name="metadata",short_help="print only __metadata__ in file header")
@readonly_input_file
@fix_ued_flag
Expand All @@ -41,12 +43,14 @@ def cmd_meta(ctx,input_file:str,parse_more:bool)->int:
ctx.obj['parse_more'] = parse_more
sys.exit( safetensors_worker.PrintMetadata(ctx.obj,input_file) )


@cli.command(name="listkeys",short_help="print header key names (except __metadata__) as a Python list")
@readonly_input_file
@click.pass_context
def cmd_keyspy(ctx,input_file:str) -> int:
sys.exit( safetensors_worker.HeaderKeysToLists(ctx.obj,input_file) )


@cli.command(name="writemd",short_help="read __metadata__ from json and write to safetensors file")
@click.argument("in_st_file", metavar='input_st_file',
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
Expand All @@ -60,6 +64,7 @@ def cmd_writemd(ctx,in_st_file:str,in_json_file:str,output_file:str,force_overwr
ctx.obj['force_overwrite'] = force_overwrite
sys.exit( safetensors_worker.WriteMetadataToHeader(ctx.obj,in_st_file,in_json_file,output_file) )


@cli.command(name="extracthdr",short_help="extract file header and save to output file")
@readonly_input_file
@output_file
Expand All @@ -69,6 +74,18 @@ def cmd_extractheader(ctx,input_file:str,output_file:str,force_overwrite:bool) -
ctx.obj['force_overwrite'] = force_overwrite
sys.exit( safetensors_worker.ExtractHeader(ctx.obj,input_file,output_file) )


@cli.command(name="extractdata",short_help="extract one tensor and save to file")
@readonly_input_file
@click.argument("key_name", metavar='key_name',type=click.STRING)
@output_file
@force_overwrite_flag
@click.pass_context
def cmd_extractheader(ctx,input_file:str,key_name:str,output_file:str,force_overwrite:bool) -> int:
ctx.obj['force_overwrite'] = force_overwrite
sys.exit( safetensors_worker.ExtractData(ctx.obj,input_file,key_name,output_file) )


@cli.command(name="checklora",short_help="see if input file is a SD 1.x LoRA file")
@readonly_input_file
@click.pass_context
Expand Down
19 changes: 19 additions & 0 deletions safetensors_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,22 @@ def CheckLoRA(cmdLine:dict,input_file:str)->int:
if i==0: print("looks like an OK LoRA file")
return 0

def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int:
if _need_force_overwrite(output_file,cmdLine): return -1

s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
if s.error!=0: return s.error

bindata=s.load_one_tensor(key_name)
s.close_file() #close it just in case user wants to write back to input_file itself
if bindata is None:
print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr)
return -1

with open(output_file,"wb") as fo:
wn=fo.write(bindata)
if wn!=len(bindata):
print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr)
return -1
if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}")
return 0

0 comments on commit 72bd8d1

Please sign in to comment.