Skip to content

Commit

Permalink
Create safetensors_tags.py
Browse files Browse the repository at this point in the history
Utility to read ['ss_tag_frequency'] if it exists in the metadata, then pretty-print the output
  • Loading branch information
duanemoody authored Nov 14, 2023
1 parent aa6a159 commit 94502e9
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions safetensors_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json, sys
from safetensors_file import SafeTensorsFile
from safetensors_worker import _ParseMore
"""
This script extracts the JSON header's ss_tag_frequency from a safetensors file, then outputs it.
TODO: gracefully error out on safetensors files without ["__metadata__"]["ss_tag_frequency"]
"""

def get_tags(tensorsfile: str) -> str:
s = SafeTensorsFile.open_file(tensorsfile, quiet=True) # omit the first non-JSON line
js = s.get_header()
md = js["__metadata__"]
_ParseMore(md) # pretty print the metadata
stf = md["ss_tag_frequency"]
return json.dumps(stf, ensure_ascii=False, separators=(', ', ': '), indent=4)

tensorsfile = sys.argv[1]
hdata = get_tags(tensorsfile)
print(hdata)

0 comments on commit 94502e9

Please sign in to comment.