Skip to content

Commit d97c99b

Browse files
committed
Fix predict
1 parent 3611a24 commit d97c99b

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

kmedoids/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,8 @@ def fit(self, X, y=None):
804804

805805
if self.metric != "precomputed":
806806
from sklearn.metrics.pairwise import pairwise_distances
807-
X = pairwise_distances(X, metric=self.metric, *self.metric_params)
807+
Xd = X
808+
X = pairwise_distances(X, metric=self.metric)
808809
if self.method == "fasterpam":
809810
result = fasterpam(X, self.n_clusters, self.max_iter, self.init, random_state=self.random_state)
810811
elif self.method == "fastpam1":
@@ -834,7 +835,7 @@ def fit(self, X, y=None):
834835
if self.metric == "precomputed":
835836
self.cluster_centers_ = None
836837
else:
837-
self.cluster_centers_ = X[result.medoids]
838+
self.cluster_centers_ = Xd[result.medoids]
838839
return self
839840

840841
def predict(self, X):
@@ -847,10 +848,12 @@ def predict(self, X):
847848
:rtype: array, shape = (n_query,)
848849
"""
849850
if self.metric != "precomputed":
850-
from sklearn.metrics.pairwise import pairwise_distances
851-
X = pairwise_distances(X, metric=self.metric)
852-
import numpy as np
853-
return np.argmin(X[:, self.medoid_indices_], axis=1)
851+
from sklearn.metrics.pairwise import pairwise_distances_argmin
852+
Y = self.cluster_centers_
853+
X = pairwise_distances_argmin(X, Y=Y, metric=self.metric)
854+
else:
855+
raise NotImplementedError("This API is not safe to use with precomputed distances. Use the argmin of the distances to the medoids.")
856+
return self.medoid_indices_[X]
854857

855858
def transform(self, X):
856859
"""Transforms X to cluster-distance space.

0 commit comments

Comments
 (0)