-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The correct ```download_model.py``` was not uploaded prior to last release, thus preventing the transcriber and image summary generator from working properly. Also removed a de-bugging print statement from ```gui.py```.
- Loading branch information
Showing
2 changed files
with
40 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,44 @@ | ||
import os | ||
import subprocess | ||
from pathlib import Path | ||
from PySide6.QtCore import QObject, Signal | ||
from huggingface_hub import snapshot_download, HfApi | ||
import logging | ||
import threading | ||
|
||
class ModelDownloadedSignal(QObject): | ||
downloaded = Signal(str, str) | ||
logging.getLogger("transformers").setLevel(logging.ERROR) | ||
|
||
model_downloaded_signal = ModelDownloadedSignal() | ||
|
||
MODEL_DIRECTORIES = { | ||
"vector": "Vector", | ||
"chat": "Chat" | ||
} | ||
|
||
class ModelDownloader: | ||
def __init__(self, model_name, model_type): | ||
self.model_name = model_name | ||
self.model_type = model_type | ||
self._model_directory = None | ||
|
||
def get_model_directory_name(self): | ||
return self.model_name.replace("/", "--") | ||
|
||
def get_model_directory(self): | ||
if not self._model_directory: | ||
model_type_dir = MODEL_DIRECTORIES.get(self.model_type, "") | ||
self._model_directory = Path("Models") / model_type_dir / self.get_model_directory_name() | ||
return self._model_directory | ||
|
||
def get_model_url(self): | ||
return f"https://huggingface.co/{self.model_name}" | ||
|
||
def download_model(self): | ||
model_url = self.get_model_url() | ||
target_directory = self.get_model_directory() | ||
print(f"Downloading {self.model_name}...") | ||
|
||
env = os.environ.copy() | ||
env["GIT_CLONE_PROTECTION_ACTIVE"] = "false" | ||
def download_model_files(repo_id, local_dir): | ||
try: | ||
api = HfApi() | ||
files_list = api.list_repo_files(repo_id) | ||
|
||
try: | ||
subprocess.run( | ||
["git", "clone", "--depth", "1", model_url, str(target_directory)], | ||
check=True, | ||
env=env | ||
) | ||
print("\033[92mModel downloaded and ready to use.\033[0m") | ||
model_downloaded_signal.downloaded.emit(self.model_name, self.model_type) | ||
except subprocess.CalledProcessError as e: | ||
print(f"Command 'git clone' returned non-zero exit status {e.returncode}.") | ||
top_level_files = [f for f in files_list if '/' not in f] | ||
if not top_level_files: | ||
raise ValueError("No top-level files found in the repository.") | ||
snapshot_download( | ||
repo_id, | ||
local_dir=local_dir, | ||
allow_patterns=top_level_files, | ||
local_dir_use_symlinks=False, | ||
) | ||
print(f"Downloaded top-level files from {repo_id} to {local_dir}") | ||
return True | ||
except Exception as e: | ||
print(f"Failed to download model: {e}") | ||
return False | ||
|
||
def download_model(repo_id): | ||
folder_name = repo_id.replace('/', '_') # CHANGED FROM TWO DASHES | ||
current_dir = Path(__file__).resolve().parent | ||
models_dir = current_dir / "Models" / "vector" | ||
local_dir = models_dir / folder_name | ||
|
||
os.makedirs(local_dir, exist_ok=True) | ||
|
||
thread = threading.Thread(target=download_model_files, args=(repo_id, local_dir)) | ||
thread.start() | ||
return thread | ||
|
||
if __name__ == "__main__": | ||
test_repo_id = "thenlper/gte-large" | ||
download_thread = download_model(test_repo_id) | ||
download_thread.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters