@@ -804,7 +804,8 @@ def fit(self, X, y=None):
804804
805805 if self .metric != "precomputed" :
806806 from sklearn .metrics .pairwise import pairwise_distances
807- X = pairwise_distances (X , metric = self .metric , * self .metric_params )
807+ Xd = X
808+ X = pairwise_distances (X , metric = self .metric )
808809 if self .method == "fasterpam" :
809810 result = fasterpam (X , self .n_clusters , self .max_iter , self .init , random_state = self .random_state )
810811 elif self .method == "fastpam1" :
@@ -834,7 +835,7 @@ def fit(self, X, y=None):
834835 if self .metric == "precomputed" :
835836 self .cluster_centers_ = None
836837 else :
837- self .cluster_centers_ = X [result .medoids ]
838+ self .cluster_centers_ = Xd [result .medoids ]
838839 return self
839840
840841 def predict (self , X ):
@@ -847,10 +848,12 @@ def predict(self, X):
847848 :rtype: array, shape = (n_query,)
848849 """
849850 if self .metric != "precomputed" :
850- from sklearn .metrics .pairwise import pairwise_distances
851- X = pairwise_distances (X , metric = self .metric )
852- import numpy as np
853- return np .argmin (X [:, self .medoid_indices_ ], axis = 1 )
851+ from sklearn .metrics .pairwise import pairwise_distances_argmin
852+ Y = self .cluster_centers_
853+ X = pairwise_distances_argmin (X , Y = Y , metric = self .metric )
854+ else :
855+ raise NotImplementedError ("This API is not safe to use with precomputed distances. Use the argmin of the distances to the medoids." )
856+ return self .medoid_indices_ [X ]
854857
855858 def transform (self , X ):
856859 """Transforms X to cluster-distance space.
0 commit comments