Skip to content

Commit

Permalink
fix mistake
Browse files Browse the repository at this point in the history
The correct ```download_model.py``` was not uploaded prior to last release, thus preventing the transcriber and image summary generator from working properly.  Also removed a de-bugging print statement from ```gui.py```.
  • Loading branch information
BBC-Esq authored Aug 6, 2024
1 parent abe02d8 commit d02a06e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 47 deletions.
85 changes: 39 additions & 46 deletions src/download_model.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,44 @@
import os
import subprocess
from pathlib import Path
from PySide6.QtCore import QObject, Signal
from huggingface_hub import snapshot_download, HfApi
import logging
import threading

class ModelDownloadedSignal(QObject):
downloaded = Signal(str, str)
logging.getLogger("transformers").setLevel(logging.ERROR)

model_downloaded_signal = ModelDownloadedSignal()

MODEL_DIRECTORIES = {
"vector": "Vector",
"chat": "Chat"
}

class ModelDownloader:
def __init__(self, model_name, model_type):
self.model_name = model_name
self.model_type = model_type
self._model_directory = None

def get_model_directory_name(self):
return self.model_name.replace("/", "--")

def get_model_directory(self):
if not self._model_directory:
model_type_dir = MODEL_DIRECTORIES.get(self.model_type, "")
self._model_directory = Path("Models") / model_type_dir / self.get_model_directory_name()
return self._model_directory

def get_model_url(self):
return f"https://huggingface.co/{self.model_name}"

def download_model(self):
model_url = self.get_model_url()
target_directory = self.get_model_directory()
print(f"Downloading {self.model_name}...")

env = os.environ.copy()
env["GIT_CLONE_PROTECTION_ACTIVE"] = "false"
def download_model_files(repo_id, local_dir):
try:
api = HfApi()
files_list = api.list_repo_files(repo_id)

try:
subprocess.run(
["git", "clone", "--depth", "1", model_url, str(target_directory)],
check=True,
env=env
)
print("\033[92mModel downloaded and ready to use.\033[0m")
model_downloaded_signal.downloaded.emit(self.model_name, self.model_type)
except subprocess.CalledProcessError as e:
print(f"Command 'git clone' returned non-zero exit status {e.returncode}.")
top_level_files = [f for f in files_list if '/' not in f]
if not top_level_files:
raise ValueError("No top-level files found in the repository.")
snapshot_download(
repo_id,
local_dir=local_dir,
allow_patterns=top_level_files,
local_dir_use_symlinks=False,
)
print(f"Downloaded top-level files from {repo_id} to {local_dir}")
return True
except Exception as e:
print(f"Failed to download model: {e}")
return False

def download_model(repo_id):
folder_name = repo_id.replace('/', '_') # CHANGED FROM TWO DASHES
current_dir = Path(__file__).resolve().parent
models_dir = current_dir / "Models" / "vector"
local_dir = models_dir / folder_name

os.makedirs(local_dir, exist_ok=True)

thread = threading.Thread(target=download_model_files, args=(repo_id, local_dir))
thread.start()
return thread

if __name__ == "__main__":
test_repo_id = "thenlper/gte-large"
download_thread = download_model(test_repo_id)
download_thread.join()
2 changes: 1 addition & 1 deletion src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from utilities import list_theme_files, make_theme_changer, load_stylesheet

# Print the current working directory
print(f"Current working directory: {os.getcwd()}")
# print(f"Current working directory: {os.getcwd()}")

# Check if we can write to the current directory
try:
Expand Down

0 comments on commit d02a06e

Please sign in to comment.