Skip to content

Commit

Permalink
Merge pull request #7 from duanemoody/main
Browse files Browse the repository at this point in the history
Add context manager protocol support, rudimentary tags reader
  • Loading branch information
by321 authored Dec 11, 2023
2 parents 63c3586 + 974669d commit f1ee506
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
8 changes: 8 additions & 0 deletions safetensors_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def __init__(self):

def __del__(self):
self.close_file()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
if self.f is not None:
self.f.close()


def close_file(self):
if self.f is not None:
Expand Down
20 changes: 20 additions & 0 deletions safetensors_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json, sys
from safetensors_file import SafeTensorsFile
from safetensors_worker import _ParseMore
"""
This script extracts the JSON header's ss_tag_frequency from a safetensors file, then outputs it.
TODO: gracefully error out on safetensors files without ["__metadata__"]["ss_tag_frequency"]
"""

def get_tags(tensorsfile: str) -> str:
s = SafeTensorsFile.open_file(tensorsfile, quiet=True) # omit the first non-JSON line
js = s.get_header()
md = js["__metadata__"]
_ParseMore(md) # pretty print the metadata
stf = md["ss_tag_frequency"]
return json.dumps(stf, ensure_ascii=False, separators=(', ', ': '), indent=4)

tensorsfile = sys.argv[1]
hdata = get_tags(tensorsfile)
print(hdata)

0 comments on commit f1ee506

Please sign in to comment.