Saltar al contenido

¿Cómo funciona el parámetro class_weight en scikit-learn?

Solución:

En primer lugar, puede que no sea bueno limitarse a recordar solo. Simplemente puede lograr una recuperación del 100% clasificando todo como la clase positiva. Por lo general, sugiero usar AUC para seleccionar parámetros y luego encontrar un umbral para el punto de operación (digamos un nivel de precisión dado) que le interesa.

Por cuanto class_weight funciona: Penaliza errores en muestras de class[i] con class_weight[i] en lugar de 1. Por lo tanto, un peso de clase más alto significa que desea poner más énfasis en una clase. Por lo que dice, parece que la clase 0 es 19 veces más frecuente que la clase 1. Por lo tanto, debe aumentar la class_weight de la clase 1 en relación con la clase 0, digamos {0: .1, 1: .9}. Si el class_weight no suma 1, básicamente cambiará el parámetro de regularización.

Por cuanto class_weight="auto" funciona, puede echar un vistazo a esta discusión. En la versión dev puedes usar class_weight="balanced", que es más fácil de entender: básicamente significa replicar la clase más pequeña hasta tener tantas muestras como en la más grande, pero de forma implícita.

La primera respuesta es buena para comprender cómo funciona. Pero quería entender cómo debería usarlo en la práctica.

RESUMEN

  • para datos moderadamente desequilibrados SIN ruido, no hay mucha diferencia en la aplicación de ponderaciones de clase
  • para datos moderadamente desequilibrados CON ruido y fuertemente desequilibrados, es mejor aplicar ponderaciones de clase
  • param class_weight="balanced" funciona decente en ausencia de que desee optimizar manualmente
  • con class_weight="balanced" captura más eventos verdaderos (mayor recuperación VERDADERA) pero también es más probable que reciba alertas falsas (menor precisión VERDADERA)
    • como resultado, el% TRUE total podría ser más alto que el real debido a todos los falsos positivos
    • Las AUC podrían confundirlo aquí si las falsas alarmas son un problema
  • no es necesario cambiar el umbral de decisión al% de desequilibrio, incluso para un desequilibrio fuerte, está bien mantener 0.5 (o en algún lugar alrededor de eso, dependiendo de lo que necesite)

nótese bien

El resultado puede diferir al utilizar RF o GBM. sklearn no tiene class_weight="balanced" para GBM pero lightgbm tiene LGBMClassifier(is_unbalance=False)

CÓDIGO

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize="index") # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize="index") # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *