Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel initialization for k-means #1754

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Hakdag97
Copy link
Collaborator

@Hakdag97 Hakdag97 commented Dec 17, 2024

Description

The bottleneck of k-means clustering (concerning runtime) is the initialization of centroids, which was previously built on a cost intensive serial algorithm. The aim of this pull request is to replace this algorithm by the more sophisticated k-means || initialization of centroids.

Issue/s resolved:

Changes proposed:

  • Complete new implementation of the initialization of centroids used for k-means, k-medians, and k-medoids
  • Adjustment of classes (like KMeans) to match with the new implementation

Type of change

  • Bug fix
  • New feature

Performance

  • Reducing the runtime of initialization of clustering algorithm in distributed and non-distributed mode with split=None and split not None by (at least) an order of magnitude (depending on the setting concerning, e.g., size of data and chosen parameters)

Does this change modify the behaviour of other functions? If so, which?

  • yes: the classes KMeans, KMedoids, KMedians and the function where are affected

Copy link
Contributor

Thank you for the PR!

@Hakdag97 Hakdag97 force-pushed the features/1674-Optimization_of_k-means_initialization branch from 0889330 to 7f860c4 Compare January 6, 2025 10:38
@github-actions github-actions bot added cluster core features testing Implementation of tests, or test-related issues labels Jan 6, 2025
Copy link
Contributor

github-actions bot commented Jan 6, 2025

Thank you for the PR!

Copy link

codecov bot commented Jan 6, 2025

Codecov Report

Attention: Patch coverage is 97.82609% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.45%. Comparing base (87f2812) to head (7f860c4).

Files with missing lines Patch % Lines
heat/cluster/_kcluster.py 97.05% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1754      +/-   ##
==========================================
+ Coverage   92.26%   92.45%   +0.18%     
==========================================
  Files          84       84              
  Lines       12445    12438       -7     
==========================================
+ Hits        11482    11499      +17     
+ Misses        963      939      -24     
Flag Coverage Δ
unit 92.45% <97.82%> (+0.18%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Hakdag97 Hakdag97 requested a review from mrfh92 January 6, 2025 11:59
@JuanPedroGHM JuanPedroGHM changed the title Features/1674 optimization of k means initialization Optimization of k means initialization Jan 13, 2025
@JuanPedroGHM JuanPedroGHM self-requested a review January 13, 2025 09:20
@ClaudiaComito ClaudiaComito added this to the 1.6 milestone Jan 13, 2025
@ClaudiaComito ClaudiaComito changed the title Optimization of k means initialization Parallel initialisation for k-means Jan 13, 2025
@Hakdag97 Hakdag97 changed the title Parallel initialisation for k-means Parallel initialization for k-means Jan 13, 2025
@Hakdag97 Hakdag97 removed the PR talk label Jan 20, 2025
@@ -169,7 +174,7 @@ def functional_value_(self) -> float:
"""
return self._functional_value

def fit(self, x: DNDarray):
def fit(self, x: DNDarray, weights: torch.tensor = 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fit of batchparallel clustering. If we allow for specification of weights here, it must be a DNDarray, and below the corresponding local arrays would have to be used for the local clusterings.

@@ -233,6 +241,7 @@ def fit(self, x: DNDarray):
self.max_iter,
self.tol,
local_random_state,
weights,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

@@ -102,7 +102,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):

return new_cluster_centers

def fit(self, x: DNDarray) -> self:
def fit(self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20) -> self:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there references for these default values? I had a look into dask and they use oversampling=2 as default.

# output format: scalar
#
# Iteratively fill the tensor storing the centroids
for _ in ht.arange(0, iter_multiplier * ht.log(init_cost)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for loop counters, the standard python range should be more effective

Copy link
Collaborator

@mrfh92 mrfh92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine 👍. I have only some minor comments concerning some details.

prob = oversampling * min_distance / min_distance.sum()
# --> probability distribution with oversampling factor
# output format: vector
idx = ht.where(sample <= prob)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one may think about moving the creation of sample here to have some new sampling every step of the loop

# Evaluate distance between final centroids and data points
if centroids.shape[0] <= self.n_clusters:
raise ValueError(
"The oversampling factor and/or the number of iterations are chosen"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why two strings?
It may also be helpful to use sth like f"The oversampling factor (={oversampling}) ... " to give the user the values in the error message.

reclustered_centroids = torch.zeros(
(self.n_clusters, centroids.shape[1]),
dtype=x.dtype.torch_type(),
device=centroids.device,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

centroids.device.torch_device()?

ht.MPI_WORLD.Bcast(
reclustered_centroids, root=0
) # by default it is broadcasted from process 0
reclustered_centroids = ht.array(reclustered_centroids, split=x.split)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know whether we want to split the centroids as they are probably only a small number. So this might produce overhead compared to having split=None here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark PR cluster core features testing Implementation of tests, or test-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants