Saltar al contenido

Scikit Learn GridSearchCV sin validación cruzada (aprendizaje no supervisado)

Solución:

Después de mucha búsqueda, pude encontrar este hilo. Parece que puede deshacerse de la validación cruzada en GridSearchCV si usa:

cv=[(slice(None), slice(None))]

He probado esto con mi propia versión codificada de búsqueda de cuadrícula sin validación cruzada y obtengo los mismos resultados de ambos métodos. Estoy publicando esta respuesta a mi propia pregunta en caso de que otros tengan el mismo problema.

Editar: para responder a la pregunta de jjrr en los comentarios, aquí hay un ejemplo de caso de uso:

from sklearn.metrics import silhouette_score as sc

def cv_silhouette_scorer(estimator, X):
    estimator.fit(X)
    cluster_labels = estimator.labels_
    num_labels = len(set(cluster_labels))
    num_samples = len(X.index)
    if num_labels == 1 or num_labels == num_samples:
        return -1
    else:
        return sc(X, cluster_labels)

cv = [(slice(None), slice(None))]
gs = GridSearchCV(estimator=sklearn.cluster.MeanShift(), param_grid=param_dict, 
                  scoring=cv_silhouette_scorer, cv=cv, n_jobs=-1)
gs.fit(df[cols_of_interest])

Voy a responder a su pregunta ya que parece que aún no ha sido respondida. Usando el método del paralelismo con el for bucle, puede utilizar el multiprocessing módulo.

from multiprocessing.dummy import Pool
from sklearn.cluster import KMeans
import functools

kmeans = KMeans()

# define your custom function for passing into each thread
def find_cluster(n_clusters, kmeans, X):
    from sklearn.metrics import silhouette_score  # you want to import in the scorer in your function

    kmeans.set_params(n_clusters=n_clusters)  # set n_cluster
    labels = kmeans.fit_predict(X)  # fit & predict
    score = silhouette_score(X, labels)  # get the score

    return score

# Now's the parallel implementation
clusters = [3, 4, 5]
pool = Pool()
results = pool.map(functools.partial(find_cluster, kmeans=kmeans, X=X), clusters)
pool.close()
pool.join()

# print the results
print(results)  # will print a list of scores that corresponds to the clusters list

Creo que usar cv = ShuffleSplit (test_size = 0.20, n_splits = 1) con n_splits = 1 es una mejor solución como esta publicación sugerida

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *