-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel initialization for k-means #1754
base: main
Are you sure you want to change the base?
Parallel initialization for k-means #1754
Conversation
Thank you for the PR! |
0889330
to
7f860c4
Compare
Thank you for the PR! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1754 +/- ##
==========================================
+ Coverage 92.26% 92.45% +0.18%
==========================================
Files 84 84
Lines 12445 12438 -7
==========================================
+ Hits 11482 11499 +17
+ Misses 963 939 -24
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@@ -169,7 +174,7 @@ def functional_value_(self) -> float: | |||
""" | |||
return self._functional_value | |||
|
|||
def fit(self, x: DNDarray): | |||
def fit(self, x: DNDarray, weights: torch.tensor = 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the fit of batchparallel clustering. If we allow for specification of weights here, it must be a DNDarray, and below the corresponding local arrays would have to be used for the local clusterings.
@@ -233,6 +241,7 @@ def fit(self, x: DNDarray): | |||
self.max_iter, | |||
self.tol, | |||
local_random_state, | |||
weights, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above
@@ -102,7 +102,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray): | |||
|
|||
return new_cluster_centers | |||
|
|||
def fit(self, x: DNDarray) -> self: | |||
def fit(self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20) -> self: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there references for these default values? I had a look into dask and they use oversampling=2 as default.
# output format: scalar | ||
# | ||
# Iteratively fill the tensor storing the centroids | ||
for _ in ht.arange(0, iter_multiplier * ht.log(init_cost)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for loop counters, the standard python range should be more effective
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine 👍. I have only some minor comments concerning some details.
prob = oversampling * min_distance / min_distance.sum() | ||
# --> probability distribution with oversampling factor | ||
# output format: vector | ||
idx = ht.where(sample <= prob) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one may think about moving the creation of sample here to have some new sampling every step of the loop
# Evaluate distance between final centroids and data points | ||
if centroids.shape[0] <= self.n_clusters: | ||
raise ValueError( | ||
"The oversampling factor and/or the number of iterations are chosen" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why two strings?
It may also be helpful to use sth like f"The oversampling factor (={oversampling}) ... " to give the user the values in the error message.
reclustered_centroids = torch.zeros( | ||
(self.n_clusters, centroids.shape[1]), | ||
dtype=x.dtype.torch_type(), | ||
device=centroids.device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
centroids.device.torch_device()?
ht.MPI_WORLD.Bcast( | ||
reclustered_centroids, root=0 | ||
) # by default it is broadcasted from process 0 | ||
reclustered_centroids = ht.array(reclustered_centroids, split=x.split) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know whether we want to split the centroids as they are probably only a small number. So this might produce overhead compared to having split=None here.
Description
The bottleneck of k-means clustering (concerning runtime) is the initialization of centroids, which was previously built on a cost intensive serial algorithm. The aim of this pull request is to replace this algorithm by the more sophisticated k-means || initialization of centroids.
Issue/s resolved:
Changes proposed:
Type of change
Performance
split=None
andsplit not None
by (at least) an order of magnitude (depending on the setting concerning, e.g., size of data and chosen parameters)Does this change modify the behaviour of other functions? If so, which?