diff --git a/snac/snac.py b/snac/snac.py index bb38506..d2fccbf 100644 --- a/snac/snac.py +++ b/snac/snac.py @@ -1,5 +1,6 @@ import json import math +import os from typing import List, Tuple import numpy as np @@ -100,10 +101,14 @@ def from_config(cls, config_path): def from_pretrained(cls, repo_id, **kwargs): from huggingface_hub import hf_hub_download - config_path = hf_hub_download(repo_id=repo_id, filename="config.json", **kwargs) - model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", **kwargs) - model = cls.from_config(config_path) - state_dict = torch.load(model_path, map_location="cpu") + if not os.path.isdir(repo_id): + config_path = hf_hub_download(repo_id=repo_id, filename="config.json", **kwargs) + model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", **kwargs) + model = cls.from_config(config_path) + state_dict = torch.load(model_path, map_location="cpu") + else: + model = cls.from_config(os.path.join(repo_id, "config.json")) + state_dict = torch.load(os.path.join(repo_id, "pytorch_model.bin"), map_location="cpu") model.load_state_dict(state_dict) model.eval() return model