Skip to content

Commit

Permalink
Update version number and fix formatting in calculations.py and logge…
Browse files Browse the repository at this point in the history
…r.py
  • Loading branch information
cobanov committed Mar 13, 2024
1 parent 1adc725 commit fbd37f5
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="tasnif",
version="0.1.9",
version="0.1.10",
install_requires=[
"numpy",
"scikit-learn",
Expand Down
5 changes: 4 additions & 1 deletion tasnif/calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from img2vec_pytorch import Img2Vec
from scipy.cluster.vq import kmeans2
from sklearn.decomposition import PCA

from .logger import info


Expand Down Expand Up @@ -67,7 +68,9 @@ def calculate_kmeans(pca_embeddings, num_classes, iter=10):
)

try:
centroid, labels = kmeans2(data=pca_embeddings, k=num_classes, minit="points", iter=iter)
centroid, labels = kmeans2(
data=pca_embeddings, k=num_classes, minit="points", iter=iter
)
counts = np.bincount(labels)
info("KMeans calculated.")
return centroid, labels, counts
Expand Down
8 changes: 6 additions & 2 deletions tasnif/logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging

from rich.logging import RichHandler

log_format = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(
level="INFO", format=log_format, datefmt="[%X]", handlers=[RichHandler(show_time=False, show_level=False)]
level="INFO",
format=log_format,
datefmt="[%X]",
handlers=[RichHandler(show_time=False, show_level=False)],
)


Expand All @@ -15,6 +19,6 @@ def error(msg):
logging.error(msg)


if __name__ == '__main__':
if __name__ == "__main__":
info("info message")
error("error message")
23 changes: 16 additions & 7 deletions tasnif/tasnif.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import shutil
import warnings
from itertools import compress

import numpy as np
from tqdm import tqdm

from .calculations import calculate_kmeans, calculate_pca, get_embeddings
from .logger import error, info
from .utils import (
create_dir,
create_image_grid,
read_images_from_directory,
read_with_pil,
)
from .logger import info, error

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -52,15 +54,19 @@ def calculate(self, pca=True, iter=10):
"""

if not self.images:
raise ValueError("The images list can not be empty. Please call the read method before calculating.")
raise ValueError(
"The images list can not be empty. Please call the read method before calculating."
)

self.embeddings = get_embeddings(use_gpu=self.use_gpu, images=self.images)
if pca:
self.pca_embeddings = calculate_pca(self.embeddings, self.pca_dim)
self.centroid, self.labels, self.counts = calculate_kmeans(self.pca_embeddings, self.num_classes, iter = iter)
self.centroid, self.labels, self.counts = calculate_kmeans(
self.pca_embeddings, self.num_classes, iter=iter
)
else:
self.centroid, self.labels, self.counts = calculate_kmeans(
self.embeddings, self.num_classes, iter = iter
self.embeddings, self.num_classes, iter=iter
)

def export(self, output_folder="./"):
Expand All @@ -76,7 +82,6 @@ def export(self, output_folder="./"):
create_dir(project_path)

for label_number in tqdm(range(self.num_classes)):

label_mask = self.labels == label_number
path_images = list(compress(self.image_paths, label_mask))
target_directory = os.path.join(project_path, f"cluster_{label_number}")
Expand Down Expand Up @@ -106,8 +111,12 @@ def export_embeddings(self, output_folder="./"):
"""

if self.embeddings is None:
raise ValueError("Embeddings can not be empty. Please call the calculate method first.")
raise ValueError(
"Embeddings can not be empty. Please call the calculate method first."
)

embeddings_path = os.path.join(output_folder, f"{self.project_name}_embeddings.npy")
embeddings_path = os.path.join(
output_folder, f"{self.project_name}_embeddings.npy"
)
np.save(embeddings_path, self.embeddings)
info(f"Embeddings have been saved to {embeddings_path}")
2 changes: 2 additions & 0 deletions tasnif/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import glob
import os
from pathlib import Path

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

from .logger import info


Expand Down

0 comments on commit fbd37f5

Please sign in to comment.